## Set-up environment

As usual, we first install HuggingFace Transformers, and Datasets.

In [None]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import os
from pytorch_lightning.loggers import TensorBoardLogger
%load_ext autoreload
%autoreload 2

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Nov 17 06:06:05 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0    46W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [None]:
!pip install -q datasets

## Prepare data

Here we take a small portion of the IMDB dataset, a binary text classification dataset ("is a movie review positive or negative?").

In [None]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train', 'test[:6250]+test[-6250:]'])
train_ds, val_ds = load_dataset("imdb", split=['train', 'test[6250:12500]+test[-12500:-6250]'])

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
print(test_ds)
print(val_ds)

Dataset({
    features: ['text', 'label'],
    num_rows: 12500
})
Dataset({
    features: ['text', 'label'],
    num_rows: 12500
})


We create id2label and label2id mappings, which are handy at inference time.

In [None]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [None]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


Next, we prepare the data for the model using the tokenizer. 

In [None]:
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
val_ds = val_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Downloading:   0%|          | 0.00/668 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/879 [00:00<?, ?B/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

We set the format to PyTorch tensors, and create familiar PyTorch dataloaders.

In [None]:
train_ds.set_format(type="torch", columns=['input_ids', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'label'])
val_ds.set_format(type="torch", columns=['input_ids', 'label'])

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=100, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=25)
val_dataloader = DataLoader(val_ds, batch_size=25)

Here we verify some things (always important to check out your data!).

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

label torch.Size([100])
input_ids torch.Size([100, 2048])


In [None]:
tokenizer.decode(batch['input_ids'][3])

"[CLS]I never expect a film adaptation to follow too closely to the novel (especially a beloved one, like Evening) but when I saw that the book's author, Susan Minot, was a screenplay writer and executive producer on the film, I thought that Evening would be a good adaptation.<br /><br />If you enjoyed the book, don't bother with this movie. It is so far afield of the book that the two hardly bear any resemblance to one another.<br /><br />Here, our characters are completely different: the bride is in love with Harris. Harris is the son of the housekeeper. Buddy is a drunk, in love with Ann and/or Harris. I don't think a single character made it from the book to the screen; oh it just gets worst with every passing moment.<br /><br />And, really, didn't we learn from Bridges of Madison County that cutting from the story we are meant to be enthralled in, to scenes of our heroes' grown children having obnoxious and juvenile fights, simply does not work on film? This film is a disaster. Sk

In [None]:
# batch['label']

In [None]:
import numpy as np
train_ds['label'].double().mean()
print(train_ds['label'][12499])

tensor(0)


## Define model

Next, we define our model, and put it on the GPU.

In [None]:
# preprocessor we customized to use the tagkop encoder
from tagkop_encoding_functions import (
    PerceiverImagePreprocessor,
    TagkopPerceiverTextPreprocessor,
)
from transformers import PerceiverForSequenceClassification

import torch

from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder,
    PerceiverTextPreprocessor,
    PerceiverClassificationDecoder
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


config = PerceiverConfig(
    num_self_attends_per_block = 4,
    d_model = 64
)

print('config', config)

# Vanilla Perceiver Encodings


preprocessor = PerceiverTextPreprocessor(config)

# Our new awesome encodings
# preprocessor = TagkopPerceiverTextPreprocessor(config)


decoder = PerceiverClassificationDecoder(config,
                                          num_channels=config.d_latents,
                                          trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
                                          use_query_residual=True,
                                         )

# num_self_attends_per_block, num_self_attention_heads, num_cross_attention_heads to something more reasonable and out_channels project_pos_dim and num_channels to 64
model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)



model.to(device)

config PerceiverConfig {
  "attention_probs_dropout_prob": 0.1,
  "audio_samples_per_frame": 1920,
  "cross_attention_shape_for_attention": "kv",
  "cross_attention_widening_factor": 1,
  "d_latents": 1280,
  "d_model": 64,
  "hidden_act": "gelu",
  "image_size": 56,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 2048,
  "model_type": "perceiver",
  "num_blocks": 1,
  "num_cross_attention_heads": 8,
  "num_frames": 16,
  "num_latents": 256,
  "num_self_attends_per_block": 4,
  "num_self_attention_heads": 8,
  "output_shape": [
    1,
    16,
    224,
    224
  ],
  "qk_channels": null,
  "samples_per_patch": 16,
  "self_attention_widening_factor": 1,
  "train_size": [
    368,
    496
  ],
  "transformers_version": "4.25.0.dev0",
  "use_query_residual": true,
  "v_channels": null,
  "vocab_size": 262
}



PerceiverModel(
  (input_preprocessor): PerceiverTextPreprocessor(
    (embeddings): Embedding(262, 64)
    (position_embeddings): Embedding(2048, 64)
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense): Linear(in_features=64, out_features=1280, bias=True)
        )
      )
      (layernorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (mlp): PerceiverMLP(
 

In [None]:
# you can then do a forward pass as follows:
tokenizer = PerceiverTokenizer()
text = "hello world"
inputs = tokenizer(text, return_tensors="pt").input_ids
print(inputs)
inputs.to(device)
with torch.no_grad():
   outputs = model(inputs=inputs.to(device))
logits = outputs.logits
print('list(logits.shape): ', list(logits.shape))
# to train, one can train the model using standard cross-entropy:
criterion = torch.nn.CrossEntropyLoss()
labels = torch.tensor([1]).to(device)
loss = criterion(logits, labels)

tensor([[  4, 110, 107, 114, 114, 117,  38, 125, 117, 120, 114, 106,   5]])
list(logits.shape):  [1, 2]


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/saved_model/small_model_Nov_16_epoch_8.pt'))
model.to(device)

PerceiverModel(
  (input_preprocessor): PerceiverTextPreprocessor(
    (embeddings): Embedding(262, 64)
    (position_embeddings): Embedding(2048, 64)
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense): Linear(in_features=64, out_features=1280, bias=True)
        )
      )
      (layernorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (mlp): PerceiverMLP(
 

## Train the model

Here we train the model using native PyTorch.

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

from tqdm.notebook import tqdm
from datasets import load_metric

optimizer = AdamW(model.parameters(), lr=1e-4)

model.train()

batch = next(iter(train_dataloader))
for epoch in range(50):  # loop over the dataset multiple times

    epoch_accs = np.array([])
    torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/rerun_small_model_Nov_16_epoch_' + str(epoch) + '.pt')
    print('saved model')
    print("Epoch:", epoch)
    # for i in range(2):
    for batch in tqdm(train_dataloader):
         batch_accs = np.array([])
         # get the inputs; 
         inputs = batch["input_ids"].to(device)
        #  attention_mask = batch["attention_mask"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs)
         logits = outputs.logits
         
         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         loss = criterion(logits, labels)
         loss.backward()
         optimizer.step()
         
         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
         batch_accs = np.append(batch_accs, accuracy)
    
    print('Batch Accuracy Avg: ', np.mean(batch_accs))
    
    accuracy = load_metric("accuracy")
    
    # Validation Test
    model.eval()
    for batch in tqdm(val_dataloader):
          # get the inputs; 
          inputs = batch["input_ids"].to(device)
          labels = batch["label"].to(device)

          # forward pass
          outputs = model(inputs=inputs)
          logits = outputs.logits 
          predictions = logits.argmax(-1).cpu().detach().numpy()
          references = batch["label"].numpy()
          accuracy.add_batch(predictions=predictions, references=references)

    final_score = accuracy.compute()
    print("Accuracy on validation set:", final_score)
    epoch_accs = np.append(epoch_accs, final_score)



saved model
Epoch: 0


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.693368136882782, Accuracy: 0.44
Loss: 0.7168745398521423, Accuracy: 0.51
Loss: 0.7018337249755859, Accuracy: 0.44
Loss: 0.7624335289001465, Accuracy: 0.39
Loss: 0.6923819184303284, Accuracy: 0.52
Loss: 0.6946253180503845, Accuracy: 0.5
Loss: 0.7050689458847046, Accuracy: 0.48
Loss: 0.687935471534729, Accuracy: 0.58
Loss: 0.6944180130958557, Accuracy: 0.49
Loss: 0.6934633851051331, Accuracy: 0.49
Loss: 0.69244384765625, Accuracy: 0.49
Loss: 0.6969552636146545, Accuracy: 0.42
Loss: 0.6833515763282776, Accuracy: 0.58
Loss: 0.7242802977561951, Accuracy: 0.49
Loss: 0.7121180891990662, Accuracy: 0.5
Loss: 0.6814194321632385, Accuracy: 0.59
Loss: 0.6940053701400757, Accuracy: 0.47
Loss: 0.6947258710861206, Accuracy: 0.49
Loss: 0.703264057636261, Accuracy: 0.47
Loss: 0.6883248686790466, Accuracy: 0.55
Loss: 0.6932123303413391, Accuracy: 0.5
Loss: 0.6910974979400635, Accuracy: 0.54
Loss: 0.6915602087974548, Accuracy: 0.51
Loss: 0.6999536752700806, Accuracy: 0.41
Loss: 0.6959879398345947



Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62624}
saved model
Epoch: 1


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.62185138463974, Accuracy: 0.7
Loss: 0.643634557723999, Accuracy: 0.64
Loss: 0.6244173049926758, Accuracy: 0.67
Loss: 0.6374948024749756, Accuracy: 0.66
Loss: 0.6095536351203918, Accuracy: 0.63
Loss: 0.6516433954238892, Accuracy: 0.66
Loss: 0.8110111355781555, Accuracy: 0.6
Loss: 0.6831523180007935, Accuracy: 0.59
Loss: 0.5747432112693787, Accuracy: 0.73
Loss: 0.6374267339706421, Accuracy: 0.62
Loss: 0.6599173545837402, Accuracy: 0.61
Loss: 0.6411436200141907, Accuracy: 0.65
Loss: 0.6616337299346924, Accuracy: 0.64
Loss: 0.6496026515960693, Accuracy: 0.63
Loss: 0.6294740438461304, Accuracy: 0.65
Loss: 0.6090351343154907, Accuracy: 0.7
Loss: 0.6553575396537781, Accuracy: 0.61
Loss: 0.6154111623764038, Accuracy: 0.65
Loss: 0.6405566334724426, Accuracy: 0.64
Loss: 0.6442317962646484, Accuracy: 0.66
Loss: 0.5971122980117798, Accuracy: 0.69
Loss: 0.6693634986877441, Accuracy: 0.55
Loss: 0.6417197585105896, Accuracy: 0.65
Loss: 0.6962986588478088, Accuracy: 0.55
Loss: 0.60108649730682

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61136}
saved model
Epoch: 2


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.618423581123352, Accuracy: 0.69
Loss: 0.6432235240936279, Accuracy: 0.68
Loss: 0.6087047457695007, Accuracy: 0.66
Loss: 0.6184061169624329, Accuracy: 0.67
Loss: 0.6981399059295654, Accuracy: 0.59
Loss: 0.6385088562965393, Accuracy: 0.63
Loss: 0.6183379292488098, Accuracy: 0.65
Loss: 0.542667806148529, Accuracy: 0.75
Loss: 0.6330515146255493, Accuracy: 0.65
Loss: 0.6864461302757263, Accuracy: 0.63
Loss: 0.5906670093536377, Accuracy: 0.71
Loss: 0.6512994170188904, Accuracy: 0.64
Loss: 0.5889741778373718, Accuracy: 0.67
Loss: 0.5885836482048035, Accuracy: 0.73
Loss: 0.6780649423599243, Accuracy: 0.63
Loss: 0.6033157706260681, Accuracy: 0.66
Loss: 0.6070911288261414, Accuracy: 0.7
Loss: 0.5798065066337585, Accuracy: 0.68
Loss: 0.6059763431549072, Accuracy: 0.67
Loss: 0.6297754049301147, Accuracy: 0.66
Loss: 0.5455560684204102, Accuracy: 0.78
Loss: 0.6410852670669556, Accuracy: 0.63
Loss: 0.6127059459686279, Accuracy: 0.67
Loss: 0.6022117733955383, Accuracy: 0.74
Loss: 0.63760566711

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63816}
saved model
Epoch: 3


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6100651025772095, Accuracy: 0.66
Loss: 0.6429837942123413, Accuracy: 0.61
Loss: 0.6156074404716492, Accuracy: 0.6
Loss: 0.6552769541740417, Accuracy: 0.64
Loss: 0.5786274075508118, Accuracy: 0.73
Loss: 0.5688806772232056, Accuracy: 0.76
Loss: 0.5767425298690796, Accuracy: 0.71
Loss: 0.6244853734970093, Accuracy: 0.67
Loss: 0.5805818438529968, Accuracy: 0.71
Loss: 0.6278192400932312, Accuracy: 0.68
Loss: 0.5823531746864319, Accuracy: 0.74
Loss: 0.5837095379829407, Accuracy: 0.7
Loss: 0.568777859210968, Accuracy: 0.7
Loss: 0.6303024888038635, Accuracy: 0.67
Loss: 0.7044219970703125, Accuracy: 0.62
Loss: 0.5598028302192688, Accuracy: 0.74
Loss: 0.6845500469207764, Accuracy: 0.59
Loss: 0.5752893686294556, Accuracy: 0.71
Loss: 0.6367887258529663, Accuracy: 0.62
Loss: 0.6748003959655762, Accuracy: 0.61
Loss: 0.5437723398208618, Accuracy: 0.75
Loss: 0.6352124214172363, Accuracy: 0.61
Loss: 0.6042026877403259, Accuracy: 0.64
Loss: 0.5991851687431335, Accuracy: 0.7
Loss: 0.5996627807617

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63096}
saved model
Epoch: 4


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6480636596679688, Accuracy: 0.68
Loss: 0.5410439968109131, Accuracy: 0.72
Loss: 0.5749983787536621, Accuracy: 0.71
Loss: 0.5383532047271729, Accuracy: 0.78
Loss: 0.5376818180084229, Accuracy: 0.71
Loss: 0.647513210773468, Accuracy: 0.62
Loss: 0.5889079570770264, Accuracy: 0.66
Loss: 0.5898928642272949, Accuracy: 0.73
Loss: 0.6336202025413513, Accuracy: 0.62
Loss: 0.5483720898628235, Accuracy: 0.75
Loss: 0.5501085519790649, Accuracy: 0.72
Loss: 0.5381302833557129, Accuracy: 0.76
Loss: 0.5033523440361023, Accuracy: 0.76
Loss: 0.608782172203064, Accuracy: 0.69
Loss: 0.5744411945343018, Accuracy: 0.75
Loss: 0.6373198628425598, Accuracy: 0.61
Loss: 0.5736528635025024, Accuracy: 0.74
Loss: 0.5653077960014343, Accuracy: 0.73
Loss: 0.6792781352996826, Accuracy: 0.59
Loss: 0.6138243675231934, Accuracy: 0.68
Loss: 0.5822780728340149, Accuracy: 0.72
Loss: 0.5889145135879517, Accuracy: 0.7
Loss: 0.6039156913757324, Accuracy: 0.67
Loss: 0.633599579334259, Accuracy: 0.67
Loss: 0.659981489181

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63832}
saved model
Epoch: 5


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5292834043502808, Accuracy: 0.78
Loss: 0.5675082206726074, Accuracy: 0.72
Loss: 0.5504863262176514, Accuracy: 0.71
Loss: 0.5429918766021729, Accuracy: 0.73
Loss: 0.48285019397735596, Accuracy: 0.79
Loss: 0.6105239391326904, Accuracy: 0.73
Loss: 0.6185600161552429, Accuracy: 0.66
Loss: 0.47951018810272217, Accuracy: 0.78
Loss: 0.6214414238929749, Accuracy: 0.64
Loss: 0.5851924419403076, Accuracy: 0.71
Loss: 0.6528581976890564, Accuracy: 0.67
Loss: 0.705859899520874, Accuracy: 0.61
Loss: 0.6123847961425781, Accuracy: 0.65
Loss: 0.5524715185165405, Accuracy: 0.69
Loss: 0.5700356960296631, Accuracy: 0.73
Loss: 0.6005215644836426, Accuracy: 0.66
Loss: 0.6202996969223022, Accuracy: 0.74
Loss: 0.6004704833030701, Accuracy: 0.65
Loss: 0.47300484776496887, Accuracy: 0.84
Loss: 0.5851019620895386, Accuracy: 0.65
Loss: 0.5555510520935059, Accuracy: 0.69
Loss: 0.5382039546966553, Accuracy: 0.73
Loss: 0.5486600995063782, Accuracy: 0.76
Loss: 0.48107925057411194, Accuracy: 0.77
Loss: 0.72511

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62456}
saved model
Epoch: 6


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6241729259490967, Accuracy: 0.66
Loss: 0.4860441982746124, Accuracy: 0.78
Loss: 0.4260340929031372, Accuracy: 0.84
Loss: 0.5415945053100586, Accuracy: 0.74
Loss: 0.5149493217468262, Accuracy: 0.78
Loss: 0.6012431979179382, Accuracy: 0.72
Loss: 0.5390426516532898, Accuracy: 0.73
Loss: 0.5465705990791321, Accuracy: 0.72
Loss: 0.41675159335136414, Accuracy: 0.86
Loss: 0.4851936995983124, Accuracy: 0.8
Loss: 0.5100862383842468, Accuracy: 0.8
Loss: 0.49805691838264465, Accuracy: 0.78
Loss: 0.4853283762931824, Accuracy: 0.79
Loss: 0.5580868721008301, Accuracy: 0.74
Loss: 0.4883686900138855, Accuracy: 0.8
Loss: 0.4543458819389343, Accuracy: 0.8
Loss: 0.5759958624839783, Accuracy: 0.69
Loss: 0.5397241711616516, Accuracy: 0.76
Loss: 0.5211876034736633, Accuracy: 0.77
Loss: 0.42720481753349304, Accuracy: 0.83
Loss: 0.5602141618728638, Accuracy: 0.74
Loss: 0.6991088390350342, Accuracy: 0.65
Loss: 0.5452858805656433, Accuracy: 0.74
Loss: 0.5219681262969971, Accuracy: 0.76
Loss: 0.517638087

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62768}
saved model
Epoch: 7


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5676006078720093, Accuracy: 0.72
Loss: 0.5658920407295227, Accuracy: 0.73
Loss: 0.4776662588119507, Accuracy: 0.78
Loss: 0.548894464969635, Accuracy: 0.77
Loss: 0.39671969413757324, Accuracy: 0.88
Loss: 0.5036609768867493, Accuracy: 0.77
Loss: 0.45910879969596863, Accuracy: 0.81
Loss: 0.5291429758071899, Accuracy: 0.73
Loss: 0.5547239780426025, Accuracy: 0.72
Loss: 0.6960685849189758, Accuracy: 0.67
Loss: 0.5549798011779785, Accuracy: 0.72
Loss: 0.4614452123641968, Accuracy: 0.74
Loss: 0.5301339626312256, Accuracy: 0.76
Loss: 0.45747867226600647, Accuracy: 0.83
Loss: 0.5432084798812866, Accuracy: 0.77
Loss: 0.45549020171165466, Accuracy: 0.83
Loss: 0.6142105460166931, Accuracy: 0.7
Loss: 0.531283438205719, Accuracy: 0.72
Loss: 0.5650184154510498, Accuracy: 0.74
Loss: 0.6350781917572021, Accuracy: 0.66
Loss: 0.5135440230369568, Accuracy: 0.78
Loss: 0.6003095507621765, Accuracy: 0.66
Loss: 0.49980977177619934, Accuracy: 0.76
Loss: 0.5066453814506531, Accuracy: 0.8
Loss: 0.5447557

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62544}
saved model
Epoch: 8


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.49535679817199707, Accuracy: 0.75
Loss: 0.4993487298488617, Accuracy: 0.75
Loss: 0.493366003036499, Accuracy: 0.76
Loss: 0.45093125104904175, Accuracy: 0.82
Loss: 0.49502238631248474, Accuracy: 0.75
Loss: 0.5469013452529907, Accuracy: 0.75
Loss: 0.4574143886566162, Accuracy: 0.76
Loss: 0.4880107045173645, Accuracy: 0.82
Loss: 0.5861297845840454, Accuracy: 0.76
Loss: 0.44738471508026123, Accuracy: 0.8
Loss: 0.5987837314605713, Accuracy: 0.69
Loss: 0.4419490098953247, Accuracy: 0.8
Loss: 0.6031410694122314, Accuracy: 0.8
Loss: 0.5749553442001343, Accuracy: 0.7
Loss: 0.4992290437221527, Accuracy: 0.8
Loss: 0.5451517701148987, Accuracy: 0.73
Loss: 0.48848530650138855, Accuracy: 0.77
Loss: 0.5080888867378235, Accuracy: 0.76
Loss: 0.5084244012832642, Accuracy: 0.77
Loss: 0.5371139049530029, Accuracy: 0.72
Loss: 0.4750209450721741, Accuracy: 0.77
Loss: 0.581333577632904, Accuracy: 0.71
Loss: 0.5083690881729126, Accuracy: 0.78
Loss: 0.4962942898273468, Accuracy: 0.77
Loss: 0.5483811497

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6196}
saved model
Epoch: 9


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.45927751064300537, Accuracy: 0.8
Loss: 0.5387266874313354, Accuracy: 0.75
Loss: 0.4731772541999817, Accuracy: 0.8
Loss: 0.5069727897644043, Accuracy: 0.79
Loss: 0.42448660731315613, Accuracy: 0.82
Loss: 0.4800146222114563, Accuracy: 0.81
Loss: 0.6003707647323608, Accuracy: 0.71
Loss: 0.47113505005836487, Accuracy: 0.79
Loss: 0.48230302333831787, Accuracy: 0.78
Loss: 0.37082844972610474, Accuracy: 0.84
Loss: 0.5099362730979919, Accuracy: 0.75
Loss: 0.5995870232582092, Accuracy: 0.7
Loss: 0.500679075717926, Accuracy: 0.79
Loss: 0.4729091227054596, Accuracy: 0.76
Loss: 0.49909600615501404, Accuracy: 0.76
Loss: 0.34143516421318054, Accuracy: 0.88
Loss: 0.5485038757324219, Accuracy: 0.7
Loss: 0.5605069994926453, Accuracy: 0.74
Loss: 0.5510785579681396, Accuracy: 0.77
Loss: 0.4790676534175873, Accuracy: 0.78
Loss: 0.5195727348327637, Accuracy: 0.75
Loss: 0.47792717814445496, Accuracy: 0.79
Loss: 0.49921849370002747, Accuracy: 0.78
Loss: 0.5472725033760071, Accuracy: 0.73
Loss: 0.4234

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62064}
saved model
Epoch: 10


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.43403172492980957, Accuracy: 0.82
Loss: 0.49487510323524475, Accuracy: 0.77
Loss: 0.44377613067626953, Accuracy: 0.82
Loss: 0.4918035566806793, Accuracy: 0.8
Loss: 0.4560002088546753, Accuracy: 0.81
Loss: 0.5188539028167725, Accuracy: 0.8
Loss: 0.4118436574935913, Accuracy: 0.83
Loss: 0.5451862812042236, Accuracy: 0.75
Loss: 0.5157116651535034, Accuracy: 0.77
Loss: 0.3662593364715576, Accuracy: 0.85
Loss: 0.4841400980949402, Accuracy: 0.79
Loss: 0.5138429403305054, Accuracy: 0.77
Loss: 0.43496963381767273, Accuracy: 0.81
Loss: 0.3943134546279907, Accuracy: 0.85
Loss: 0.41919979453086853, Accuracy: 0.85
Loss: 0.49401265382766724, Accuracy: 0.8
Loss: 0.44234219193458557, Accuracy: 0.83
Loss: 0.3931306004524231, Accuracy: 0.87
Loss: 0.44230708479881287, Accuracy: 0.82
Loss: 0.4721396267414093, Accuracy: 0.79
Loss: 0.4912811517715454, Accuracy: 0.8
Loss: 0.4956333041191101, Accuracy: 0.84
Loss: 0.5082308053970337, Accuracy: 0.8
Loss: 0.4901511073112488, Accuracy: 0.79
Loss: 0.36265

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6196}
saved model
Epoch: 11


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.47881171107292175, Accuracy: 0.81
Loss: 0.5161840319633484, Accuracy: 0.74
Loss: 0.4542214572429657, Accuracy: 0.81
Loss: 0.5960199236869812, Accuracy: 0.75
Loss: 0.41545745730400085, Accuracy: 0.84
Loss: 0.4419266879558563, Accuracy: 0.79
Loss: 0.3816995322704315, Accuracy: 0.87
Loss: 0.3886268138885498, Accuracy: 0.84
Loss: 0.39738982915878296, Accuracy: 0.87
Loss: 0.542382001876831, Accuracy: 0.74
Loss: 0.42605912685394287, Accuracy: 0.83
Loss: 0.5229822397232056, Accuracy: 0.77
Loss: 0.40765249729156494, Accuracy: 0.82
Loss: 0.42975670099258423, Accuracy: 0.86
Loss: 0.42226409912109375, Accuracy: 0.82
Loss: 0.5072095394134521, Accuracy: 0.78
Loss: 0.30939993262290955, Accuracy: 0.9
Loss: 0.5348346829414368, Accuracy: 0.78
Loss: 0.2860414981842041, Accuracy: 0.92
Loss: 0.4449138343334198, Accuracy: 0.82
Loss: 0.32108452916145325, Accuracy: 0.89
Loss: 0.5023000240325928, Accuracy: 0.76
Loss: 0.43429237604141235, Accuracy: 0.83
Loss: 0.34996941685676575, Accuracy: 0.86
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61832}
saved model
Epoch: 12


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.4915142357349396, Accuracy: 0.77
Loss: 0.40222081542015076, Accuracy: 0.81
Loss: 0.4220956563949585, Accuracy: 0.84
Loss: 0.4437098801136017, Accuracy: 0.83
Loss: 0.49125412106513977, Accuracy: 0.78
Loss: 0.44419074058532715, Accuracy: 0.79
Loss: 0.4780412018299103, Accuracy: 0.81
Loss: 0.4124375581741333, Accuracy: 0.83
Loss: 0.4459555745124817, Accuracy: 0.81
Loss: 0.3773423433303833, Accuracy: 0.85
Loss: 0.4382587671279907, Accuracy: 0.85
Loss: 0.3849336504936218, Accuracy: 0.87
Loss: 0.3760392367839813, Accuracy: 0.86
Loss: 0.45546501874923706, Accuracy: 0.82
Loss: 0.38301289081573486, Accuracy: 0.83
Loss: 0.42031043767929077, Accuracy: 0.8
Loss: 0.48174381256103516, Accuracy: 0.79
Loss: 0.3910718560218811, Accuracy: 0.86
Loss: 0.4878092110157013, Accuracy: 0.79
Loss: 0.38693028688430786, Accuracy: 0.88
Loss: 0.5603100061416626, Accuracy: 0.76
Loss: 0.3100435733795166, Accuracy: 0.87
Loss: 0.41653233766555786, Accuracy: 0.82
Loss: 0.37238746881484985, Accuracy: 0.83
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60544}
saved model
Epoch: 13


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3967166244983673, Accuracy: 0.83
Loss: 0.4159645736217499, Accuracy: 0.81
Loss: 0.43586266040802, Accuracy: 0.84
Loss: 0.42786797881126404, Accuracy: 0.82
Loss: 0.36661162972450256, Accuracy: 0.87
Loss: 0.4719901382923126, Accuracy: 0.77
Loss: 0.38883116841316223, Accuracy: 0.83
Loss: 0.4211593568325043, Accuracy: 0.8
Loss: 0.44791823625564575, Accuracy: 0.82
Loss: 0.392427533864975, Accuracy: 0.85
Loss: 0.3759743869304657, Accuracy: 0.84
Loss: 0.35732781887054443, Accuracy: 0.85
Loss: 0.3861161768436432, Accuracy: 0.88
Loss: 0.41943082213401794, Accuracy: 0.83
Loss: 0.35976940393447876, Accuracy: 0.86
Loss: 0.345342755317688, Accuracy: 0.85
Loss: 0.2984813451766968, Accuracy: 0.88
Loss: 0.4520217776298523, Accuracy: 0.81
Loss: 0.615769624710083, Accuracy: 0.72
Loss: 0.5668064951896667, Accuracy: 0.77
Loss: 0.4113619327545166, Accuracy: 0.83
Loss: 0.4039527475833893, Accuracy: 0.85
Loss: 0.3748815655708313, Accuracy: 0.86
Loss: 0.39132604002952576, Accuracy: 0.84
Loss: 0.416294

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61752}
saved model
Epoch: 14


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.40161705017089844, Accuracy: 0.84
Loss: 0.29820716381073, Accuracy: 0.92
Loss: 0.3358488082885742, Accuracy: 0.87
Loss: 0.32162415981292725, Accuracy: 0.88
Loss: 0.36810067296028137, Accuracy: 0.87
Loss: 0.45045143365859985, Accuracy: 0.83
Loss: 0.3654628396034241, Accuracy: 0.85
Loss: 0.41925713419914246, Accuracy: 0.85
Loss: 0.31340619921684265, Accuracy: 0.88
Loss: 0.3585239350795746, Accuracy: 0.85
Loss: 0.3894445300102234, Accuracy: 0.84
Loss: 0.24811722338199615, Accuracy: 0.91
Loss: 0.31983861327171326, Accuracy: 0.88
Loss: 0.364098459482193, Accuracy: 0.85
Loss: 0.41002795100212097, Accuracy: 0.82
Loss: 0.47471657395362854, Accuracy: 0.81
Loss: 0.32614266872406006, Accuracy: 0.87
Loss: 0.24192893505096436, Accuracy: 0.95
Loss: 0.39169201254844666, Accuracy: 0.8
Loss: 0.4558236598968506, Accuracy: 0.82
Loss: 0.35117411613464355, Accuracy: 0.88
Loss: 0.3712027668952942, Accuracy: 0.86
Loss: 0.3139232397079468, Accuracy: 0.91
Loss: 0.4273320436477661, Accuracy: 0.81
Loss: 

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61528}
saved model
Epoch: 15


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.32089972496032715, Accuracy: 0.86
Loss: 0.3825347125530243, Accuracy: 0.85
Loss: 0.313533216714859, Accuracy: 0.86
Loss: 0.30315539240837097, Accuracy: 0.89
Loss: 0.28310996294021606, Accuracy: 0.91
Loss: 0.3511834740638733, Accuracy: 0.86
Loss: 0.29238125681877136, Accuracy: 0.92
Loss: 0.4176744222640991, Accuracy: 0.85
Loss: 0.30398979783058167, Accuracy: 0.89
Loss: 0.40232178568840027, Accuracy: 0.84
Loss: 0.3912501931190491, Accuracy: 0.86
Loss: 0.4029470682144165, Accuracy: 0.86
Loss: 0.413875937461853, Accuracy: 0.83
Loss: 0.4098045229911804, Accuracy: 0.83
Loss: 0.4749307334423065, Accuracy: 0.82
Loss: 0.3357962369918823, Accuracy: 0.89
Loss: 0.41791942715644836, Accuracy: 0.86
Loss: 0.4729686975479126, Accuracy: 0.8
Loss: 0.4251798391342163, Accuracy: 0.81
Loss: 0.31789177656173706, Accuracy: 0.87
Loss: 0.31912997364997864, Accuracy: 0.88
Loss: 0.3887966275215149, Accuracy: 0.84
Loss: 0.4341805577278137, Accuracy: 0.82
Loss: 0.4332205057144165, Accuracy: 0.83
Loss: 0.48

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61032}
saved model
Epoch: 16


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3861946761608124, Accuracy: 0.85
Loss: 0.4407121241092682, Accuracy: 0.84
Loss: 0.35815393924713135, Accuracy: 0.85
Loss: 0.24794960021972656, Accuracy: 0.93
Loss: 0.3594679534435272, Accuracy: 0.88
Loss: 0.3937285244464874, Accuracy: 0.84
Loss: 0.45561474561691284, Accuracy: 0.83
Loss: 0.3614693582057953, Accuracy: 0.86
Loss: 0.3935316205024719, Accuracy: 0.84
Loss: 0.2526785433292389, Accuracy: 0.93
Loss: 0.3581256568431854, Accuracy: 0.86
Loss: 0.44851112365722656, Accuracy: 0.83
Loss: 0.3818344175815582, Accuracy: 0.87
Loss: 0.5022295713424683, Accuracy: 0.79
Loss: 0.4455125331878662, Accuracy: 0.8
Loss: 0.323324590921402, Accuracy: 0.89
Loss: 0.41079825162887573, Accuracy: 0.84
Loss: 0.3199506103992462, Accuracy: 0.91
Loss: 0.27432218194007874, Accuracy: 0.91
Loss: 0.3625207543373108, Accuracy: 0.89
Loss: 0.41715800762176514, Accuracy: 0.83
Loss: 0.3545086979866028, Accuracy: 0.87
Loss: 0.37958160042762756, Accuracy: 0.83
Loss: 0.40891021490097046, Accuracy: 0.83
Loss: 0.3

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60944}
saved model
Epoch: 17


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3160068988800049, Accuracy: 0.9
Loss: 0.3234253227710724, Accuracy: 0.86
Loss: 0.30172088742256165, Accuracy: 0.9
Loss: 0.2668686807155609, Accuracy: 0.92
Loss: 0.26490941643714905, Accuracy: 0.93
Loss: 0.32110556960105896, Accuracy: 0.87
Loss: 0.43530845642089844, Accuracy: 0.84
Loss: 0.3169238567352295, Accuracy: 0.88
Loss: 0.268298476934433, Accuracy: 0.92
Loss: 0.3347815275192261, Accuracy: 0.87
Loss: 0.32154905796051025, Accuracy: 0.9
Loss: 0.24116292595863342, Accuracy: 0.92
Loss: 0.37052199244499207, Accuracy: 0.86
Loss: 0.4416542053222656, Accuracy: 0.83
Loss: 0.4156263470649719, Accuracy: 0.83
Loss: 0.28817522525787354, Accuracy: 0.93
Loss: 0.2872862219810486, Accuracy: 0.91
Loss: 0.3449198603630066, Accuracy: 0.87
Loss: 0.4372456669807434, Accuracy: 0.83
Loss: 0.34314167499542236, Accuracy: 0.89
Loss: 0.2599028944969177, Accuracy: 0.92
Loss: 0.2620565891265869, Accuracy: 0.92
Loss: 0.31609320640563965, Accuracy: 0.89
Loss: 0.27229785919189453, Accuracy: 0.9
Loss: 0.27

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61432}
saved model
Epoch: 18


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.500494658946991, Accuracy: 0.78
Loss: 0.3036966323852539, Accuracy: 0.89
Loss: 0.3058449625968933, Accuracy: 0.91
Loss: 0.27960672974586487, Accuracy: 0.91
Loss: 0.34571555256843567, Accuracy: 0.87
Loss: 0.3799661695957184, Accuracy: 0.83
Loss: 0.3822852671146393, Accuracy: 0.86
Loss: 0.28347325325012207, Accuracy: 0.92
Loss: 0.2880224585533142, Accuracy: 0.91
Loss: 0.23296837508678436, Accuracy: 0.91
Loss: 0.35311341285705566, Accuracy: 0.87
Loss: 0.3272928297519684, Accuracy: 0.87
Loss: 0.3579740524291992, Accuracy: 0.88
Loss: 0.2554500997066498, Accuracy: 0.91
Loss: 0.3307286202907562, Accuracy: 0.87
Loss: 0.33129313588142395, Accuracy: 0.87
Loss: 0.5115101337432861, Accuracy: 0.8
Loss: 0.3398759067058563, Accuracy: 0.87
Loss: 0.2808329164981842, Accuracy: 0.9
Loss: 0.41933178901672363, Accuracy: 0.85
Loss: 0.47842803597450256, Accuracy: 0.78
Loss: 0.2902930676937103, Accuracy: 0.89
Loss: 0.4134212136268616, Accuracy: 0.82
Loss: 0.3426622152328491, Accuracy: 0.86
Loss: 0.407

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61592}
saved model
Epoch: 19


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.32467329502105713, Accuracy: 0.9
Loss: 0.2884633541107178, Accuracy: 0.92
Loss: 0.4957159459590912, Accuracy: 0.81
Loss: 0.4468596577644348, Accuracy: 0.82
Loss: 0.4541866183280945, Accuracy: 0.82
Loss: 0.43605685234069824, Accuracy: 0.83
Loss: 0.27645769715309143, Accuracy: 0.92
Loss: 0.40973761677742004, Accuracy: 0.82
Loss: 0.34126347303390503, Accuracy: 0.87
Loss: 0.3766231834888458, Accuracy: 0.83
Loss: 0.35178229212760925, Accuracy: 0.85
Loss: 0.3149975538253784, Accuracy: 0.88
Loss: 0.3149186074733734, Accuracy: 0.88
Loss: 0.21653097867965698, Accuracy: 0.92
Loss: 0.30575546622276306, Accuracy: 0.88
Loss: 0.46521008014678955, Accuracy: 0.83
Loss: 0.3036750853061676, Accuracy: 0.9
Loss: 0.359806090593338, Accuracy: 0.88
Loss: 0.2457513064146042, Accuracy: 0.91
Loss: 0.26047423481941223, Accuracy: 0.91
Loss: 0.35599255561828613, Accuracy: 0.86
Loss: 0.29782116413116455, Accuracy: 0.92
Loss: 0.35477790236473083, Accuracy: 0.86
Loss: 0.3211032450199127, Accuracy: 0.89
Loss: 

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61008}
saved model
Epoch: 20


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.38313937187194824, Accuracy: 0.86
Loss: 0.29917386174201965, Accuracy: 0.88
Loss: 0.23447532951831818, Accuracy: 0.93
Loss: 0.2700178921222687, Accuracy: 0.92
Loss: 0.2010556161403656, Accuracy: 0.93
Loss: 0.2875867784023285, Accuracy: 0.88
Loss: 0.28024372458457947, Accuracy: 0.9
Loss: 0.28526389598846436, Accuracy: 0.9
Loss: 0.2767389416694641, Accuracy: 0.91
Loss: 0.36581695079803467, Accuracy: 0.86
Loss: 0.2474098950624466, Accuracy: 0.9
Loss: 0.2829655408859253, Accuracy: 0.9
Loss: 0.38787373900413513, Accuracy: 0.88
Loss: 0.23598001897335052, Accuracy: 0.92
Loss: 0.3474273383617401, Accuracy: 0.84
Loss: 0.3551529049873352, Accuracy: 0.87
Loss: 0.3589775562286377, Accuracy: 0.88
Loss: 0.3237679600715637, Accuracy: 0.89
Loss: 0.25378167629241943, Accuracy: 0.9
Loss: 0.30825385451316833, Accuracy: 0.9
Loss: 0.29707905650138855, Accuracy: 0.89
Loss: 0.2701086699962616, Accuracy: 0.89
Loss: 0.3739948570728302, Accuracy: 0.84
Loss: 0.32722169160842896, Accuracy: 0.88
Loss: 0.32

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61}
saved model
Epoch: 21


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2440122812986374, Accuracy: 0.93
Loss: 0.2630845308303833, Accuracy: 0.91
Loss: 0.2661028802394867, Accuracy: 0.9
Loss: 0.2561592757701874, Accuracy: 0.91
Loss: 0.21001169085502625, Accuracy: 0.92
Loss: 0.27046576142311096, Accuracy: 0.91
Loss: 0.10041265189647675, Accuracy: 0.97
Loss: 0.40310508012771606, Accuracy: 0.86
Loss: 0.23938526213169098, Accuracy: 0.93
Loss: 0.3178989589214325, Accuracy: 0.87
Loss: 0.3619075417518616, Accuracy: 0.88
Loss: 0.21298781037330627, Accuracy: 0.93
Loss: 0.32421597838401794, Accuracy: 0.88
Loss: 0.3262437582015991, Accuracy: 0.85
Loss: 0.2827666997909546, Accuracy: 0.9
Loss: 0.27190256118774414, Accuracy: 0.92
Loss: 0.411784827709198, Accuracy: 0.85
Loss: 0.25773459672927856, Accuracy: 0.9
Loss: 0.3248022794723511, Accuracy: 0.88
Loss: 0.324944406747818, Accuracy: 0.87
Loss: 0.246078222990036, Accuracy: 0.93
Loss: 0.22179840505123138, Accuracy: 0.94
Loss: 0.3039829730987549, Accuracy: 0.9
Loss: 0.2702566087245941, Accuracy: 0.91
Loss: 0.26875

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61448}
saved model
Epoch: 22


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2919447124004364, Accuracy: 0.87
Loss: 0.2779441177845001, Accuracy: 0.92
Loss: 0.25245630741119385, Accuracy: 0.91
Loss: 0.29211530089378357, Accuracy: 0.87
Loss: 0.3953106701374054, Accuracy: 0.82
Loss: 0.35133227705955505, Accuracy: 0.89
Loss: 0.19714663922786713, Accuracy: 0.92
Loss: 0.2772025465965271, Accuracy: 0.9
Loss: 0.2447173148393631, Accuracy: 0.91
Loss: 0.3385223150253296, Accuracy: 0.87
Loss: 0.26374903321266174, Accuracy: 0.89
Loss: 0.23059462010860443, Accuracy: 0.91
Loss: 0.26172465085983276, Accuracy: 0.91
Loss: 0.13969232141971588, Accuracy: 0.95
Loss: 0.25299403071403503, Accuracy: 0.9
Loss: 0.3331693708896637, Accuracy: 0.85
Loss: 0.20471785962581635, Accuracy: 0.92
Loss: 0.3119121789932251, Accuracy: 0.89
Loss: 0.2480340152978897, Accuracy: 0.92
Loss: 0.229888916015625, Accuracy: 0.91
Loss: 0.321424663066864, Accuracy: 0.87
Loss: 0.21655189990997314, Accuracy: 0.93
Loss: 0.23278258740901947, Accuracy: 0.92
Loss: 0.2710375487804413, Accuracy: 0.89
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6156}
saved model
Epoch: 23


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.25566309690475464, Accuracy: 0.91
Loss: 0.26128485798835754, Accuracy: 0.89
Loss: 0.24430827796459198, Accuracy: 0.94
Loss: 0.31344273686408997, Accuracy: 0.91
Loss: 0.2293236404657364, Accuracy: 0.93
Loss: 0.3092680275440216, Accuracy: 0.9
Loss: 0.35456782579421997, Accuracy: 0.86
Loss: 0.24841804802417755, Accuracy: 0.9
Loss: 0.22594013810157776, Accuracy: 0.93
Loss: 0.20914095640182495, Accuracy: 0.92
Loss: 0.24491795897483826, Accuracy: 0.91
Loss: 0.3474430739879608, Accuracy: 0.86
Loss: 0.2068951576948166, Accuracy: 0.93
Loss: 0.1547219157218933, Accuracy: 0.98
Loss: 0.16789858043193817, Accuracy: 0.94
Loss: 0.24249334633350372, Accuracy: 0.91
Loss: 0.2644539475440979, Accuracy: 0.9
Loss: 0.27936816215515137, Accuracy: 0.88
Loss: 0.16447965800762177, Accuracy: 0.95
Loss: 0.19244574010372162, Accuracy: 0.92
Loss: 0.1714622527360916, Accuracy: 0.94
Loss: 0.2836500406265259, Accuracy: 0.89
Loss: 0.20176593959331512, Accuracy: 0.94
Loss: 0.21610476076602936, Accuracy: 0.92
Los

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.614}
saved model
Epoch: 24


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2831202745437622, Accuracy: 0.91
Loss: 0.27817341685295105, Accuracy: 0.91
Loss: 0.2776177227497101, Accuracy: 0.92
Loss: 0.2637168765068054, Accuracy: 0.91
Loss: 0.27276477217674255, Accuracy: 0.91
Loss: 0.39267489314079285, Accuracy: 0.86
Loss: 0.31143343448638916, Accuracy: 0.9
Loss: 0.27853602170944214, Accuracy: 0.91
Loss: 0.17671246826648712, Accuracy: 0.96
Loss: 0.1912926286458969, Accuracy: 0.93
Loss: 0.3097798526287079, Accuracy: 0.88
Loss: 0.3565799295902252, Accuracy: 0.86
Loss: 0.4068276286125183, Accuracy: 0.86
Loss: 0.24893680214881897, Accuracy: 0.92
Loss: 0.35372069478034973, Accuracy: 0.88
Loss: 0.3598334789276123, Accuracy: 0.86
Loss: 0.2707442343235016, Accuracy: 0.91
Loss: 0.32077276706695557, Accuracy: 0.88
Loss: 0.31230229139328003, Accuracy: 0.9
Loss: 0.2793136239051819, Accuracy: 0.87
Loss: 0.29717570543289185, Accuracy: 0.88
Loss: 0.3461020290851593, Accuracy: 0.89
Loss: 0.310286283493042, Accuracy: 0.89
Loss: 0.4274953603744507, Accuracy: 0.83
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61304}
saved model
Epoch: 25


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.19352735579013824, Accuracy: 0.94
Loss: 0.24874883890151978, Accuracy: 0.92
Loss: 0.27444806694984436, Accuracy: 0.92
Loss: 0.37225738167762756, Accuracy: 0.87
Loss: 0.22725015878677368, Accuracy: 0.91
Loss: 0.27259889245033264, Accuracy: 0.89
Loss: 0.2638755142688751, Accuracy: 0.92
Loss: 0.22264832258224487, Accuracy: 0.94
Loss: 0.17451819777488708, Accuracy: 0.97
Loss: 0.2746339738368988, Accuracy: 0.89
Loss: 0.20656239986419678, Accuracy: 0.94
Loss: 0.3150765299797058, Accuracy: 0.88
Loss: 0.28004613518714905, Accuracy: 0.89
Loss: 0.2884618639945984, Accuracy: 0.89
Loss: 0.21744216978549957, Accuracy: 0.93
Loss: 0.19825991988182068, Accuracy: 0.93
Loss: 0.26885777711868286, Accuracy: 0.89
Loss: 0.2603648900985718, Accuracy: 0.93
Loss: 0.24629254639148712, Accuracy: 0.91
Loss: 0.2093554437160492, Accuracy: 0.93
Loss: 0.4677087664604187, Accuracy: 0.79
Loss: 0.20595712959766388, Accuracy: 0.92
Loss: 0.3727884292602539, Accuracy: 0.86
Loss: 0.3487829267978668, Accuracy: 0.88
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61008}
saved model
Epoch: 26


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3198738396167755, Accuracy: 0.89
Loss: 0.2974900007247925, Accuracy: 0.9
Loss: 0.29627788066864014, Accuracy: 0.87
Loss: 0.2681240439414978, Accuracy: 0.92
Loss: 0.2850070297718048, Accuracy: 0.9
Loss: 0.3210826516151428, Accuracy: 0.87
Loss: 0.2567519247531891, Accuracy: 0.91
Loss: 0.3534400165081024, Accuracy: 0.9
Loss: 0.3048822283744812, Accuracy: 0.89
Loss: 0.24715708196163177, Accuracy: 0.92
Loss: 0.2100289762020111, Accuracy: 0.94
Loss: 0.2191886454820633, Accuracy: 0.92
Loss: 0.2205522209405899, Accuracy: 0.94
Loss: 0.2729313373565674, Accuracy: 0.92
Loss: 0.299363374710083, Accuracy: 0.9
Loss: 0.29640281200408936, Accuracy: 0.9
Loss: 0.23258602619171143, Accuracy: 0.91
Loss: 0.20104852318763733, Accuracy: 0.94
Loss: 0.2211582213640213, Accuracy: 0.92
Loss: 0.2545337378978729, Accuracy: 0.91
Loss: 0.24595248699188232, Accuracy: 0.9
Loss: 0.23236769437789917, Accuracy: 0.93
Loss: 0.20205055177211761, Accuracy: 0.94
Loss: 0.1799129843711853, Accuracy: 0.93
Loss: 0.2227015

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60928}
saved model
Epoch: 27


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.27442002296447754, Accuracy: 0.87
Loss: 0.19499288499355316, Accuracy: 0.93
Loss: 0.2629956305027008, Accuracy: 0.91
Loss: 0.2539111375808716, Accuracy: 0.9
Loss: 0.2716565728187561, Accuracy: 0.91
Loss: 0.2216530591249466, Accuracy: 0.91
Loss: 0.20816460251808167, Accuracy: 0.92
Loss: 0.13343371450901031, Accuracy: 0.96
Loss: 0.1787244826555252, Accuracy: 0.94
Loss: 0.14063194394111633, Accuracy: 0.95
Loss: 0.18198566138744354, Accuracy: 0.95
Loss: 0.30764591693878174, Accuracy: 0.89
Loss: 0.20473065972328186, Accuracy: 0.93
Loss: 0.36779865622520447, Accuracy: 0.88
Loss: 0.25906074047088623, Accuracy: 0.9
Loss: 0.26316946744918823, Accuracy: 0.9
Loss: 0.3392848074436188, Accuracy: 0.89
Loss: 0.31404662132263184, Accuracy: 0.88
Loss: 0.23420925438404083, Accuracy: 0.92
Loss: 0.24805475771427155, Accuracy: 0.92
Loss: 0.16444043815135956, Accuracy: 0.96
Loss: 0.20751051604747772, Accuracy: 0.93
Loss: 0.23189374804496765, Accuracy: 0.91
Loss: 0.20226608216762543, Accuracy: 0.93
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61096}
saved model
Epoch: 28


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.1957368403673172, Accuracy: 0.92
Loss: 0.24015991389751434, Accuracy: 0.9
Loss: 0.2129828780889511, Accuracy: 0.91
Loss: 0.16516365110874176, Accuracy: 0.95
Loss: 0.24739812314510345, Accuracy: 0.91
Loss: 0.24983114004135132, Accuracy: 0.92
Loss: 0.2980630397796631, Accuracy: 0.9
Loss: 0.18938377499580383, Accuracy: 0.95
Loss: 0.2389567494392395, Accuracy: 0.92
Loss: 0.13270379602909088, Accuracy: 0.96
Loss: 0.19831451773643494, Accuracy: 0.93
Loss: 0.24926616251468658, Accuracy: 0.91
Loss: 0.2613704204559326, Accuracy: 0.89
Loss: 0.2922513782978058, Accuracy: 0.88
Loss: 0.27622783184051514, Accuracy: 0.92
Loss: 0.24011048674583435, Accuracy: 0.93
Loss: 0.258596807718277, Accuracy: 0.9
Loss: 0.20002658665180206, Accuracy: 0.92
Loss: 0.2870984971523285, Accuracy: 0.87
Loss: 0.24362319707870483, Accuracy: 0.92
Loss: 0.23354700207710266, Accuracy: 0.89
Loss: 0.18559134006500244, Accuracy: 0.95
Loss: 0.22177784144878387, Accuracy: 0.9
Loss: 0.19453641772270203, Accuracy: 0.94
Loss:

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61056}
saved model
Epoch: 29


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.18358416855335236, Accuracy: 0.92
Loss: 0.15998202562332153, Accuracy: 0.95
Loss: 0.22641414403915405, Accuracy: 0.9
Loss: 0.20247051119804382, Accuracy: 0.92
Loss: 0.19114351272583008, Accuracy: 0.94
Loss: 0.19565893709659576, Accuracy: 0.92
Loss: 0.17109794914722443, Accuracy: 0.94
Loss: 0.19495542347431183, Accuracy: 0.94
Loss: 0.14843051135540009, Accuracy: 0.95
Loss: 0.17290934920310974, Accuracy: 0.94
Loss: 0.20343074202537537, Accuracy: 0.91
Loss: 0.1806451380252838, Accuracy: 0.91
Loss: 0.21580582857131958, Accuracy: 0.91
Loss: 0.22299884259700775, Accuracy: 0.9
Loss: 0.19457188248634338, Accuracy: 0.92
Loss: 0.2995147407054901, Accuracy: 0.88
Loss: 0.2425239384174347, Accuracy: 0.91
Loss: 0.13607251644134521, Accuracy: 0.96
Loss: 0.21730577945709229, Accuracy: 0.94
Loss: 0.2222960889339447, Accuracy: 0.92
Loss: 0.1951252818107605, Accuracy: 0.93
Loss: 0.16271209716796875, Accuracy: 0.93
Loss: 0.15115948021411896, Accuracy: 0.94
Loss: 0.1218440905213356, Accuracy: 0.96


  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60896}
saved model
Epoch: 30


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.23657268285751343, Accuracy: 0.92
Loss: 0.20722472667694092, Accuracy: 0.94
Loss: 0.21157488226890564, Accuracy: 0.92
Loss: 0.25575631856918335, Accuracy: 0.89
Loss: 0.2068854123353958, Accuracy: 0.92
Loss: 0.27209559082984924, Accuracy: 0.89
Loss: 0.16650086641311646, Accuracy: 0.94
Loss: 0.08393159508705139, Accuracy: 0.98
Loss: 0.30002084374427795, Accuracy: 0.89
Loss: 0.26183509826660156, Accuracy: 0.91
Loss: 0.3232405483722687, Accuracy: 0.87
Loss: 0.17981523275375366, Accuracy: 0.92
Loss: 0.2234303504228592, Accuracy: 0.93
Loss: 0.23171228170394897, Accuracy: 0.91
Loss: 0.15593241155147552, Accuracy: 0.94
Loss: 0.16783688962459564, Accuracy: 0.93
Loss: 0.24017837643623352, Accuracy: 0.9
Loss: 0.208795428276062, Accuracy: 0.93
Loss: 0.19959445297718048, Accuracy: 0.95
Loss: 0.1826590597629547, Accuracy: 0.94
Loss: 0.28842535614967346, Accuracy: 0.87
Loss: 0.2937922179698944, Accuracy: 0.88
Loss: 0.22754091024398804, Accuracy: 0.9
Loss: 0.2733290493488312, Accuracy: 0.88
Lo

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61232}
saved model
Epoch: 31


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.21321380138397217, Accuracy: 0.93
Loss: 0.15928250551223755, Accuracy: 0.93
Loss: 0.16443219780921936, Accuracy: 0.92
Loss: 0.19484364986419678, Accuracy: 0.91
Loss: 0.1919546276330948, Accuracy: 0.93
Loss: 0.19110947847366333, Accuracy: 0.91
Loss: 0.20222410559654236, Accuracy: 0.92
Loss: 0.25867635011672974, Accuracy: 0.88
Loss: 0.2740214765071869, Accuracy: 0.88
Loss: 0.16224630177021027, Accuracy: 0.94
Loss: 0.21197053790092468, Accuracy: 0.92
Loss: 0.1296393722295761, Accuracy: 0.96
Loss: 0.12926575541496277, Accuracy: 0.94
Loss: 0.18566040694713593, Accuracy: 0.93
Loss: 0.2416597604751587, Accuracy: 0.91
Loss: 0.2215699553489685, Accuracy: 0.91
Loss: 0.16942653059959412, Accuracy: 0.93
Loss: 0.19549018144607544, Accuracy: 0.91
Loss: 0.23054108023643494, Accuracy: 0.91
Loss: 0.1940590888261795, Accuracy: 0.93
Loss: 0.29379767179489136, Accuracy: 0.87
Loss: 0.1346503347158432, Accuracy: 0.95
Loss: 0.17418445646762848, Accuracy: 0.95
Loss: 0.156544491648674, Accuracy: 0.96
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61072}
saved model
Epoch: 32


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.25430119037628174, Accuracy: 0.87
Loss: 0.19953271746635437, Accuracy: 0.9
Loss: 0.29584887623786926, Accuracy: 0.9
Loss: 0.23017467558383942, Accuracy: 0.89
Loss: 0.22844427824020386, Accuracy: 0.89
Loss: 0.0850871279835701, Accuracy: 0.98
Loss: 0.2359427809715271, Accuracy: 0.89
Loss: 0.2379143387079239, Accuracy: 0.9
Loss: 0.20218469202518463, Accuracy: 0.9
Loss: 0.17941415309906006, Accuracy: 0.92
Loss: 0.16634295880794525, Accuracy: 0.95
Loss: 0.1622912585735321, Accuracy: 0.92
Loss: 0.17352771759033203, Accuracy: 0.94
Loss: 0.1802922487258911, Accuracy: 0.91
Loss: 0.12431557476520538, Accuracy: 0.95
Loss: 0.18908868730068207, Accuracy: 0.89
Loss: 0.12367267906665802, Accuracy: 0.95
Loss: 0.19561509788036346, Accuracy: 0.91
Loss: 0.2821294069290161, Accuracy: 0.89
Loss: 0.239565908908844, Accuracy: 0.92
Loss: 0.16427698731422424, Accuracy: 0.9
Loss: 0.24988217651844025, Accuracy: 0.88
Loss: 0.1987106204032898, Accuracy: 0.95
Loss: 0.17030885815620422, Accuracy: 0.94
Loss: 

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60712}
saved model
Epoch: 33


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.24025748670101166, Accuracy: 0.91
Loss: 0.1989925354719162, Accuracy: 0.94
Loss: 0.2208327054977417, Accuracy: 0.91
Loss: 0.2868322432041168, Accuracy: 0.87
Loss: 0.14393524825572968, Accuracy: 0.95
Loss: 0.14365212619304657, Accuracy: 0.95
Loss: 0.1654650866985321, Accuracy: 0.92
Loss: 0.19739732146263123, Accuracy: 0.92
Loss: 0.14261050522327423, Accuracy: 0.96
Loss: 0.28424352407455444, Accuracy: 0.88
Loss: 0.2077140212059021, Accuracy: 0.94
Loss: 0.31134098768234253, Accuracy: 0.91
Loss: 0.2646043300628662, Accuracy: 0.91
Loss: 0.1987074315547943, Accuracy: 0.91
Loss: 0.24068278074264526, Accuracy: 0.91
Loss: 0.20207025110721588, Accuracy: 0.92
Loss: 0.14887601137161255, Accuracy: 0.93
Loss: 0.19217859208583832, Accuracy: 0.93
Loss: 0.23109538853168488, Accuracy: 0.89
Loss: 0.2614268958568573, Accuracy: 0.88
Loss: 0.2998264729976654, Accuracy: 0.87
Loss: 0.18861238658428192, Accuracy: 0.93
Loss: 0.2675873041152954, Accuracy: 0.91
Loss: 0.1938839852809906, Accuracy: 0.94
Los

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60888}
saved model
Epoch: 34


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.21869085729122162, Accuracy: 0.92
Loss: 0.12249128520488739, Accuracy: 0.96
Loss: 0.1364467293024063, Accuracy: 0.95
Loss: 0.16246174275875092, Accuracy: 0.93
Loss: 0.26941609382629395, Accuracy: 0.86
Loss: 0.1583898812532425, Accuracy: 0.94
Loss: 0.16382525861263275, Accuracy: 0.92
Loss: 0.186823308467865, Accuracy: 0.92
Loss: 0.1646043211221695, Accuracy: 0.92
Loss: 0.15210658311843872, Accuracy: 0.92
Loss: 0.1252046823501587, Accuracy: 0.95
Loss: 0.1857956349849701, Accuracy: 0.91
Loss: 0.13034048676490784, Accuracy: 0.96
Loss: 0.15740767121315002, Accuracy: 0.94
Loss: 0.17803360521793365, Accuracy: 0.94
Loss: 0.14483891427516937, Accuracy: 0.96
Loss: 0.2080276608467102, Accuracy: 0.92
Loss: 0.16953939199447632, Accuracy: 0.93
Loss: 0.29040420055389404, Accuracy: 0.88
Loss: 0.14238138496875763, Accuracy: 0.93
Loss: 0.11485113948583603, Accuracy: 0.96
Loss: 0.15199607610702515, Accuracy: 0.92
Loss: 0.2167545109987259, Accuracy: 0.9
Loss: 0.16469888389110565, Accuracy: 0.95
Lo

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61104}
saved model
Epoch: 35


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2298627495765686, Accuracy: 0.9
Loss: 0.16711212694644928, Accuracy: 0.92
Loss: 0.3105304539203644, Accuracy: 0.85
Loss: 0.24660354852676392, Accuracy: 0.88
Loss: 0.15881377458572388, Accuracy: 0.95
Loss: 0.1544579565525055, Accuracy: 0.94
Loss: 0.16152969002723694, Accuracy: 0.93
Loss: 0.15228527784347534, Accuracy: 0.95
Loss: 0.14610977470874786, Accuracy: 0.96
Loss: 0.10418267548084259, Accuracy: 0.98
Loss: 0.1111309602856636, Accuracy: 0.97
Loss: 0.26095420122146606, Accuracy: 0.88
Loss: 0.1987108290195465, Accuracy: 0.93
Loss: 0.21744345128536224, Accuracy: 0.93
Loss: 0.15445172786712646, Accuracy: 0.93
Loss: 0.14930546283721924, Accuracy: 0.94
Loss: 0.15598049759864807, Accuracy: 0.93
Loss: 0.18760830163955688, Accuracy: 0.95
Loss: 0.18379490077495575, Accuracy: 0.95
Loss: 0.2255590558052063, Accuracy: 0.91
Loss: 0.1311049908399582, Accuracy: 0.95
Loss: 0.2714959383010864, Accuracy: 0.88
Loss: 0.20674623548984528, Accuracy: 0.91
Loss: 0.2030375599861145, Accuracy: 0.89
Lo

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60936}
saved model
Epoch: 36


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.27465176582336426, Accuracy: 0.84
Loss: 0.1732875406742096, Accuracy: 0.96
Loss: 0.19774101674556732, Accuracy: 0.93
Loss: 0.15115486085414886, Accuracy: 0.94
Loss: 0.12090345472097397, Accuracy: 0.96
Loss: 0.1068200170993805, Accuracy: 0.98
Loss: 0.1653348207473755, Accuracy: 0.94
Loss: 0.15662118792533875, Accuracy: 0.91
Loss: 0.14896024763584137, Accuracy: 0.93
Loss: 0.15442676842212677, Accuracy: 0.93
Loss: 0.10113608092069626, Accuracy: 0.96
Loss: 0.1562957912683487, Accuracy: 0.91
Loss: 0.13037018477916718, Accuracy: 0.94
Loss: 0.24967096745967865, Accuracy: 0.9
Loss: 0.24135053157806396, Accuracy: 0.89
Loss: 0.1615566611289978, Accuracy: 0.94
Loss: 0.10322659462690353, Accuracy: 0.94
Loss: 0.23789344727993011, Accuracy: 0.9
Loss: 0.18628019094467163, Accuracy: 0.93
Loss: 0.20118053257465363, Accuracy: 0.9
Loss: 0.20616558194160461, Accuracy: 0.91
Loss: 0.2222118228673935, Accuracy: 0.92
Loss: 0.19185027480125427, Accuracy: 0.9
Loss: 0.20803199708461761, Accuracy: 0.91
Lo

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60392}
saved model
Epoch: 37


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.11416777968406677, Accuracy: 0.95
Loss: 0.27632561326026917, Accuracy: 0.89
Loss: 0.17093881964683533, Accuracy: 0.91
Loss: 0.10255111753940582, Accuracy: 0.95
Loss: 0.2936195135116577, Accuracy: 0.86
Loss: 0.34073829650878906, Accuracy: 0.84
Loss: 0.18628084659576416, Accuracy: 0.94
Loss: 0.11115196347236633, Accuracy: 0.95
Loss: 0.19637127220630646, Accuracy: 0.91
Loss: 0.2613203525543213, Accuracy: 0.89
Loss: 0.28458964824676514, Accuracy: 0.9
Loss: 0.13829036056995392, Accuracy: 0.97
Loss: 0.19272497296333313, Accuracy: 0.95
Loss: 0.161763533949852, Accuracy: 0.95
Loss: 0.3098893165588379, Accuracy: 0.88
Loss: 0.15200816094875336, Accuracy: 0.94
Loss: 0.1813039630651474, Accuracy: 0.91
Loss: 0.14540445804595947, Accuracy: 0.95
Loss: 0.24358130991458893, Accuracy: 0.88
Loss: 0.22206783294677734, Accuracy: 0.9
Loss: 0.2032308429479599, Accuracy: 0.91
Loss: 0.20901890099048615, Accuracy: 0.9
Loss: 0.13848070800304413, Accuracy: 0.95
Loss: 0.22789858281612396, Accuracy: 0.91
Lo

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60904}
saved model
Epoch: 38


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.11792024970054626, Accuracy: 0.96
Loss: 0.16231024265289307, Accuracy: 0.94
Loss: 0.15438319742679596, Accuracy: 0.93
Loss: 0.18248631060123444, Accuracy: 0.92
Loss: 0.132557675242424, Accuracy: 0.94
Loss: 0.15329094231128693, Accuracy: 0.94
Loss: 0.1285632997751236, Accuracy: 0.95
Loss: 0.10299177467823029, Accuracy: 0.95
Loss: 0.1009957492351532, Accuracy: 0.98
Loss: 0.1746903657913208, Accuracy: 0.92
Loss: 0.10624660551548004, Accuracy: 0.95
Loss: 0.17023494839668274, Accuracy: 0.93
Loss: 0.1779995709657669, Accuracy: 0.91
Loss: 0.18405082821846008, Accuracy: 0.91
Loss: 0.15523618459701538, Accuracy: 0.96
Loss: 0.1116601973772049, Accuracy: 0.94
Loss: 0.11844068765640259, Accuracy: 0.94
Loss: 0.1528584510087967, Accuracy: 0.93
Loss: 0.15748785436153412, Accuracy: 0.94
Loss: 0.2135777324438095, Accuracy: 0.9
Loss: 0.13970132172107697, Accuracy: 0.93
Loss: 0.24855123460292816, Accuracy: 0.9
Loss: 0.19695021212100983, Accuracy: 0.94
Loss: 0.14652197062969208, Accuracy: 0.96
Los

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60168}
saved model
Epoch: 39


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.1518017202615738, Accuracy: 0.94
Loss: 0.18745379149913788, Accuracy: 0.93
Loss: 0.12530732154846191, Accuracy: 0.96
Loss: 0.1579606682062149, Accuracy: 0.96
Loss: 0.12283818423748016, Accuracy: 0.95
Loss: 0.13049063086509705, Accuracy: 0.92
Loss: 0.13528212904930115, Accuracy: 0.94
Loss: 0.1588747501373291, Accuracy: 0.93
Loss: 0.17498290538787842, Accuracy: 0.92
Loss: 0.15212875604629517, Accuracy: 0.93
Loss: 0.1441577970981598, Accuracy: 0.94
Loss: 0.1930728405714035, Accuracy: 0.91
Loss: 0.06388261169195175, Accuracy: 0.98
Loss: 0.16495868563652039, Accuracy: 0.94
Loss: 0.09020349383354187, Accuracy: 0.97
Loss: 0.13260728120803833, Accuracy: 0.93
Loss: 0.1297728419303894, Accuracy: 0.95
Loss: 0.18295206129550934, Accuracy: 0.92
Loss: 0.13479001820087433, Accuracy: 0.94
Loss: 0.16931603848934174, Accuracy: 0.93
Loss: 0.13428065180778503, Accuracy: 0.95
Loss: 0.15958555042743683, Accuracy: 0.89
Loss: 0.14068718254566193, Accuracy: 0.92
Loss: 0.13214929401874542, Accuracy: 0.9

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.59856}
saved model
Epoch: 40


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.13126124441623688, Accuracy: 0.93
Loss: 0.09955421090126038, Accuracy: 0.94
Loss: 0.10275115072727203, Accuracy: 0.96
Loss: 0.1435774862766266, Accuracy: 0.93
Loss: 0.16190224885940552, Accuracy: 0.93
Loss: 0.07985430955886841, Accuracy: 0.98
Loss: 0.11578050255775452, Accuracy: 0.96
Loss: 0.040943779051303864, Accuracy: 0.99
Loss: 0.11632095277309418, Accuracy: 0.94
Loss: 0.14691044390201569, Accuracy: 0.91
Loss: 0.10049227625131607, Accuracy: 0.97
Loss: 0.16294124722480774, Accuracy: 0.94
Loss: 0.33105218410491943, Accuracy: 0.88
Loss: 0.11348812282085419, Accuracy: 0.96
Loss: 0.14382809400558472, Accuracy: 0.94
Loss: 0.19911180436611176, Accuracy: 0.9
Loss: 0.39033129811286926, Accuracy: 0.86
Loss: 0.08305501937866211, Accuracy: 0.96
Loss: 0.1215275451540947, Accuracy: 0.95
Loss: 0.22277361154556274, Accuracy: 0.91
Loss: 0.2353353351354599, Accuracy: 0.87
Loss: 0.2354823648929596, Accuracy: 0.91
Loss: 0.16614669561386108, Accuracy: 0.95
Loss: 0.16792207956314087, Accuracy: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6}
saved model
Epoch: 41


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.13321788609027863, Accuracy: 0.94
Loss: 0.16838988661766052, Accuracy: 0.92
Loss: 0.1333666443824768, Accuracy: 0.93
Loss: 0.08484615385532379, Accuracy: 0.95
Loss: 0.0836251899600029, Accuracy: 0.95
Loss: 0.17496177554130554, Accuracy: 0.94
Loss: 0.09316890686750412, Accuracy: 0.97
Loss: 0.10049615800380707, Accuracy: 0.96
Loss: 0.08119388669729233, Accuracy: 0.98
Loss: 0.15891686081886292, Accuracy: 0.93
Loss: 0.13137347996234894, Accuracy: 0.96
Loss: 0.1481444090604782, Accuracy: 0.94
Loss: 0.11913524568080902, Accuracy: 0.97
Loss: 0.09141165018081665, Accuracy: 0.97
Loss: 0.18133559823036194, Accuracy: 0.88
Loss: 0.15778952836990356, Accuracy: 0.95
Loss: 0.12520477175712585, Accuracy: 0.94
Loss: 0.15679438412189484, Accuracy: 0.91
Loss: 0.11478690057992935, Accuracy: 0.97
Loss: 0.15745985507965088, Accuracy: 0.93
Loss: 0.25009670853614807, Accuracy: 0.9
Loss: 0.08172839134931564, Accuracy: 0.98
Loss: 0.12341798096895218, Accuracy: 0.95
Loss: 0.10880488157272339, Accuracy: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6052}
saved model
Epoch: 42


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.07769729197025299, Accuracy: 0.97
Loss: 0.13738054037094116, Accuracy: 0.96
Loss: 0.09343542158603668, Accuracy: 0.95
Loss: 0.1831527203321457, Accuracy: 0.91
Loss: 0.10439004749059677, Accuracy: 0.95
Loss: 0.20419609546661377, Accuracy: 0.92
Loss: 0.14143690466880798, Accuracy: 0.94
Loss: 0.12898315489292145, Accuracy: 0.94
Loss: 0.09775853902101517, Accuracy: 0.96
Loss: 0.1976974606513977, Accuracy: 0.9
Loss: 0.1333020180463791, Accuracy: 0.94
Loss: 0.12201005965471268, Accuracy: 0.94
Loss: 0.08196661621332169, Accuracy: 0.99
Loss: 0.13384848833084106, Accuracy: 0.94
Loss: 0.16305692493915558, Accuracy: 0.93
Loss: 0.16834206879138947, Accuracy: 0.92
Loss: 0.08992251753807068, Accuracy: 0.96
Loss: 0.12546281516551971, Accuracy: 0.93
Loss: 0.14280757308006287, Accuracy: 0.95
Loss: 0.224192276597023, Accuracy: 0.88
Loss: 0.13048960268497467, Accuracy: 0.96
Loss: 0.18095825612545013, Accuracy: 0.91
Loss: 0.1302037388086319, Accuracy: 0.93
Loss: 0.12410309910774231, Accuracy: 0.93

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60592}
saved model
Epoch: 43


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.10394768416881561, Accuracy: 0.96
Loss: 0.045483626425266266, Accuracy: 0.99
Loss: 0.08806139975786209, Accuracy: 0.98
Loss: 0.10022979974746704, Accuracy: 0.96
Loss: 0.09261549264192581, Accuracy: 0.96
Loss: 0.12226663529872894, Accuracy: 0.93
Loss: 0.1074814423918724, Accuracy: 0.94
Loss: 0.1472824215888977, Accuracy: 0.94
Loss: 0.2322303205728531, Accuracy: 0.92
Loss: 0.05126675218343735, Accuracy: 0.98
Loss: 0.12364155799150467, Accuracy: 0.94
Loss: 0.12229987233877182, Accuracy: 0.96
Loss: 0.16776792705059052, Accuracy: 0.94
Loss: 0.1346559375524521, Accuracy: 0.94
Loss: 0.19058330357074738, Accuracy: 0.92
Loss: 0.17129825055599213, Accuracy: 0.91
Loss: 0.11386986821889877, Accuracy: 0.93
Loss: 0.14246419072151184, Accuracy: 0.95
Loss: 0.11449071019887924, Accuracy: 0.97
Loss: 0.1569310873746872, Accuracy: 0.95
Loss: 0.13214462995529175, Accuracy: 0.92
Loss: 0.11926670372486115, Accuracy: 0.93
Loss: 0.10367489606142044, Accuracy: 0.95
Loss: 0.07952144742012024, Accuracy: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6088}
saved model
Epoch: 44


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.06932856887578964, Accuracy: 0.96
Loss: 0.11815192550420761, Accuracy: 0.96
Loss: 0.2694569528102875, Accuracy: 0.89
Loss: 0.09051179140806198, Accuracy: 0.96
Loss: 0.10305342823266983, Accuracy: 0.95
Loss: 0.05429180711507797, Accuracy: 0.99
Loss: 0.10959421098232269, Accuracy: 0.98
Loss: 0.08927323669195175, Accuracy: 0.95
Loss: 0.12316036969423294, Accuracy: 0.93
Loss: 0.06908321380615234, Accuracy: 0.98
Loss: 0.10238433629274368, Accuracy: 0.95
Loss: 0.09229260683059692, Accuracy: 0.96
Loss: 0.1311962753534317, Accuracy: 0.92
Loss: 0.18985439836978912, Accuracy: 0.9
Loss: 0.13394418358802795, Accuracy: 0.95
Loss: 0.09168203175067902, Accuracy: 0.98
Loss: 0.09692781418561935, Accuracy: 0.95
Loss: 0.15614861249923706, Accuracy: 0.94
Loss: 0.20128479599952698, Accuracy: 0.94
Loss: 0.1387788951396942, Accuracy: 0.96
Loss: 0.17514610290527344, Accuracy: 0.92
Loss: 0.22109751403331757, Accuracy: 0.89
Loss: 0.15794628858566284, Accuracy: 0.93
Loss: 0.0987027958035469, Accuracy: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60528}
saved model
Epoch: 45


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.10354232043027878, Accuracy: 0.93
Loss: 0.06996703892946243, Accuracy: 0.99
Loss: 0.18750552833080292, Accuracy: 0.93
Loss: 0.13128861784934998, Accuracy: 0.93
Loss: 0.10623747110366821, Accuracy: 0.96
Loss: 0.10390264540910721, Accuracy: 0.94
Loss: 0.11811498552560806, Accuracy: 0.92
Loss: 0.13429255783557892, Accuracy: 0.95
Loss: 0.1422652006149292, Accuracy: 0.95
Loss: 0.09181922674179077, Accuracy: 0.95
Loss: 0.1122845932841301, Accuracy: 0.95
Loss: 0.10139130800962448, Accuracy: 0.98
Loss: 0.11039144545793533, Accuracy: 0.95
Loss: 0.11481699347496033, Accuracy: 0.96
Loss: 0.148000106215477, Accuracy: 0.94
Loss: 0.11990435421466827, Accuracy: 0.95
Loss: 0.14597216248512268, Accuracy: 0.96
Loss: 0.12123185396194458, Accuracy: 0.95
Loss: 0.07984570413827896, Accuracy: 0.97
Loss: 0.050088897347450256, Accuracy: 0.98
Loss: 0.22633907198905945, Accuracy: 0.93
Loss: 0.12608236074447632, Accuracy: 0.95
Loss: 0.09691056609153748, Accuracy: 0.97
Loss: 0.1576603353023529, Accuracy: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.59048}
saved model
Epoch: 46


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.09202826768159866, Accuracy: 0.96
Loss: 0.05728016793727875, Accuracy: 0.98
Loss: 0.058014653623104095, Accuracy: 0.98
Loss: 0.09831242263317108, Accuracy: 0.93
Loss: 0.145015150308609, Accuracy: 0.92
Loss: 0.07635408639907837, Accuracy: 0.98
Loss: 0.11935214698314667, Accuracy: 0.96
Loss: 0.11493992805480957, Accuracy: 0.96
Loss: 0.14126186072826385, Accuracy: 0.95
Loss: 0.10258188098669052, Accuracy: 0.95
Loss: 0.07654036581516266, Accuracy: 0.95
Loss: 0.07768821716308594, Accuracy: 0.97
Loss: 0.10165343433618546, Accuracy: 0.96
Loss: 0.10192898660898209, Accuracy: 0.95
Loss: 0.13121938705444336, Accuracy: 0.96
Loss: 0.11364677548408508, Accuracy: 0.93
Loss: 0.136021688580513, Accuracy: 0.94
Loss: 0.10120879113674164, Accuracy: 0.96
Loss: 0.0860944464802742, Accuracy: 0.97
Loss: 0.10555274784564972, Accuracy: 0.97
Loss: 0.08167804777622223, Accuracy: 0.96
Loss: 0.13331784307956696, Accuracy: 0.94
Loss: 0.15542222559452057, Accuracy: 0.92
Loss: 0.07637979090213776, Accuracy: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.596}
saved model
Epoch: 47


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.065459705889225, Accuracy: 0.98
Loss: 0.0895276814699173, Accuracy: 0.96
Loss: 0.1733100861310959, Accuracy: 0.92
Loss: 0.058823589235544205, Accuracy: 0.99
Loss: 0.1573769599199295, Accuracy: 0.94
Loss: 0.20289890468120575, Accuracy: 0.93
Loss: 0.1236470490694046, Accuracy: 0.95
Loss: 0.11341109871864319, Accuracy: 0.96
Loss: 0.13476847112178802, Accuracy: 0.94
Loss: 0.13537436723709106, Accuracy: 0.97
Loss: 0.07159429043531418, Accuracy: 0.96
Loss: 0.1297394186258316, Accuracy: 0.93
Loss: 0.11374234408140182, Accuracy: 0.97
Loss: 0.11511829495429993, Accuracy: 0.95
Loss: 0.07588373869657516, Accuracy: 0.97
Loss: 0.07168499380350113, Accuracy: 0.99
Loss: 0.14925481379032135, Accuracy: 0.94
Loss: 0.08275246620178223, Accuracy: 0.95
Loss: 0.09572780877351761, Accuracy: 0.96
Loss: 0.2687523663043976, Accuracy: 0.91
Loss: 0.11382968723773956, Accuracy: 0.93
Loss: 0.09549564123153687, Accuracy: 0.97
Loss: 0.177683025598526, Accuracy: 0.93
Loss: 0.08940383046865463, Accuracy: 0.96
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.5948}
saved model
Epoch: 48


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.08778943121433258, Accuracy: 0.99
Loss: 0.16105370223522186, Accuracy: 0.94
Loss: 0.08220705389976501, Accuracy: 0.95
Loss: 0.07787182182073593, Accuracy: 0.97
Loss: 0.0915243998169899, Accuracy: 0.96
Loss: 0.09487205743789673, Accuracy: 0.95
Loss: 0.033935319632291794, Accuracy: 1.0
Loss: 0.04927386716008186, Accuracy: 0.97
Loss: 0.05673044174909592, Accuracy: 1.0
Loss: 0.05232391878962517, Accuracy: 0.99
Loss: 0.13067565858364105, Accuracy: 0.94
Loss: 0.1144631952047348, Accuracy: 0.94
Loss: 0.15464641153812408, Accuracy: 0.94
Loss: 0.1495627462863922, Accuracy: 0.93
Loss: 0.0759919136762619, Accuracy: 0.96
Loss: 0.14757061004638672, Accuracy: 0.95
Loss: 0.06639459729194641, Accuracy: 0.96
Loss: 0.08106418699026108, Accuracy: 0.95
Loss: 0.11168432980775833, Accuracy: 0.95
Loss: 0.28088369965553284, Accuracy: 0.84
Loss: 0.11511273682117462, Accuracy: 0.93
Loss: 0.11547751724720001, Accuracy: 0.95
Loss: 0.299515962600708, Accuracy: 0.92
Loss: 0.079022116959095, Accuracy: 0.98
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60144}
saved model
Epoch: 49


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.12841780483722687, Accuracy: 0.93
Loss: 0.08401262015104294, Accuracy: 0.97
Loss: 0.1365337073802948, Accuracy: 0.93
Loss: 0.08215785771608353, Accuracy: 0.97
Loss: 0.06760230660438538, Accuracy: 0.96
Loss: 0.14563098549842834, Accuracy: 0.94
Loss: 0.09546879678964615, Accuracy: 0.95
Loss: 0.0674472376704216, Accuracy: 0.98
Loss: 0.1319706290960312, Accuracy: 0.97
Loss: 0.10262922942638397, Accuracy: 0.95
Loss: 0.08252410590648651, Accuracy: 0.99
Loss: 0.08442489802837372, Accuracy: 0.96
Loss: 0.09942060708999634, Accuracy: 0.94
Loss: 0.07538425177335739, Accuracy: 0.96
Loss: 0.043820735067129135, Accuracy: 0.99
Loss: 0.11426311731338501, Accuracy: 0.97
Loss: 0.07761441171169281, Accuracy: 0.96
Loss: 0.04873674735426903, Accuracy: 0.99
Loss: 0.10058242827653885, Accuracy: 0.94
Loss: 0.08797258138656616, Accuracy: 0.97
Loss: 0.02841714769601822, Accuracy: 0.98
Loss: 0.060032982379198074, Accuracy: 0.98
Loss: 0.06743095070123672, Accuracy: 0.97
Loss: 0.10952449589967728, Accuracy

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.59728}


## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/rerun_small_model_Nov_16_epoch_last.pt')

In [None]:

# import torch
# checkpoint = torch.load('/content/drive/MyDrive/saved_model/small_network_model.pt')
# model.load_state_dict(checkpoint)
# model.eval()

In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(test_dataloader):
      # get the inputs; 
      inputs = batch["input_ids"].to(device)
      # attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.5928}


## Inference

In [None]:
text = "I hated this movie, it's really bad."

input_ids = tokenizer(text, return_tensors="pt").input_ids

# forward pass
outputs = model(inputs=input_ids.to(device))
logits = outputs.logits 
predicted_class_idx = logits.argmax(-1).item()

print("Predicted:", model.config.id2label[predicted_class_idx])