# End to End Quantization Test
This notebook tests our quantization implementation in one place by making sure our operations yield good accuracy on the end task. We don't check for exact correctness of the output since quantization will introduce errors that we do not worry about unless they cause an accuracy drop.

In [54]:
%load_ext autoreload
%autoreload 2
import torch
from transformers import glue_compute_metrics
import sklearn
import math
from sklearn.metrics import f1_score
from tqdm import tqdm
import numpy as np
from transformers import RobertaForSequenceClassification, AutoTokenizer
from transformers.data.metrics import simple_accuracy
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
import torch.nn as nn
from torch.nn import CrossEntropyLoss

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from src.utils import roberta_mrpc_dataset
dataset = roberta_mrpc_dataset()

dataset = dataset.map(encode, batched=True)
dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])


Reusing dataset glue (/Users/oliver/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Some weights of the model checkpoint at textattack/roberta-base-MRPC were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Loading cached processed dataset at /Users/oliver/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/

In [64]:
def attention(layer, hidden_states, attention_mask=None):
    '''
    Pass in a encoder layer (which holds pretrained weights) and hidden_states input,
    and this function performs the same operations as the layer but in a readable fashion.
    
    hidden_states: <bs, seqlen, dmodel>
    '''
    bs, seqlen, dmodel = hidden_states.size()
    num_heads = layer.attention.self.num_attention_heads
    dhead = layer.attention.self.attention_head_size
    
    # Linear transform to get multiple heads. This is a major MAC slurper.
    # Each of these is calling an nn.Linear layer on hidden_states.
#     query_layer = layer.attention.self.query(hidden_states) # <bs, seqlen, dmodel>
    query_layer = torch.matmul(hidden_states, layer.attention.self.query.weight.T) + layer.attention.self.query.bias
    key_layer = layer.attention.self.key(hidden_states)     # <bs, seqlen, dmodel>
    value_layer = layer.attention.self.value(hidden_states) # <bs, seqlen, dmodel>
    
    # Reshape and transpose for multi-head
    new_shape = (bs, seqlen, num_heads, dhead)
    
    query_layer = query_layer.view(new_shape)
    value_layer = value_layer.view(new_shape)
    key_layer = key_layer.view(new_shape)
    
    query_layer = query_layer.permute(0,2,1,3) # <bs, num_head, seqlen, dhead>
    value_layer = value_layer.permute(0,2,1,3) # <bs, num_head, seqlen, dhead>
    # Key is transposed to match dimensions of Query for matmul
    key_layer = key_layer.permute(0,2,3,1)     # <bs, num_head, dhead, seqlen>
    
    # The attention main course
    attention_scores = torch.matmul(query_layer, key_layer)
    attention_scores /= math.sqrt(dhead)
    
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
        attention_scores = attention_scores + attention_mask
    
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    attention_probs = layer.attention.self.dropout(attention_probs)
    
    # Weighted sum of Values from softmax attention
    attention_out = torch.matmul(attention_probs, value_layer)
    
    attention_out = attention_out.permute(0,2,1,3).contiguous()
    attention_out = attention_out.view(bs, seqlen, dmodel)
    
    # It's time for one more linear transform and layer norm
    dense_out = layer.attention.output.dense(attention_out)
    dense_out = layer.attention.output.dropout(dense_out)
    
    # LayerNorm also mplements the residual connection
    layer_out = layer.attention.output.LayerNorm(dense_out + hidden_states)
    
    return layer_out

In [65]:
def ffn(layer, attention_out):
    '''
    Pass in the feedforward layer and attention output. Returns the same result of a FF forward pass.
    '''
    # Layer 1
    output = layer.intermediate.dense(attention_out)
    output = nn.functional.gelu(output)
    
    # Layer 2
    output = layer.output.dense(output)
    output = layer.output.dropout(output)
    output = layer.output.LayerNorm(output + attention_out)
    
    return output

In [70]:
def encoder_stack(model, hidden_states, attention_mask):
    for layer_module in model.roberta.encoder.layer:
        # MHA + LayerNorm
        attention_out = attention(layer_module, hidden_states, attention_mask)
        ff_out = ffn(layer_module, attention_out)
        hidden_states = ff_out
    sequence_output = hidden_states
    pooled_output = model.roberta.pooler(hidden_states) if model.roberta.pooler is not None else None
    
    return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
            cross_attentions=None,
        )
        
def sequence_classification(model,
                            outputs, 
                            input_ids=None,
                            attention_mask=None,
                            token_type_ids=None,
                            position_ids=None,
                            head_mask=None,
                            inputs_embeds=None,
                            labels=None,
                            output_attentions=None,
                            output_hidden_states=None,
                            return_dict=None,):
    
    sequence_output = outputs[0]
    logits = model.classifier(sequence_output)

    loss = None
    if labels is not None:
        if model.config.problem_type is None:
            if model.num_labels == 1:
                model.config.problem_type = "regression"
            elif model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                model.config.problem_type = "single_label_classification"
            else:
                model.config.problem_type = "multi_label_classification"

        if model.config.problem_type == "regression":
            loss_fct = MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(logits, labels)
        elif model.config.problem_type == "single_label_classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, model.num_labels), labels.view(-1))
        elif model.config.problem_type == "multi_label_classification":
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)

    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

def eval_model(model, dataloader):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.eval()
    preds = None

    for i, batch in enumerate(tqdm(dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            embedding_output = model.roberta.embeddings(
                input_ids=batch['input_ids'],
                position_ids=None,
                token_type_ids=None,
                inputs_embeds=None,
                past_key_values_length=0,
            )
            
            extended_attention_mask = model.roberta.get_extended_attention_mask(batch['attention_mask'], batch['input_ids'].size(), device)
            outputs = encoder_stack(model, embedding_output, extended_attention_mask)
            outputs = sequence_classification(model, outputs, **batch)

            outputs_gt = model(**batch)
            
            tmp_eval_loss, logits = outputs[:2]
            _, logits_gt = outputs_gt[:2]
            
            
            assert torch.allclose(logits_gt, logits)
            
            loss = outputs[0]
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = batch['labels'].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, batch['labels'].detach().cpu().numpy(), axis=0)
        if i % 10 == 0:
    #         print(f"loss: {loss}")
            pass

    preds = np.argmax(preds, axis=1)

    print(f'accuracy: {simple_accuracy(preds, out_label_ids)}')

In [71]:
eval_model(model, dataloader)

  0%|                                                                                                                                                                              | 0/13 [00:00<?, ?it/s]

dict_items([('input_ids', tensor([[    0,   894,    26,  ...,     1,     1,     1],
        [    0, 43600,  1322,  ...,     1,     1,     1],
        [    0,   133,  1404,  ...,     1,     1,     1],
        ...,
        [    0,   133,  5259,  ...,     1,     1,     1],
        [    0,  3908,   209,  ...,     1,     1,     1],
        [    0, 33553,    21,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1,
        1, 1, 0, 1, 1, 1, 0, 1]))])
torch.Size([32, 512])


  8%|████████████▊                                                                                                                                                         | 1/13 [00:34<06:48, 34.02s/it]

dict_items([('input_ids', tensor([[    0, 10980,  3921,  ...,     1,     1,     1],
        [    0,  3750,   435,  ...,     1,     1,     1],
        [    0,   894,   156,  ...,     1,     1,     1],
        ...,
        [    0, 36035,     9,  ...,     1,     1,     1],
        [    0,   113,  2477,  ...,     1,     1,     1],
        [    0,  1711,    74,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1]))])
torch.Size([32, 512])


 15%|█████████████████████████▌                                                                                                                                            | 2/13 [01:09<06:21, 34.66s/it]

dict_items([('input_ids', tensor([[    0,  4763, 38805,  ...,     1,     1,     1],
        [    0,   133,  4614,  ...,     1,     1,     1],
        [    0,   347,  2723,  ...,     1,     1,     1],
        ...,
        [    0, 15248,  6909,  ...,     1,     1,     1],
        [    0, 43294,    67,  ...,     1,     1,     1],
        [    0, 42609,    42,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,
        1, 1, 0, 1, 1, 0, 1, 1]))])
torch.Size([32, 512])


 23%|██████████████████████████████████████▎                                                                                                                               | 3/13 [01:46<05:59, 35.95s/it]

dict_items([('input_ids', tensor([[    0,   133,   690,  ...,     1,     1,     1],
        [    0, 28188,    13,  ...,     1,     1,     1],
        [    0,  1106,     5,  ...,     1,     1,     1],
        ...,
        [    0,  9497,   224,  ...,     1,     1,     1],
        [    0,  4771, 11680,  ...,     1,     1,     1],
        [    0, 28012,   282,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 0, 1, 1, 1, 1, 0]))])
torch.Size([32, 512])


 31%|███████████████████████████████████████████████████                                                                                                                   | 4/13 [02:23<05:28, 36.50s/it]

dict_items([('input_ids', tensor([[    0,   133,   908,  ...,     1,     1,     1],
        [    0,  3084,   138,  ...,     1,     1,     1],
        [    0, 32703,    12,  ...,     1,     1,     1],
        ...,
        [    0,   133, 23781,  ...,     1,     1,     1],
        [    0, 30019, 20906,  ...,     1,     1,     1],
        [    0, 12582, 40787,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1,
        1, 1, 1, 1, 1, 0, 0, 1]))])
torch.Size([32, 512])


 38%|███████████████████████████████████████████████████████████████▊                                                                                                      | 5/13 [03:01<04:54, 36.84s/it]

dict_items([('input_ids', tensor([[    0, 21674,   112,  ...,     1,     1,     1],
        [    0, 44888,  1322,  ...,     1,     1,     1],
        [    0,  3762,     9,  ...,     1,     1,     1],
        ...,
        [    0,  5771,     5,  ...,     1,     1,     1],
        [    0, 42510,    56,  ...,     1,     1,     1],
        [    0,   133, 13387,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1,
        0, 1, 0, 1, 0, 1, 1, 0]))])
torch.Size([32, 512])


 46%|████████████████████████████████████████████████████████████████████████████▌                                                                                         | 6/13 [03:37<04:16, 36.60s/it]

dict_items([('input_ids', tensor([[    0, 35346,  2993,  ...,     1,     1,     1],
        [    0,  1121,   130,  ...,     1,     1,     1],
        [    0, 21674,    80,  ...,     1,     1,     1],
        ...,
        [    0,   133,   796,  ...,     1,     1,     1],
        [    0,   113,   407,  ...,     1,     1,     1],
        [    0, 37703,   135,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1,
        1, 1, 0, 1, 1, 1, 1, 0]))])
torch.Size([32, 512])


 54%|█████████████████████████████████████████████████████████████████████████████████████████▍                                                                            | 7/13 [04:12<03:35, 35.95s/it]

dict_items([('input_ids', tensor([[   0,  894,   16,  ...,    1,    1,    1],
        [   0,  133,  604,  ...,    1,    1,    1],
        [   0,  113,  440,  ...,    1,    1,    1],
        ...,
        [   0, 4993,  504,  ...,    1,    1,    1],
        [   0,  565, 3937,  ...,    1,    1,    1],
        [   0,  113,  152,  ...,    1,    1,    1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 0, 1, 0, 1, 0]))])
torch.Size([32, 512])


 62%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                               | 8/13 [04:46<02:57, 35.53s/it]

dict_items([('input_ids', tensor([[    0, 30888,    12,  ...,     1,     1,     1],
        [    0, 36310,  2156,  ...,     1,     1,     1],
        [    0,  1121,     5,  ...,     1,     1,     1],
        ...,
        [    0,   113,   166,  ...,     1,     1,     1],
        [    0, 10980,     4,  ...,     1,     1,     1],
        [    0,   133,   316,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 0, 0, 1, 1, 1]))])
torch.Size([32, 512])


 69%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                   | 9/13 [05:21<02:20, 35.16s/it]

dict_items([('input_ids', tensor([[    0,  1121,   902,  ...,     1,     1,     1],
        [    0,   100,    21,  ...,     1,     1,     1],
        [    0,   133,   194,  ...,     1,     1,     1],
        ...,
        [    0,   113,    85,  ...,     1,     1,     1],
        [    0,   717,  6712,  ...,     1,     1,     1],
        [    0, 16025,  2614,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 0, 1, 0, 1, 1, 0, 0]))])
torch.Size([32, 512])


 77%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                      | 10/13 [05:55<01:44, 34.97s/it]

dict_items([('input_ids', tensor([[    0,  9497,  3986,  ...,     1,     1,     1],
        [    0, 14229,    42,  ...,     1,     1,     1],
        [    0,   133,   806,  ...,     1,     1,     1],
        ...,
        [    0,   250,   320,  ...,     1,     1,     1],
        [    0, 11329,   219,  ...,     1,     1,     1],
        [    0, 40145,  5369,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,
        1, 1, 1, 1, 0, 1, 0, 1]))])
torch.Size([32, 512])


 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                         | 11/13 [06:29<01:09, 34.68s/it]

dict_items([('input_ids', tensor([[    0,   133, 13640,  ...,     1,     1,     1],
        [    0,   243,    67,  ...,     1,     1,     1],
        [    0,   243,  1276,  ...,     1,     1,     1],
        ...,
        [    0, 35779,    32,  ...,     1,     1,     1],
        [    0,  8800,   388,  ...,     1,     1,     1],
        [    0,   133,  4079,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,
        0, 0, 1, 1, 1, 1, 0, 1]))])
torch.Size([32, 512])


 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 12/13 [07:06<00:35, 35.37s/it]

dict_items([('input_ids', tensor([[    0,   133,  1576,  ...,     1,     1,     1],
        [    0,  1121,     5,  ...,     1,     1,     1],
        [    0, 42047, 12637,  ...,     1,     1,     1],
        ...,
        [    0,  5625,    11,  ...,     1,     1,     1],
        [    0,   133,   208,  ...,     1,     1,     1],
        [    0,  9089,    12,  ...,     1,     1,     1]])), ('attention_mask', tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])), ('labels', tensor([1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1]))])
torch.Size([24, 512])


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [09:35<00:00, 44.29s/it]

accuracy: 0.9117647058823529



