## Set-up environment

As usual, we first install HuggingFace Transformers, and Datasets.

In [None]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import os
from pytorch_lightning.loggers import TensorBoardLogger
%load_ext autoreload
%autoreload 2

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Nov 17 05:45:55 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   28C    P0    44W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [None]:
!pip install -q datasets

## Prepare data

Here we take a small portion of the IMDB dataset, a binary text classification dataset ("is a movie review positive or negative?").

In [None]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train', 'test[:6250]+test[-6250:]'])
train_ds, val_ds = load_dataset("imdb", split=['train', 'test[6250:12500]+test[-12500:-6250]'])

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:

print(test_ds)
print(val_ds)

Dataset({
    features: ['text', 'label'],
    num_rows: 12500
})
Dataset({
    features: ['text', 'label'],
    num_rows: 12500
})


We create id2label and label2id mappings, which are handy at inference time.

In [None]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [None]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


Next, we prepare the data for the model using the tokenizer. 

In [None]:
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
val_ds = val_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Downloading:   0%|          | 0.00/668 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/879 [00:00<?, ?B/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

We set the format to PyTorch tensors, and create familiar PyTorch dataloaders.

In [None]:
train_ds.set_format(type="torch", columns=['input_ids', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'label'])
val_ds.set_format(type="torch", columns=['input_ids', 'label'])

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=100, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=25)
val_dataloader = DataLoader(val_ds, batch_size=25)

Here we verify some things (always important to check out your data!).

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

label torch.Size([100])
input_ids torch.Size([100, 2048])


In [None]:
tokenizer.decode(batch['input_ids'][3])

"[CLS]This film would usually classify as the worst movie production ever. Ever. But in my opinion it is possibly the funniest. The horrifying direction and screenplay makes this film priceless. I bought the movie whilst sifting through the bargain DVD's at my local pound shop. Me and some friends then watched it, admittedly whilst rather drunk. It soon occurred that this wasn't any normal film. Instead a priceless relic of what will probably be James Cahill's last film. At first we were confused and were screaming for the DVD player to be turned off but thankfully in our abnormal state no-one could be bothered. Instead we watched the film right through. At the end we soon realised we had found any wasters dream, something that you can acceptably laugh at for hours, whilst laughing for all the wrong reasons. We soon showed all our other friends and they too agreed, this wasn't a work of abysmal film. This was a film that you can truly wet yourself laughing at. This was a film that anyo

In [None]:
# batch['label']

In [None]:
import numpy as np
train_ds['label'].double().mean()
print(train_ds['label'][12499])

tensor(0)


## Define model

Next, we define our model, and put it on the GPU.

In [None]:
# preprocessor we customized to use the tagkop encoder
# Must include tagkop_encoding_functions.py and pos_embeddings_IMDB_tSNE_2048x64.pth in base directory
from tagkop_encoding_functions import (
    PerceiverImagePreprocessor,
    TagkopPerceiverTextPreprocessor,
)
from transformers import PerceiverForSequenceClassification

import torch

from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder,
    PerceiverTextPreprocessor,
    PerceiverClassificationDecoder
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


config = PerceiverConfig(
    num_self_attends_per_block = 4,
    d_model = 64
)

print('config', config)

MAE-Derived Perceiver Encodings
preprocessor = TagkopPerceiverTextPreprocessor(config)


decoder = PerceiverClassificationDecoder(config,
                                          num_channels=config.d_latents,
                                          trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
                                          use_query_residual=True,
                                         )

model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)



model.to(device)

ModuleNotFoundError: ignored

In [None]:
# you can then do a forward pass as follows:
tokenizer = PerceiverTokenizer()
text = "hello world"
inputs = tokenizer(text, return_tensors="pt").input_ids
print(inputs)
inputs.to(device)
with torch.no_grad():
   outputs = model(inputs=inputs.to(device))
logits = outputs.logits
print('list(logits.shape): ', list(logits.shape))
# to train, one can train the model using standard cross-entropy:
criterion = torch.nn.CrossEntropyLoss()
labels = torch.tensor([1]).to(device)
loss = criterion(logits, labels)

## Train the model

Here we train the model using native PyTorch.

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

from tqdm.notebook import tqdm
from datasets import load_metric

optimizer = AdamW(model.parameters(), lr=1e-4)

model.train()

batch = next(iter(train_dataloader))
for epoch in range(50):  # loop over the dataset multiple times

    epoch_accs = np.array([])
    # Add model save location below
    # torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/small_model_tagop_Nov_16_epoch_' + str(epoch) + '.pt')
    print('saved model')
    print("Epoch:", epoch)
    # for i in range(2):
    for batch in tqdm(train_dataloader):
         batch_accs = np.array([])
         # get the inputs; 
         inputs = batch["input_ids"].to(device)
        #  attention_mask = batch["attention_mask"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs)
         logits = outputs.logits
         
         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         loss = criterion(logits, labels)
         loss.backward()
         optimizer.step()
         
         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")
         batch_accs = np.append(batch_accs, accuracy)
    
    print('Batch Accuracy Avg: ', np.mean(batch_accs))
    
    accuracy = load_metric("accuracy")
    
    # Validation Test
    model.eval()
    for batch in tqdm(val_dataloader):
          # get the inputs; 
          inputs = batch["input_ids"].to(device)
          labels = batch["label"].to(device)

          # forward pass
          outputs = model(inputs=inputs)
          logits = outputs.logits 
          predictions = logits.argmax(-1).cpu().detach().numpy()
          references = batch["label"].numpy()
          accuracy.add_batch(predictions=predictions, references=references)

    final_score = accuracy.compute()
    print("Accuracy on validation set:", final_score)
    epoch_accs = np.append(epoch_accs, final_score)



saved model
Epoch: 0


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6997937560081482, Accuracy: 0.59
Loss: 5.860103130340576, Accuracy: 0.47
Loss: 1.0586258172988892, Accuracy: 0.49
Loss: 0.7263509631156921, Accuracy: 0.41
Loss: 1.219814658164978, Accuracy: 0.47
Loss: 0.7055288553237915, Accuracy: 0.54
Loss: 0.889110803604126, Accuracy: 0.56
Loss: 0.9297664165496826, Accuracy: 0.54
Loss: 0.7063116431236267, Accuracy: 0.51
Loss: 0.8001638650894165, Accuracy: 0.52
Loss: 0.8801435828208923, Accuracy: 0.47
Loss: 0.7252963781356812, Accuracy: 0.44
Loss: 0.8669946193695068, Accuracy: 0.4
Loss: 0.8065783977508545, Accuracy: 0.47
Loss: 0.7030810713768005, Accuracy: 0.44
Loss: 0.8251209259033203, Accuracy: 0.42
Loss: 0.832796037197113, Accuracy: 0.43
Loss: 0.701611340045929, Accuracy: 0.48
Loss: 0.68707674741745, Accuracy: 0.58
Loss: 0.7653590440750122, Accuracy: 0.56
Loss: 0.8590285778045654, Accuracy: 0.47
Loss: 0.7559733390808105, Accuracy: 0.41
Loss: 0.7004354596138, Accuracy: 0.56
Loss: 0.9003586769104004, Accuracy: 0.49
Loss: 0.9321032762527466, A



Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.50624}
saved model
Epoch: 1


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6876648664474487, Accuracy: 0.57
Loss: 0.6884174942970276, Accuracy: 0.53
Loss: 0.6905404925346375, Accuracy: 0.51
Loss: 0.6925503015518188, Accuracy: 0.47
Loss: 0.7394405603408813, Accuracy: 0.42
Loss: 0.6944580674171448, Accuracy: 0.49
Loss: 0.6814687252044678, Accuracy: 0.6
Loss: 0.6943730115890503, Accuracy: 0.51
Loss: 0.702823281288147, Accuracy: 0.47
Loss: 0.7085275053977966, Accuracy: 0.37
Loss: 0.6844261884689331, Accuracy: 0.55
Loss: 0.6995547413825989, Accuracy: 0.47
Loss: 0.6983396410942078, Accuracy: 0.52
Loss: 0.6901745796203613, Accuracy: 0.54
Loss: 0.6884239315986633, Accuracy: 0.55
Loss: 0.6962190270423889, Accuracy: 0.51
Loss: 0.6942296028137207, Accuracy: 0.47
Loss: 0.6917767524719238, Accuracy: 0.57
Loss: 0.6894019246101379, Accuracy: 0.56
Loss: 0.6928794980049133, Accuracy: 0.5
Loss: 0.6991918683052063, Accuracy: 0.48
Loss: 0.6876569390296936, Accuracy: 0.58
Loss: 0.6883142590522766, Accuracy: 0.57
Loss: 0.6918371319770813, Accuracy: 0.53
Loss: 0.69743919372

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.5}
saved model
Epoch: 2


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6960604190826416, Accuracy: 0.51
Loss: 0.6881473064422607, Accuracy: 0.56
Loss: 0.7005516290664673, Accuracy: 0.54
Loss: 0.7346038818359375, Accuracy: 0.46
Loss: 0.720877468585968, Accuracy: 0.4
Loss: 0.6883983612060547, Accuracy: 0.55
Loss: 0.7336866855621338, Accuracy: 0.48
Loss: 0.7830134630203247, Accuracy: 0.42
Loss: 0.7145869731903076, Accuracy: 0.48
Loss: 0.6925964951515198, Accuracy: 0.54
Loss: 0.7211853861808777, Accuracy: 0.5
Loss: 0.7352039813995361, Accuracy: 0.49
Loss: 0.6972746253013611, Accuracy: 0.52
Loss: 0.6921192407608032, Accuracy: 0.55
Loss: 0.6984204649925232, Accuracy: 0.52
Loss: 0.7497259378433228, Accuracy: 0.42
Loss: 0.6967321634292603, Accuracy: 0.49
Loss: 0.6922349333763123, Accuracy: 0.53
Loss: 0.7370051741600037, Accuracy: 0.47
Loss: 0.7090666890144348, Accuracy: 0.51
Loss: 0.6945560574531555, Accuracy: 0.48
Loss: 0.7065544724464417, Accuracy: 0.47
Loss: 0.6844021677970886, Accuracy: 0.57
Loss: 0.693297803401947, Accuracy: 0.55
Loss: 0.717640817165

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.5}
saved model
Epoch: 3


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.7250745892524719, Accuracy: 0.44
Loss: 0.688485860824585, Accuracy: 0.56
Loss: 0.6933708190917969, Accuracy: 0.48
Loss: 0.6930732131004333, Accuracy: 0.51
Loss: 0.7044492363929749, Accuracy: 0.42
Loss: 0.6917608380317688, Accuracy: 0.49
Loss: 0.6967355608940125, Accuracy: 0.43
Loss: 0.6948646306991577, Accuracy: 0.47
Loss: 0.6958025097846985, Accuracy: 0.47
Loss: 0.6933327317237854, Accuracy: 0.58
Loss: 0.6935656666755676, Accuracy: 0.51
Loss: 0.6918783783912659, Accuracy: 0.53
Loss: 0.6947967410087585, Accuracy: 0.53
Loss: 0.7079599499702454, Accuracy: 0.47
Loss: 0.6911702156066895, Accuracy: 0.53
Loss: 0.6905354261398315, Accuracy: 0.5
Loss: 0.6940131187438965, Accuracy: 0.5
Loss: 0.6938678026199341, Accuracy: 0.53
Loss: 0.6862005591392517, Accuracy: 0.56
Loss: 0.6860588788986206, Accuracy: 0.56
Loss: 0.6869661808013916, Accuracy: 0.55
Loss: 0.6907328963279724, Accuracy: 0.54
Loss: 0.7207450270652771, Accuracy: 0.41
Loss: 0.6951367855072021, Accuracy: 0.46
Loss: 0.70593035221

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.50376}
saved model
Epoch: 4


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6956081390380859, Accuracy: 0.43
Loss: 0.690199077129364, Accuracy: 0.52
Loss: 0.6938334703445435, Accuracy: 0.46
Loss: 0.6906189918518066, Accuracy: 0.6
Loss: 0.690032422542572, Accuracy: 0.59
Loss: 0.6931760311126709, Accuracy: 0.49
Loss: 0.6867191791534424, Accuracy: 0.61
Loss: 0.7114690542221069, Accuracy: 0.39
Loss: 0.6841830611228943, Accuracy: 0.58
Loss: 0.6910602450370789, Accuracy: 0.56
Loss: 0.6895921230316162, Accuracy: 0.6
Loss: 0.6917084455490112, Accuracy: 0.51
Loss: 0.6851580739021301, Accuracy: 0.66
Loss: 0.6894033551216125, Accuracy: 0.52
Loss: 0.6848595142364502, Accuracy: 0.55
Loss: 0.6727922558784485, Accuracy: 0.59
Loss: 0.7355667948722839, Accuracy: 0.42
Loss: 0.687350332736969, Accuracy: 0.54
Loss: 0.6844865679740906, Accuracy: 0.6
Loss: 0.7024347186088562, Accuracy: 0.48
Loss: 0.7153236865997314, Accuracy: 0.45
Loss: 0.6825596690177917, Accuracy: 0.58
Loss: 0.6866722106933594, Accuracy: 0.51
Loss: 0.6957660913467407, Accuracy: 0.5
Loss: 0.710485517978668

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62528}
saved model
Epoch: 5


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6594849228858948, Accuracy: 0.61
Loss: 0.6786109209060669, Accuracy: 0.54
Loss: 0.6282404065132141, Accuracy: 0.68
Loss: 0.6767128109931946, Accuracy: 0.57
Loss: 0.6557548642158508, Accuracy: 0.62
Loss: 0.6554307341575623, Accuracy: 0.6
Loss: 0.6574298739433289, Accuracy: 0.58
Loss: 0.6550685167312622, Accuracy: 0.65
Loss: 0.6457695960998535, Accuracy: 0.69
Loss: 0.6588848829269409, Accuracy: 0.71
Loss: 0.6426658630371094, Accuracy: 0.64
Loss: 0.6700355410575867, Accuracy: 0.52
Loss: 0.6006492972373962, Accuracy: 0.69
Loss: 0.7234657406806946, Accuracy: 0.48
Loss: 0.5971710085868835, Accuracy: 0.69
Loss: 0.6147657632827759, Accuracy: 0.65
Loss: 0.6201580762863159, Accuracy: 0.62
Loss: 0.5670078992843628, Accuracy: 0.74
Loss: 0.6799859404563904, Accuracy: 0.61
Loss: 0.5993856191635132, Accuracy: 0.69
Loss: 0.585237979888916, Accuracy: 0.69
Loss: 0.6629307270050049, Accuracy: 0.6
Loss: 0.7580854892730713, Accuracy: 0.58
Loss: 0.6403773427009583, Accuracy: 0.62
Loss: 0.64284348487

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62984}
saved model
Epoch: 6


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5736016035079956, Accuracy: 0.69
Loss: 0.6369208693504333, Accuracy: 0.64
Loss: 0.6211540102958679, Accuracy: 0.67
Loss: 0.6168565154075623, Accuracy: 0.63
Loss: 0.5701101422309875, Accuracy: 0.75
Loss: 0.5961709022521973, Accuracy: 0.67
Loss: 0.6107425689697266, Accuracy: 0.67
Loss: 0.5718780159950256, Accuracy: 0.72
Loss: 0.675110936164856, Accuracy: 0.58
Loss: 0.6023373603820801, Accuracy: 0.66
Loss: 0.5940759181976318, Accuracy: 0.71
Loss: 0.5982183218002319, Accuracy: 0.7
Loss: 0.6096442341804504, Accuracy: 0.63
Loss: 0.6009698510169983, Accuracy: 0.7
Loss: 0.6023644804954529, Accuracy: 0.69
Loss: 0.5985700488090515, Accuracy: 0.73
Loss: 0.6805984973907471, Accuracy: 0.55
Loss: 0.6127275824546814, Accuracy: 0.71
Loss: 0.712043046951294, Accuracy: 0.57
Loss: 0.5690233707427979, Accuracy: 0.7
Loss: 0.6009969115257263, Accuracy: 0.7
Loss: 0.6755563616752625, Accuracy: 0.62
Loss: 0.6580053567886353, Accuracy: 0.64
Loss: 0.6575368642807007, Accuracy: 0.62
Loss: 0.59280896186828

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60816}
saved model
Epoch: 7


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6091173887252808, Accuracy: 0.67
Loss: 0.5798378586769104, Accuracy: 0.82
Loss: 0.6311081051826477, Accuracy: 0.65
Loss: 0.5867358446121216, Accuracy: 0.73
Loss: 0.6110517382621765, Accuracy: 0.67
Loss: 0.6375366449356079, Accuracy: 0.65
Loss: 0.5521501898765564, Accuracy: 0.69
Loss: 0.5592952966690063, Accuracy: 0.7
Loss: 0.6049760580062866, Accuracy: 0.69
Loss: 0.6346427202224731, Accuracy: 0.66
Loss: 0.5810177326202393, Accuracy: 0.69
Loss: 0.5346622467041016, Accuracy: 0.76
Loss: 0.5579668283462524, Accuracy: 0.72
Loss: 0.6191281080245972, Accuracy: 0.72
Loss: 0.5555858612060547, Accuracy: 0.71
Loss: 0.6700376868247986, Accuracy: 0.6
Loss: 0.5908315181732178, Accuracy: 0.68
Loss: 0.6822964549064636, Accuracy: 0.61
Loss: 0.6195418238639832, Accuracy: 0.69
Loss: 0.6277719736099243, Accuracy: 0.7
Loss: 0.618732213973999, Accuracy: 0.65
Loss: 0.5283442139625549, Accuracy: 0.77
Loss: 0.5940207242965698, Accuracy: 0.71
Loss: 0.7032347917556763, Accuracy: 0.59
Loss: 0.594731390476

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6324}
saved model
Epoch: 8


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5220690369606018, Accuracy: 0.77
Loss: 0.5756580233573914, Accuracy: 0.68
Loss: 0.5694621205329895, Accuracy: 0.72
Loss: 0.5631093382835388, Accuracy: 0.71
Loss: 0.5307955145835876, Accuracy: 0.77
Loss: 0.593721866607666, Accuracy: 0.71
Loss: 0.6141173839569092, Accuracy: 0.68
Loss: 0.5737866759300232, Accuracy: 0.7
Loss: 0.5529550313949585, Accuracy: 0.73
Loss: 0.5961859226226807, Accuracy: 0.66
Loss: 0.6397421360015869, Accuracy: 0.65
Loss: 0.664842963218689, Accuracy: 0.6
Loss: 0.6242197751998901, Accuracy: 0.66
Loss: 0.6058371663093567, Accuracy: 0.66
Loss: 0.623290479183197, Accuracy: 0.65
Loss: 0.579964816570282, Accuracy: 0.72
Loss: 0.5791273713111877, Accuracy: 0.73
Loss: 0.5464359521865845, Accuracy: 0.71
Loss: 0.5994261503219604, Accuracy: 0.66
Loss: 0.6854531764984131, Accuracy: 0.58
Loss: 0.5772769451141357, Accuracy: 0.71
Loss: 0.5408257246017456, Accuracy: 0.78
Loss: 0.5680973529815674, Accuracy: 0.71
Loss: 0.5393107533454895, Accuracy: 0.74
Loss: 0.52255672216415

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63848}
saved model
Epoch: 9


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5771596431732178, Accuracy: 0.74
Loss: 0.5968730449676514, Accuracy: 0.7
Loss: 0.4902860224246979, Accuracy: 0.78
Loss: 0.49646803736686707, Accuracy: 0.79
Loss: 0.47113731503486633, Accuracy: 0.81
Loss: 0.5402935147285461, Accuracy: 0.74
Loss: 0.546873152256012, Accuracy: 0.74
Loss: 0.6328594088554382, Accuracy: 0.69
Loss: 0.49872756004333496, Accuracy: 0.78
Loss: 0.6004610657691956, Accuracy: 0.71
Loss: 0.5145317912101746, Accuracy: 0.78
Loss: 0.5513550043106079, Accuracy: 0.7
Loss: 0.6175340414047241, Accuracy: 0.72
Loss: 0.5832383632659912, Accuracy: 0.7
Loss: 0.6040747165679932, Accuracy: 0.71
Loss: 0.538643479347229, Accuracy: 0.76
Loss: 0.6125513315200806, Accuracy: 0.68
Loss: 0.6340511441230774, Accuracy: 0.61
Loss: 0.5699845552444458, Accuracy: 0.69
Loss: 0.6055436134338379, Accuracy: 0.71
Loss: 0.5698255300521851, Accuracy: 0.72
Loss: 0.5752027630805969, Accuracy: 0.77
Loss: 0.5500239729881287, Accuracy: 0.71
Loss: 0.5610905885696411, Accuracy: 0.74
Loss: 0.5994993448

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63648}
saved model
Epoch: 10


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5460675954818726, Accuracy: 0.74
Loss: 0.5517252087593079, Accuracy: 0.77
Loss: 0.5495566725730896, Accuracy: 0.76
Loss: 0.5237377285957336, Accuracy: 0.75
Loss: 0.5509363412857056, Accuracy: 0.76
Loss: 0.4515773355960846, Accuracy: 0.78
Loss: 0.6133741140365601, Accuracy: 0.71
Loss: 0.5435428023338318, Accuracy: 0.71
Loss: 0.47514334321022034, Accuracy: 0.79
Loss: 0.4900457262992859, Accuracy: 0.8
Loss: 0.5881919264793396, Accuracy: 0.68
Loss: 0.4651056230068207, Accuracy: 0.77
Loss: 0.5292748212814331, Accuracy: 0.75
Loss: 0.6042894124984741, Accuracy: 0.67
Loss: 0.5859552025794983, Accuracy: 0.71
Loss: 0.5503884553909302, Accuracy: 0.73
Loss: 0.5528212189674377, Accuracy: 0.78
Loss: 0.5132895112037659, Accuracy: 0.73
Loss: 0.5374352931976318, Accuracy: 0.77
Loss: 0.6917477250099182, Accuracy: 0.64
Loss: 0.6125379800796509, Accuracy: 0.67
Loss: 0.5150631070137024, Accuracy: 0.74
Loss: 0.6173283457756042, Accuracy: 0.68
Loss: 0.5349119305610657, Accuracy: 0.74
Loss: 0.54631233

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63024}
saved model
Epoch: 11


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.4787557125091553, Accuracy: 0.77
Loss: 0.5010474324226379, Accuracy: 0.77
Loss: 0.47467637062072754, Accuracy: 0.75
Loss: 0.5135046243667603, Accuracy: 0.77
Loss: 0.5655401945114136, Accuracy: 0.74
Loss: 0.5929601192474365, Accuracy: 0.71
Loss: 0.590990424156189, Accuracy: 0.71
Loss: 0.42929214239120483, Accuracy: 0.79
Loss: 0.5309746265411377, Accuracy: 0.72
Loss: 0.6377794146537781, Accuracy: 0.69
Loss: 0.4772289991378784, Accuracy: 0.78
Loss: 0.46819865703582764, Accuracy: 0.82
Loss: 0.5729703307151794, Accuracy: 0.71
Loss: 0.5475023984909058, Accuracy: 0.74
Loss: 0.490728497505188, Accuracy: 0.75
Loss: 0.5059670209884644, Accuracy: 0.77
Loss: 0.4176243841648102, Accuracy: 0.86
Loss: 0.5182346701622009, Accuracy: 0.75
Loss: 0.5863945484161377, Accuracy: 0.73
Loss: 0.5824228525161743, Accuracy: 0.75
Loss: 0.5539919137954712, Accuracy: 0.69
Loss: 0.5371127724647522, Accuracy: 0.78
Loss: 0.4403504431247711, Accuracy: 0.78
Loss: 0.43970951437950134, Accuracy: 0.86
Loss: 0.503617

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.63}
saved model
Epoch: 12


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6018043756484985, Accuracy: 0.72
Loss: 0.5267796516418457, Accuracy: 0.76
Loss: 0.5695263147354126, Accuracy: 0.74
Loss: 0.5050003528594971, Accuracy: 0.82
Loss: 0.4492891728878021, Accuracy: 0.79
Loss: 0.48582062125205994, Accuracy: 0.79
Loss: 0.49706923961639404, Accuracy: 0.78
Loss: 0.5221792459487915, Accuracy: 0.76
Loss: 0.4668689966201782, Accuracy: 0.81
Loss: 0.4251071512699127, Accuracy: 0.85
Loss: 0.5342414379119873, Accuracy: 0.74
Loss: 0.39019957184791565, Accuracy: 0.86
Loss: 0.4698043167591095, Accuracy: 0.83
Loss: 0.5612892508506775, Accuracy: 0.77
Loss: 0.3903263509273529, Accuracy: 0.85
Loss: 0.48617884516716003, Accuracy: 0.8
Loss: 0.49318981170654297, Accuracy: 0.77
Loss: 0.5487803220748901, Accuracy: 0.76
Loss: 0.493254691362381, Accuracy: 0.79
Loss: 0.4818846583366394, Accuracy: 0.82
Loss: 0.5692043900489807, Accuracy: 0.72
Loss: 0.5921375155448914, Accuracy: 0.73
Loss: 0.5209249258041382, Accuracy: 0.75
Loss: 0.5327970385551453, Accuracy: 0.76
Loss: 0.53034

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6252}
saved model
Epoch: 13


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.48108384013175964, Accuracy: 0.81
Loss: 0.5154817700386047, Accuracy: 0.81
Loss: 0.5010709166526794, Accuracy: 0.77
Loss: 0.5072630047798157, Accuracy: 0.81
Loss: 0.4279564917087555, Accuracy: 0.85
Loss: 0.42559880018234253, Accuracy: 0.85
Loss: 0.45108702778816223, Accuracy: 0.82
Loss: 0.5263615250587463, Accuracy: 0.75
Loss: 0.5479734539985657, Accuracy: 0.76
Loss: 0.4606199562549591, Accuracy: 0.83
Loss: 0.4565628170967102, Accuracy: 0.84
Loss: 0.5272867679595947, Accuracy: 0.8
Loss: 0.4934660792350769, Accuracy: 0.77
Loss: 0.4455423355102539, Accuracy: 0.82
Loss: 0.42274153232574463, Accuracy: 0.83
Loss: 0.4218193292617798, Accuracy: 0.83
Loss: 0.3741583526134491, Accuracy: 0.87
Loss: 0.5342292189598083, Accuracy: 0.75
Loss: 0.6392752528190613, Accuracy: 0.71
Loss: 0.4836200773715973, Accuracy: 0.8
Loss: 0.42808765172958374, Accuracy: 0.82
Loss: 0.4067937731742859, Accuracy: 0.84
Loss: 0.4899139106273651, Accuracy: 0.81
Loss: 0.47417324781417847, Accuracy: 0.77
Loss: 0.5363

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62376}
saved model
Epoch: 14


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.44015613198280334, Accuracy: 0.8
Loss: 0.4566906690597534, Accuracy: 0.81
Loss: 0.5062787532806396, Accuracy: 0.77
Loss: 0.5371149182319641, Accuracy: 0.75
Loss: 0.47772955894470215, Accuracy: 0.79
Loss: 0.44092392921447754, Accuracy: 0.79
Loss: 0.4285544455051422, Accuracy: 0.8
Loss: 0.475587397813797, Accuracy: 0.83
Loss: 0.48609215021133423, Accuracy: 0.81
Loss: 0.38313284516334534, Accuracy: 0.82
Loss: 0.5124626755714417, Accuracy: 0.76
Loss: 0.5809324383735657, Accuracy: 0.74
Loss: 0.3509413003921509, Accuracy: 0.87
Loss: 0.485423743724823, Accuracy: 0.8
Loss: 0.4075598418712616, Accuracy: 0.82
Loss: 0.4591451585292816, Accuracy: 0.79
Loss: 0.4631789028644562, Accuracy: 0.82
Loss: 0.5113885998725891, Accuracy: 0.77
Loss: 0.4144427478313446, Accuracy: 0.84
Loss: 0.5344179272651672, Accuracy: 0.76
Loss: 0.5392760634422302, Accuracy: 0.81
Loss: 0.43293440341949463, Accuracy: 0.82
Loss: 0.46887537837028503, Accuracy: 0.82
Loss: 0.49362799525260925, Accuracy: 0.8
Loss: 0.391436

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6164}
saved model
Epoch: 15


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.4812204837799072, Accuracy: 0.81
Loss: 0.41461604833602905, Accuracy: 0.83
Loss: 0.609056293964386, Accuracy: 0.66
Loss: 0.45480769872665405, Accuracy: 0.81
Loss: 0.39931434392929077, Accuracy: 0.84
Loss: 0.5563493371009827, Accuracy: 0.8
Loss: 0.4516492187976837, Accuracy: 0.81
Loss: 0.5203502774238586, Accuracy: 0.8
Loss: 0.46597325801849365, Accuracy: 0.79
Loss: 0.3984338343143463, Accuracy: 0.87
Loss: 0.40547311305999756, Accuracy: 0.84
Loss: 0.45197948813438416, Accuracy: 0.8
Loss: 0.385908305644989, Accuracy: 0.84
Loss: 0.34904345870018005, Accuracy: 0.88
Loss: 0.39129626750946045, Accuracy: 0.88
Loss: 0.3753148913383484, Accuracy: 0.84
Loss: 0.369381308555603, Accuracy: 0.84
Loss: 0.3402599096298218, Accuracy: 0.89
Loss: 0.536942720413208, Accuracy: 0.78
Loss: 0.3052680194377899, Accuracy: 0.89
Loss: 0.5254909992218018, Accuracy: 0.78
Loss: 0.4280507564544678, Accuracy: 0.84
Loss: 0.4474238157272339, Accuracy: 0.76
Loss: 0.3398231863975525, Accuracy: 0.87
Loss: 0.4419507

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61256}
saved model
Epoch: 16


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.528493344783783, Accuracy: 0.77
Loss: 0.4441882073879242, Accuracy: 0.81
Loss: 0.4397088289260864, Accuracy: 0.83
Loss: 0.4087584614753723, Accuracy: 0.86
Loss: 0.5690801739692688, Accuracy: 0.75
Loss: 0.41594353318214417, Accuracy: 0.82
Loss: 0.509721040725708, Accuracy: 0.79
Loss: 0.3486555814743042, Accuracy: 0.88
Loss: 0.4746953248977661, Accuracy: 0.8
Loss: 0.3274739980697632, Accuracy: 0.89
Loss: 0.5938897132873535, Accuracy: 0.73
Loss: 0.42929187417030334, Accuracy: 0.85
Loss: 0.3377719521522522, Accuracy: 0.89
Loss: 0.357666552066803, Accuracy: 0.87
Loss: 0.38067543506622314, Accuracy: 0.83
Loss: 0.36995211243629456, Accuracy: 0.87
Loss: 0.4260447323322296, Accuracy: 0.86
Loss: 0.4230680763721466, Accuracy: 0.84
Loss: 0.48433685302734375, Accuracy: 0.76
Loss: 0.449655681848526, Accuracy: 0.81
Loss: 0.4428527057170868, Accuracy: 0.82
Loss: 0.4756460189819336, Accuracy: 0.8
Loss: 0.3882283866405487, Accuracy: 0.87
Loss: 0.4541362524032593, Accuracy: 0.82
Loss: 0.437700718

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62496}
saved model
Epoch: 17


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.4990676939487457, Accuracy: 0.77
Loss: 0.4391956329345703, Accuracy: 0.83
Loss: 0.5066943168640137, Accuracy: 0.78
Loss: 0.4353516101837158, Accuracy: 0.82
Loss: 0.4666807949542999, Accuracy: 0.83
Loss: 0.5184822678565979, Accuracy: 0.77
Loss: 0.41458621621131897, Accuracy: 0.84
Loss: 0.5170308351516724, Accuracy: 0.78
Loss: 0.43047401309013367, Accuracy: 0.86
Loss: 0.47527816891670227, Accuracy: 0.8
Loss: 0.32894641160964966, Accuracy: 0.91
Loss: 0.41560786962509155, Accuracy: 0.85
Loss: 0.37461644411087036, Accuracy: 0.86
Loss: 0.4457318186759949, Accuracy: 0.85
Loss: 0.31785812973976135, Accuracy: 0.9
Loss: 0.36215049028396606, Accuracy: 0.87
Loss: 0.42394015192985535, Accuracy: 0.84
Loss: 0.4340316653251648, Accuracy: 0.85
Loss: 0.47314149141311646, Accuracy: 0.83
Loss: 0.4422467052936554, Accuracy: 0.86
Loss: 0.3518911600112915, Accuracy: 0.84
Loss: 0.5406724810600281, Accuracy: 0.77
Loss: 0.3401414453983307, Accuracy: 0.89
Loss: 0.33022332191467285, Accuracy: 0.88
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61896}
saved model
Epoch: 18


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.44519999623298645, Accuracy: 0.83
Loss: 0.45113861560821533, Accuracy: 0.82
Loss: 0.45135730504989624, Accuracy: 0.82
Loss: 0.44150370359420776, Accuracy: 0.83
Loss: 0.4105626344680786, Accuracy: 0.84
Loss: 0.37723883986473083, Accuracy: 0.87
Loss: 0.43033507466316223, Accuracy: 0.85
Loss: 0.40616652369499207, Accuracy: 0.85
Loss: 0.3213843107223511, Accuracy: 0.9
Loss: 0.3617285192012787, Accuracy: 0.88
Loss: 0.3520749807357788, Accuracy: 0.89
Loss: 0.349943071603775, Accuracy: 0.86
Loss: 0.34546756744384766, Accuracy: 0.9
Loss: 0.36413413286209106, Accuracy: 0.85
Loss: 0.29547375440597534, Accuracy: 0.91
Loss: 0.4066150188446045, Accuracy: 0.84
Loss: 0.45640814304351807, Accuracy: 0.83
Loss: 0.3818664848804474, Accuracy: 0.87
Loss: 0.5040828585624695, Accuracy: 0.82
Loss: 0.39015936851501465, Accuracy: 0.86
Loss: 0.33355382084846497, Accuracy: 0.9
Loss: 0.4466431140899658, Accuracy: 0.82
Loss: 0.3455190658569336, Accuracy: 0.88
Loss: 0.5000810027122498, Accuracy: 0.78
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6204}
saved model
Epoch: 19


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.28666922450065613, Accuracy: 0.91
Loss: 0.47047796845436096, Accuracy: 0.82
Loss: 0.46620461344718933, Accuracy: 0.83
Loss: 0.3479302227497101, Accuracy: 0.85
Loss: 0.4027581810951233, Accuracy: 0.85
Loss: 0.3498714566230774, Accuracy: 0.87
Loss: 0.3631443679332733, Accuracy: 0.86
Loss: 0.45916369557380676, Accuracy: 0.83
Loss: 0.34362491965293884, Accuracy: 0.87
Loss: 0.33667975664138794, Accuracy: 0.89
Loss: 0.26664575934410095, Accuracy: 0.92
Loss: 0.3409402072429657, Accuracy: 0.88
Loss: 0.3898669183254242, Accuracy: 0.85
Loss: 0.4340405762195587, Accuracy: 0.85
Loss: 0.3224949240684509, Accuracy: 0.89
Loss: 0.270290344953537, Accuracy: 0.93
Loss: 0.44755637645721436, Accuracy: 0.84
Loss: 0.23927761614322662, Accuracy: 0.93
Loss: 0.3603837192058563, Accuracy: 0.87
Loss: 0.4568796455860138, Accuracy: 0.83
Loss: 0.3150533139705658, Accuracy: 0.89
Loss: 0.3350246548652649, Accuracy: 0.87
Loss: 0.4459417760372162, Accuracy: 0.85
Loss: 0.3045227527618408, Accuracy: 0.91
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61512}
saved model
Epoch: 20


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.44044753909111023, Accuracy: 0.87
Loss: 0.3607744574546814, Accuracy: 0.87
Loss: 0.340681254863739, Accuracy: 0.9
Loss: 0.2750430405139923, Accuracy: 0.94
Loss: 0.3115691542625427, Accuracy: 0.89
Loss: 0.43348270654678345, Accuracy: 0.84
Loss: 0.2672026753425598, Accuracy: 0.93
Loss: 0.39986464381217957, Accuracy: 0.85
Loss: 0.3371357321739197, Accuracy: 0.88
Loss: 0.4102306663990021, Accuracy: 0.84
Loss: 0.36134016513824463, Accuracy: 0.88
Loss: 0.4246845245361328, Accuracy: 0.84
Loss: 0.4135754108428955, Accuracy: 0.86
Loss: 0.32359421253204346, Accuracy: 0.89
Loss: 0.400865375995636, Accuracy: 0.86
Loss: 0.4745349586009979, Accuracy: 0.82
Loss: 0.37181299924850464, Accuracy: 0.9
Loss: 0.4260178506374359, Accuracy: 0.84
Loss: 0.4082271456718445, Accuracy: 0.84
Loss: 0.2707132399082184, Accuracy: 0.92
Loss: 0.4423961341381073, Accuracy: 0.82
Loss: 0.3494219481945038, Accuracy: 0.88
Loss: 0.36541929841041565, Accuracy: 0.88
Loss: 0.2966911196708679, Accuracy: 0.88
Loss: 0.38687

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61848}
saved model
Epoch: 21


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.32159537076950073, Accuracy: 0.89
Loss: 0.40272682905197144, Accuracy: 0.84
Loss: 0.4089112877845764, Accuracy: 0.84
Loss: 0.408155620098114, Accuracy: 0.85
Loss: 0.3366684019565582, Accuracy: 0.87
Loss: 0.277452290058136, Accuracy: 0.9
Loss: 0.32995688915252686, Accuracy: 0.9
Loss: 0.2849683463573456, Accuracy: 0.9
Loss: 0.31045520305633545, Accuracy: 0.88
Loss: 0.37380295991897583, Accuracy: 0.86
Loss: 0.37322819232940674, Accuracy: 0.88
Loss: 0.37189650535583496, Accuracy: 0.86
Loss: 0.2900523841381073, Accuracy: 0.93
Loss: 0.38228529691696167, Accuracy: 0.86
Loss: 0.44750282168388367, Accuracy: 0.83
Loss: 0.4085089862346649, Accuracy: 0.87
Loss: 0.4151778817176819, Accuracy: 0.83
Loss: 0.45686575770378113, Accuracy: 0.79
Loss: 0.28741922974586487, Accuracy: 0.88
Loss: 0.4404161870479584, Accuracy: 0.84
Loss: 0.36130011081695557, Accuracy: 0.87
Loss: 0.3159375488758087, Accuracy: 0.91
Loss: 0.30318784713745117, Accuracy: 0.91
Loss: 0.35420650243759155, Accuracy: 0.88
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6144}
saved model
Epoch: 22


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2829907238483429, Accuracy: 0.91
Loss: 0.2861933410167694, Accuracy: 0.91
Loss: 0.33549216389656067, Accuracy: 0.86
Loss: 0.27565720677375793, Accuracy: 0.92
Loss: 0.3279639780521393, Accuracy: 0.9
Loss: 0.26221245527267456, Accuracy: 0.92
Loss: 0.3393292725086212, Accuracy: 0.9
Loss: 0.31851744651794434, Accuracy: 0.91
Loss: 0.2593070864677429, Accuracy: 0.92
Loss: 0.29059576988220215, Accuracy: 0.9
Loss: 0.2745796740055084, Accuracy: 0.92
Loss: 0.35991424322128296, Accuracy: 0.88
Loss: 0.22011566162109375, Accuracy: 0.91
Loss: 0.3554503619670868, Accuracy: 0.88
Loss: 0.3402865529060364, Accuracy: 0.89
Loss: 0.33012208342552185, Accuracy: 0.89
Loss: 0.39944320917129517, Accuracy: 0.85
Loss: 0.3582092225551605, Accuracy: 0.87
Loss: 0.33593955636024475, Accuracy: 0.89
Loss: 0.2791005074977875, Accuracy: 0.92
Loss: 0.4191173315048218, Accuracy: 0.83
Loss: 0.2950027585029602, Accuracy: 0.91
Loss: 0.2609281837940216, Accuracy: 0.92
Loss: 0.34761863946914673, Accuracy: 0.87
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61976}
saved model
Epoch: 23


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.4440721571445465, Accuracy: 0.83
Loss: 0.26526370644569397, Accuracy: 0.92
Loss: 0.2628634572029114, Accuracy: 0.91
Loss: 0.33855143189430237, Accuracy: 0.86
Loss: 0.3369344472885132, Accuracy: 0.87
Loss: 0.2749282717704773, Accuracy: 0.93
Loss: 0.3110150098800659, Accuracy: 0.9
Loss: 0.3636658191680908, Accuracy: 0.88
Loss: 0.3748765289783478, Accuracy: 0.86
Loss: 0.2635796368122101, Accuracy: 0.91
Loss: 0.35815033316612244, Accuracy: 0.87
Loss: 0.2863176167011261, Accuracy: 0.9
Loss: 0.36318251490592957, Accuracy: 0.86
Loss: 0.4205714762210846, Accuracy: 0.84
Loss: 0.2089974284172058, Accuracy: 0.94
Loss: 0.4448585510253906, Accuracy: 0.82
Loss: 0.2069050669670105, Accuracy: 0.95
Loss: 0.23107533156871796, Accuracy: 0.95
Loss: 0.29371994733810425, Accuracy: 0.9
Loss: 0.32825568318367004, Accuracy: 0.87
Loss: 0.388207346200943, Accuracy: 0.85
Loss: 0.28240537643432617, Accuracy: 0.89
Loss: 0.33461588621139526, Accuracy: 0.87
Loss: 0.4293409585952759, Accuracy: 0.84
Loss: 0.370

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6144}
saved model
Epoch: 24


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3334583342075348, Accuracy: 0.88
Loss: 0.39797303080558777, Accuracy: 0.84
Loss: 0.5185604095458984, Accuracy: 0.79
Loss: 0.320835143327713, Accuracy: 0.9
Loss: 0.2732694149017334, Accuracy: 0.92
Loss: 0.3388841152191162, Accuracy: 0.86
Loss: 0.3383959233760834, Accuracy: 0.89
Loss: 0.3744342029094696, Accuracy: 0.85
Loss: 0.22155620157718658, Accuracy: 0.95
Loss: 0.20687393844127655, Accuracy: 0.95
Loss: 0.25511085987091064, Accuracy: 0.92
Loss: 0.31965646147727966, Accuracy: 0.89
Loss: 0.35671311616897583, Accuracy: 0.88
Loss: 0.2541815936565399, Accuracy: 0.91
Loss: 0.334075927734375, Accuracy: 0.88
Loss: 0.47109970450401306, Accuracy: 0.83
Loss: 0.22545550763607025, Accuracy: 0.94
Loss: 0.2860647737979889, Accuracy: 0.9
Loss: 0.2145485281944275, Accuracy: 0.93
Loss: 0.2664546072483063, Accuracy: 0.88
Loss: 0.2782609164714813, Accuracy: 0.91
Loss: 0.30982473492622375, Accuracy: 0.92
Loss: 0.27159935235977173, Accuracy: 0.91
Loss: 0.29398292303085327, Accuracy: 0.88
Loss: 0.3

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.62064}
saved model
Epoch: 25


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.35097163915634155, Accuracy: 0.88
Loss: 0.24886491894721985, Accuracy: 0.93
Loss: 0.42608538269996643, Accuracy: 0.85
Loss: 0.2916083037853241, Accuracy: 0.91
Loss: 0.33122193813323975, Accuracy: 0.9
Loss: 0.323169469833374, Accuracy: 0.86
Loss: 0.3070865869522095, Accuracy: 0.87
Loss: 0.3742899000644684, Accuracy: 0.87
Loss: 0.22919709980487823, Accuracy: 0.94
Loss: 0.297802209854126, Accuracy: 0.9
Loss: 0.3786265254020691, Accuracy: 0.87
Loss: 0.28924325108528137, Accuracy: 0.91
Loss: 0.359787255525589, Accuracy: 0.87
Loss: 0.3173147141933441, Accuracy: 0.9
Loss: 0.25896549224853516, Accuracy: 0.93
Loss: 0.2587931454181671, Accuracy: 0.92
Loss: 0.3733789920806885, Accuracy: 0.86
Loss: 0.374095618724823, Accuracy: 0.87
Loss: 0.39999920129776, Accuracy: 0.85
Loss: 0.2645173966884613, Accuracy: 0.91
Loss: 0.3580119013786316, Accuracy: 0.85
Loss: 0.26315560936927795, Accuracy: 0.92
Loss: 0.21024508774280548, Accuracy: 0.95
Loss: 0.29960912466049194, Accuracy: 0.9
Loss: 0.23831033

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6148}
saved model
Epoch: 26


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.32970765233039856, Accuracy: 0.86
Loss: 0.3129243552684784, Accuracy: 0.87
Loss: 0.2890338599681854, Accuracy: 0.92
Loss: 0.3346942961215973, Accuracy: 0.87
Loss: 0.44235533475875854, Accuracy: 0.83
Loss: 0.32250186800956726, Accuracy: 0.88
Loss: 0.3467421233654022, Accuracy: 0.87
Loss: 0.3063863515853882, Accuracy: 0.89
Loss: 0.32623812556266785, Accuracy: 0.9
Loss: 0.2958354651927948, Accuracy: 0.89
Loss: 0.3139464855194092, Accuracy: 0.89
Loss: 0.22303558886051178, Accuracy: 0.93
Loss: 0.28567954897880554, Accuracy: 0.9
Loss: 0.24077986180782318, Accuracy: 0.92
Loss: 0.30114421248435974, Accuracy: 0.9
Loss: 0.25981661677360535, Accuracy: 0.91
Loss: 0.2576514184474945, Accuracy: 0.93
Loss: 0.26865696907043457, Accuracy: 0.91
Loss: 0.24087268114089966, Accuracy: 0.92
Loss: 0.35445430874824524, Accuracy: 0.86
Loss: 0.1814737319946289, Accuracy: 0.95
Loss: 0.29986199736595154, Accuracy: 0.88
Loss: 0.17534737288951874, Accuracy: 0.96
Loss: 0.37152299284935, Accuracy: 0.87
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61536}
saved model
Epoch: 27


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.42271414399147034, Accuracy: 0.85
Loss: 0.3539898693561554, Accuracy: 0.88
Loss: 0.4091081917285919, Accuracy: 0.82
Loss: 0.24273681640625, Accuracy: 0.93
Loss: 0.31368789076805115, Accuracy: 0.91
Loss: 0.43516409397125244, Accuracy: 0.83
Loss: 0.28372177481651306, Accuracy: 0.91
Loss: 0.28337740898132324, Accuracy: 0.91
Loss: 0.28644469380378723, Accuracy: 0.91
Loss: 0.28039658069610596, Accuracy: 0.9
Loss: 0.32157421112060547, Accuracy: 0.9
Loss: 0.31689196825027466, Accuracy: 0.9
Loss: 0.23707668483257294, Accuracy: 0.94
Loss: 0.2874632477760315, Accuracy: 0.89
Loss: 0.2980930805206299, Accuracy: 0.89
Loss: 0.3149212598800659, Accuracy: 0.91
Loss: 0.2305692434310913, Accuracy: 0.92
Loss: 0.29564112424850464, Accuracy: 0.89
Loss: 0.20987337827682495, Accuracy: 0.93
Loss: 0.19737958908081055, Accuracy: 0.96
Loss: 0.3156333267688751, Accuracy: 0.89
Loss: 0.2697065472602844, Accuracy: 0.9
Loss: 0.2863699197769165, Accuracy: 0.9
Loss: 0.36384502053260803, Accuracy: 0.86
Loss: 0.2

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61656}
saved model
Epoch: 28


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.36434081196784973, Accuracy: 0.88
Loss: 0.29940593242645264, Accuracy: 0.89
Loss: 0.39258521795272827, Accuracy: 0.88
Loss: 0.3420184254646301, Accuracy: 0.88
Loss: 0.3043457269668579, Accuracy: 0.91
Loss: 0.24564027786254883, Accuracy: 0.94
Loss: 0.2064177691936493, Accuracy: 0.94
Loss: 0.43267595767974854, Accuracy: 0.85
Loss: 0.35417571663856506, Accuracy: 0.86
Loss: 0.3250901699066162, Accuracy: 0.86
Loss: 0.22417284548282623, Accuracy: 0.93
Loss: 0.26638686656951904, Accuracy: 0.93
Loss: 0.2489989697933197, Accuracy: 0.92
Loss: 0.5169954895973206, Accuracy: 0.8
Loss: 0.23184971511363983, Accuracy: 0.94
Loss: 0.33183300495147705, Accuracy: 0.9
Loss: 0.5113485455513, Accuracy: 0.82
Loss: 0.4241285026073456, Accuracy: 0.82
Loss: 0.3624258041381836, Accuracy: 0.85
Loss: 0.4143217206001282, Accuracy: 0.82
Loss: 0.2868041694164276, Accuracy: 0.89
Loss: 0.35066699981689453, Accuracy: 0.89
Loss: 0.27623289823532104, Accuracy: 0.91
Loss: 0.280769944190979, Accuracy: 0.91
Loss: 0.23

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61576}
saved model
Epoch: 29


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.21563343703746796, Accuracy: 0.95
Loss: 0.27456212043762207, Accuracy: 0.91
Loss: 0.2913884222507477, Accuracy: 0.88
Loss: 0.2162828892469406, Accuracy: 0.93
Loss: 0.19061905145645142, Accuracy: 0.94
Loss: 0.25504183769226074, Accuracy: 0.93
Loss: 0.22910811007022858, Accuracy: 0.93
Loss: 0.21669046580791473, Accuracy: 0.91
Loss: 0.3110771179199219, Accuracy: 0.87
Loss: 0.2794114351272583, Accuracy: 0.92
Loss: 0.32511425018310547, Accuracy: 0.9
Loss: 0.2891803979873657, Accuracy: 0.9
Loss: 0.18079029023647308, Accuracy: 0.94
Loss: 0.34072163701057434, Accuracy: 0.88
Loss: 0.2527897357940674, Accuracy: 0.92
Loss: 0.140689417719841, Accuracy: 0.97
Loss: 0.26914751529693604, Accuracy: 0.91
Loss: 0.396391361951828, Accuracy: 0.85
Loss: 0.2682052552700043, Accuracy: 0.9
Loss: 0.25612467527389526, Accuracy: 0.94
Loss: 0.23347541689872742, Accuracy: 0.91
Loss: 0.20532625913619995, Accuracy: 0.9
Loss: 0.3522593379020691, Accuracy: 0.85
Loss: 0.24643395841121674, Accuracy: 0.91
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61888}
saved model
Epoch: 30


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.3084583580493927, Accuracy: 0.91
Loss: 0.21566934883594513, Accuracy: 0.92
Loss: 0.31212031841278076, Accuracy: 0.9
Loss: 0.3167377710342407, Accuracy: 0.88
Loss: 0.2585124671459198, Accuracy: 0.91
Loss: 0.21648286283016205, Accuracy: 0.93
Loss: 0.22428090870380402, Accuracy: 0.92
Loss: 0.3387545645236969, Accuracy: 0.86
Loss: 0.12376399338245392, Accuracy: 0.97
Loss: 0.21553298830986023, Accuracy: 0.93
Loss: 0.3306281566619873, Accuracy: 0.89
Loss: 0.22983913123607635, Accuracy: 0.93
Loss: 0.25532686710357666, Accuracy: 0.91
Loss: 0.2744995057582855, Accuracy: 0.91
Loss: 0.25846806168556213, Accuracy: 0.9
Loss: 0.4223755896091461, Accuracy: 0.85
Loss: 0.2563720643520355, Accuracy: 0.92
Loss: 0.358210027217865, Accuracy: 0.87
Loss: 0.41455313563346863, Accuracy: 0.84
Loss: 0.4852334260940552, Accuracy: 0.83
Loss: 0.3225027024745941, Accuracy: 0.9
Loss: 0.3833012282848358, Accuracy: 0.84
Loss: 0.2850002348423004, Accuracy: 0.91
Loss: 0.2501644790172577, Accuracy: 0.93
Loss: 0.27

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61624}
saved model
Epoch: 31


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2141581028699875, Accuracy: 0.94
Loss: 0.30814990401268005, Accuracy: 0.87
Loss: 0.3087371289730072, Accuracy: 0.91
Loss: 0.3258511424064636, Accuracy: 0.87
Loss: 0.21219517290592194, Accuracy: 0.92
Loss: 0.2979241907596588, Accuracy: 0.88
Loss: 0.2627246379852295, Accuracy: 0.91
Loss: 0.243495374917984, Accuracy: 0.92
Loss: 0.23292386531829834, Accuracy: 0.92
Loss: 0.25120124220848083, Accuracy: 0.93
Loss: 0.2872995138168335, Accuracy: 0.9
Loss: 0.3072104752063751, Accuracy: 0.89
Loss: 0.244731143116951, Accuracy: 0.91
Loss: 0.33915406465530396, Accuracy: 0.89
Loss: 0.17790794372558594, Accuracy: 0.94
Loss: 0.26215291023254395, Accuracy: 0.91
Loss: 0.23201239109039307, Accuracy: 0.92
Loss: 0.15820477902889252, Accuracy: 0.96
Loss: 0.20712697505950928, Accuracy: 0.93
Loss: 0.5862968564033508, Accuracy: 0.79
Loss: 0.40599915385246277, Accuracy: 0.88
Loss: 0.25305700302124023, Accuracy: 0.92
Loss: 0.2888101041316986, Accuracy: 0.89
Loss: 0.2325080931186676, Accuracy: 0.93
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61624}
saved model
Epoch: 32


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.23111289739608765, Accuracy: 0.92
Loss: 0.27380436658859253, Accuracy: 0.91
Loss: 0.20408710837364197, Accuracy: 0.94
Loss: 0.16987189650535583, Accuracy: 0.95
Loss: 0.2306998074054718, Accuracy: 0.91
Loss: 0.2169504314661026, Accuracy: 0.93
Loss: 0.22432845830917358, Accuracy: 0.92
Loss: 0.44460949301719666, Accuracy: 0.83
Loss: 0.21080277860164642, Accuracy: 0.92
Loss: 0.364725798368454, Accuracy: 0.85
Loss: 0.3035331666469574, Accuracy: 0.85
Loss: 0.28376027941703796, Accuracy: 0.88
Loss: 0.23005525767803192, Accuracy: 0.94
Loss: 0.17247207462787628, Accuracy: 0.93
Loss: 0.3368854522705078, Accuracy: 0.89
Loss: 0.2650088369846344, Accuracy: 0.92
Loss: 0.24514888226985931, Accuracy: 0.9
Loss: 0.2722817361354828, Accuracy: 0.88
Loss: 0.28194624185562134, Accuracy: 0.91
Loss: 0.3022318184375763, Accuracy: 0.89
Loss: 0.25259843468666077, Accuracy: 0.9
Loss: 0.2323606312274933, Accuracy: 0.92
Loss: 0.2619558274745941, Accuracy: 0.9
Loss: 0.28685155510902405, Accuracy: 0.9
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6196}
saved model
Epoch: 33


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.19213241338729858, Accuracy: 0.93
Loss: 0.30649664998054504, Accuracy: 0.9
Loss: 0.2840655446052551, Accuracy: 0.9
Loss: 0.18336273729801178, Accuracy: 0.94
Loss: 0.2951018810272217, Accuracy: 0.87
Loss: 0.23004573583602905, Accuracy: 0.92
Loss: 0.19945937395095825, Accuracy: 0.93
Loss: 0.2838096618652344, Accuracy: 0.9
Loss: 0.24533498287200928, Accuracy: 0.88
Loss: 0.18601499497890472, Accuracy: 0.96
Loss: 0.2586430013179779, Accuracy: 0.9
Loss: 0.1786956638097763, Accuracy: 0.94
Loss: 0.3565261960029602, Accuracy: 0.85
Loss: 0.18094687163829803, Accuracy: 0.94
Loss: 0.2202979028224945, Accuracy: 0.94
Loss: 0.2713870704174042, Accuracy: 0.9
Loss: 0.28611788153648376, Accuracy: 0.86
Loss: 0.20578201115131378, Accuracy: 0.92
Loss: 0.3060731291770935, Accuracy: 0.88
Loss: 0.21745023131370544, Accuracy: 0.92
Loss: 0.284233033657074, Accuracy: 0.88
Loss: 0.19077353179454803, Accuracy: 0.94
Loss: 0.27085307240486145, Accuracy: 0.88
Loss: 0.23873576521873474, Accuracy: 0.92
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61376}
saved model
Epoch: 34


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.17451885342597961, Accuracy: 0.95
Loss: 0.18866503238677979, Accuracy: 0.95
Loss: 0.27875274419784546, Accuracy: 0.88
Loss: 0.20941844582557678, Accuracy: 0.93
Loss: 0.2794789671897888, Accuracy: 0.89
Loss: 0.35289886593818665, Accuracy: 0.84
Loss: 0.22165770828723907, Accuracy: 0.93
Loss: 0.3124421238899231, Accuracy: 0.89
Loss: 0.11033559590578079, Accuracy: 0.98
Loss: 0.34053361415863037, Accuracy: 0.84
Loss: 0.18679441511631012, Accuracy: 0.94
Loss: 0.2508583664894104, Accuracy: 0.91
Loss: 0.24202817678451538, Accuracy: 0.9
Loss: 0.1523059606552124, Accuracy: 0.96
Loss: 0.15492981672286987, Accuracy: 0.95
Loss: 0.24894018471240997, Accuracy: 0.89
Loss: 0.25946754217147827, Accuracy: 0.91
Loss: 0.22771833837032318, Accuracy: 0.93
Loss: 0.2707156836986542, Accuracy: 0.9
Loss: 0.3142308294773102, Accuracy: 0.88
Loss: 0.21905525028705597, Accuracy: 0.93
Loss: 0.2857102155685425, Accuracy: 0.9
Loss: 0.12357840687036514, Accuracy: 0.97
Loss: 0.2625057101249695, Accuracy: 0.9
Loss

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61256}
saved model
Epoch: 35


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.16912461817264557, Accuracy: 0.94
Loss: 0.3957393765449524, Accuracy: 0.85
Loss: 0.39586013555526733, Accuracy: 0.84
Loss: 0.3092193305492401, Accuracy: 0.88
Loss: 0.306203156709671, Accuracy: 0.87
Loss: 0.1902935951948166, Accuracy: 0.94
Loss: 0.28554970026016235, Accuracy: 0.91
Loss: 0.23647451400756836, Accuracy: 0.92
Loss: 0.21078313887119293, Accuracy: 0.92
Loss: 0.16295616328716278, Accuracy: 0.96
Loss: 0.16227230429649353, Accuracy: 0.97
Loss: 0.2743699848651886, Accuracy: 0.89
Loss: 0.2718263566493988, Accuracy: 0.89
Loss: 0.34281760454177856, Accuracy: 0.84
Loss: 0.2328519970178604, Accuracy: 0.91
Loss: 0.17853082716464996, Accuracy: 0.93
Loss: 0.22050583362579346, Accuracy: 0.91
Loss: 0.2358444780111313, Accuracy: 0.92
Loss: 0.2057027816772461, Accuracy: 0.92
Loss: 0.2540697753429413, Accuracy: 0.88
Loss: 0.2836326062679291, Accuracy: 0.89
Loss: 0.27056631445884705, Accuracy: 0.9
Loss: 0.2727956175804138, Accuracy: 0.86
Loss: 0.23535297811031342, Accuracy: 0.9
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6128}
saved model
Epoch: 36


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.20909962058067322, Accuracy: 0.93
Loss: 0.19976110756397247, Accuracy: 0.93
Loss: 0.16284117102622986, Accuracy: 0.95
Loss: 0.33052459359169006, Accuracy: 0.87
Loss: 0.17931795120239258, Accuracy: 0.94
Loss: 0.2910885810852051, Accuracy: 0.89
Loss: 0.19459624588489532, Accuracy: 0.93
Loss: 0.24990878999233246, Accuracy: 0.91
Loss: 0.2580242455005646, Accuracy: 0.88
Loss: 0.15837855637073517, Accuracy: 0.95
Loss: 0.2595258951187134, Accuracy: 0.88
Loss: 0.2979432940483093, Accuracy: 0.9
Loss: 0.24607324600219727, Accuracy: 0.9
Loss: 0.2606407403945923, Accuracy: 0.89
Loss: 0.2917453646659851, Accuracy: 0.87
Loss: 0.13659806549549103, Accuracy: 0.96
Loss: 0.36833658814430237, Accuracy: 0.84
Loss: 0.3576773703098297, Accuracy: 0.86
Loss: 0.19795750081539154, Accuracy: 0.95
Loss: 0.2804943323135376, Accuracy: 0.9
Loss: 0.2522766590118408, Accuracy: 0.89
Loss: 0.16763648390769958, Accuracy: 0.96
Loss: 0.22513389587402344, Accuracy: 0.92
Loss: 0.35224366188049316, Accuracy: 0.83
Loss

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61592}
saved model
Epoch: 37


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.23426908254623413, Accuracy: 0.91
Loss: 0.2732282876968384, Accuracy: 0.89
Loss: 0.20720651745796204, Accuracy: 0.91
Loss: 0.18838825821876526, Accuracy: 0.93
Loss: 0.22847868502140045, Accuracy: 0.9
Loss: 0.27069079875946045, Accuracy: 0.89
Loss: 0.2254396378993988, Accuracy: 0.91
Loss: 0.234313502907753, Accuracy: 0.91
Loss: 0.24262215197086334, Accuracy: 0.88
Loss: 0.3304847478866577, Accuracy: 0.88
Loss: 0.18601909279823303, Accuracy: 0.92
Loss: 0.2083459049463272, Accuracy: 0.93
Loss: 0.25463011860847473, Accuracy: 0.91
Loss: 0.18695925176143646, Accuracy: 0.92
Loss: 0.17655010521411896, Accuracy: 0.95
Loss: 0.18232208490371704, Accuracy: 0.91
Loss: 0.18128080666065216, Accuracy: 0.92
Loss: 0.22476497292518616, Accuracy: 0.91
Loss: 0.25954902172088623, Accuracy: 0.9
Loss: 0.17886143922805786, Accuracy: 0.92
Loss: 0.15613223612308502, Accuracy: 0.95
Loss: 0.20399406552314758, Accuracy: 0.94
Loss: 0.15384453535079956, Accuracy: 0.95
Loss: 0.24666549265384674, Accuracy: 0.92


  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61688}
saved model
Epoch: 38


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2415795922279358, Accuracy: 0.92
Loss: 0.21686524152755737, Accuracy: 0.89
Loss: 0.335824579000473, Accuracy: 0.84
Loss: 0.22964182496070862, Accuracy: 0.93
Loss: 0.32750874757766724, Accuracy: 0.85
Loss: 0.18644073605537415, Accuracy: 0.94
Loss: 0.3383084535598755, Accuracy: 0.87
Loss: 0.3554915487766266, Accuracy: 0.85
Loss: 0.2733764350414276, Accuracy: 0.87
Loss: 0.22465814650058746, Accuracy: 0.91
Loss: 0.2918251156806946, Accuracy: 0.89
Loss: 0.2564985454082489, Accuracy: 0.89
Loss: 0.24471434950828552, Accuracy: 0.89
Loss: 0.34713131189346313, Accuracy: 0.86
Loss: 0.19743946194648743, Accuracy: 0.95
Loss: 0.2899978458881378, Accuracy: 0.87
Loss: 0.3130553960800171, Accuracy: 0.85
Loss: 0.3805559575557709, Accuracy: 0.84
Loss: 0.2598465085029602, Accuracy: 0.91
Loss: 0.23574957251548767, Accuracy: 0.91
Loss: 0.31915634870529175, Accuracy: 0.85
Loss: 0.16319479048252106, Accuracy: 0.96
Loss: 0.21099156141281128, Accuracy: 0.93
Loss: 0.24362698197364807, Accuracy: 0.92
Loss

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.612}
saved model
Epoch: 39


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.40626060962677, Accuracy: 0.84
Loss: 0.318271279335022, Accuracy: 0.88
Loss: 0.2836039662361145, Accuracy: 0.9
Loss: 0.25052568316459656, Accuracy: 0.89
Loss: 0.22410354018211365, Accuracy: 0.88
Loss: 0.1960413008928299, Accuracy: 0.93
Loss: 0.23908500373363495, Accuracy: 0.9
Loss: 0.15753298997879028, Accuracy: 0.95
Loss: 0.2446359395980835, Accuracy: 0.93
Loss: 0.26405230164527893, Accuracy: 0.88
Loss: 0.30313462018966675, Accuracy: 0.88
Loss: 0.2285533845424652, Accuracy: 0.91
Loss: 0.1794051229953766, Accuracy: 0.93
Loss: 0.2782919406890869, Accuracy: 0.91
Loss: 0.23185083270072937, Accuracy: 0.91
Loss: 0.26094475388526917, Accuracy: 0.91
Loss: 0.2248104065656662, Accuracy: 0.92
Loss: 0.2507394254207611, Accuracy: 0.89
Loss: 0.2603190243244171, Accuracy: 0.88
Loss: 0.264822393655777, Accuracy: 0.88
Loss: 0.20239734649658203, Accuracy: 0.91
Loss: 0.19909046590328217, Accuracy: 0.93
Loss: 0.2057945430278778, Accuracy: 0.93
Loss: 0.27826014161109924, Accuracy: 0.91
Loss: 0.230

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61632}
saved model
Epoch: 40


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2942763566970825, Accuracy: 0.88
Loss: 0.22937601804733276, Accuracy: 0.91
Loss: 0.16152173280715942, Accuracy: 0.95
Loss: 0.13723239302635193, Accuracy: 0.95
Loss: 0.20434574782848358, Accuracy: 0.92
Loss: 0.18307162821292877, Accuracy: 0.92
Loss: 0.2154277265071869, Accuracy: 0.9
Loss: 0.19935426115989685, Accuracy: 0.91
Loss: 0.16023626923561096, Accuracy: 0.94
Loss: 0.2784402370452881, Accuracy: 0.9
Loss: 0.22218942642211914, Accuracy: 0.91
Loss: 0.24631080031394958, Accuracy: 0.88
Loss: 0.22249165177345276, Accuracy: 0.92
Loss: 0.2388155162334442, Accuracy: 0.9
Loss: 0.19886581599712372, Accuracy: 0.9
Loss: 0.19046729803085327, Accuracy: 0.93
Loss: 0.21178825199604034, Accuracy: 0.92
Loss: 0.2272808998823166, Accuracy: 0.92
Loss: 0.25710129737854004, Accuracy: 0.91
Loss: 0.19285453855991364, Accuracy: 0.95
Loss: 0.18313907086849213, Accuracy: 0.92
Loss: 0.16694560647010803, Accuracy: 0.92
Loss: 0.19286848604679108, Accuracy: 0.93
Loss: 0.321824312210083, Accuracy: 0.87
Los

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6172}
saved model
Epoch: 41


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.22037826478481293, Accuracy: 0.93
Loss: 0.2935532331466675, Accuracy: 0.88
Loss: 0.2782524824142456, Accuracy: 0.9
Loss: 0.17173638939857483, Accuracy: 0.92
Loss: 0.24646766483783722, Accuracy: 0.91
Loss: 0.20113599300384521, Accuracy: 0.91
Loss: 0.23078957200050354, Accuracy: 0.88
Loss: 0.22014249861240387, Accuracy: 0.9
Loss: 0.23002345860004425, Accuracy: 0.93
Loss: 0.13644686341285706, Accuracy: 0.94
Loss: 0.26863518357276917, Accuracy: 0.88
Loss: 0.2824307978153229, Accuracy: 0.9
Loss: 0.16069269180297852, Accuracy: 0.94
Loss: 0.1904515027999878, Accuracy: 0.92
Loss: 0.24593904614448547, Accuracy: 0.89
Loss: 0.22907216846942902, Accuracy: 0.91
Loss: 0.33770251274108887, Accuracy: 0.83
Loss: 0.279094398021698, Accuracy: 0.87
Loss: 0.1882229447364807, Accuracy: 0.94
Loss: 0.27628451585769653, Accuracy: 0.89
Loss: 0.1884835958480835, Accuracy: 0.92
Loss: 0.2989330589771271, Accuracy: 0.88
Loss: 0.2427651286125183, Accuracy: 0.88
Loss: 0.2804175317287445, Accuracy: 0.89
Loss: 

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.616}
saved model
Epoch: 42


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2576991319656372, Accuracy: 0.9
Loss: 0.21069519221782684, Accuracy: 0.91
Loss: 0.26723307371139526, Accuracy: 0.9
Loss: 0.19298133254051208, Accuracy: 0.95
Loss: 0.1658998429775238, Accuracy: 0.92
Loss: 0.1386117786169052, Accuracy: 0.95
Loss: 0.2232355922460556, Accuracy: 0.91
Loss: 0.18395233154296875, Accuracy: 0.93
Loss: 0.16420866549015045, Accuracy: 0.95
Loss: 0.237367644906044, Accuracy: 0.89
Loss: 0.1352178156375885, Accuracy: 0.95
Loss: 0.38326388597488403, Accuracy: 0.87
Loss: 0.22949349880218506, Accuracy: 0.91
Loss: 0.21032729744911194, Accuracy: 0.9
Loss: 0.219087615609169, Accuracy: 0.91
Loss: 0.19369439780712128, Accuracy: 0.9
Loss: 0.16650596261024475, Accuracy: 0.93
Loss: 0.14311523735523224, Accuracy: 0.93
Loss: 0.19049999117851257, Accuracy: 0.9
Loss: 0.17247866094112396, Accuracy: 0.93
Loss: 0.1341547966003418, Accuracy: 0.95
Loss: 0.2211775779724121, Accuracy: 0.91
Loss: 0.18117870390415192, Accuracy: 0.94
Loss: 0.31385478377342224, Accuracy: 0.87
Loss: 0.

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61424}
saved model
Epoch: 43


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2378859966993332, Accuracy: 0.87
Loss: 0.10044322907924652, Accuracy: 0.97
Loss: 0.24089860916137695, Accuracy: 0.89
Loss: 0.12018444389104843, Accuracy: 0.96
Loss: 0.18374189734458923, Accuracy: 0.91
Loss: 0.2357138842344284, Accuracy: 0.92
Loss: 0.17076557874679565, Accuracy: 0.92
Loss: 0.2141217738389969, Accuracy: 0.89
Loss: 0.18745604157447815, Accuracy: 0.91
Loss: 0.1396757960319519, Accuracy: 0.95
Loss: 0.15531615912914276, Accuracy: 0.92
Loss: 0.23181892931461334, Accuracy: 0.93
Loss: 0.10746055841445923, Accuracy: 0.98
Loss: 0.11504912376403809, Accuracy: 0.95
Loss: 0.22362028062343597, Accuracy: 0.94
Loss: 0.1415630429983139, Accuracy: 0.96
Loss: 0.1958741545677185, Accuracy: 0.91
Loss: 0.23951545357704163, Accuracy: 0.86
Loss: 0.19045330584049225, Accuracy: 0.94
Loss: 0.2096138447523117, Accuracy: 0.91
Loss: 0.1929926723241806, Accuracy: 0.92
Loss: 0.2626420259475708, Accuracy: 0.93
Loss: 0.12890063226222992, Accuracy: 0.94
Loss: 0.12835240364074707, Accuracy: 0.95
L

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61136}
saved model
Epoch: 44


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.1847122758626938, Accuracy: 0.96
Loss: 0.2323591113090515, Accuracy: 0.88
Loss: 0.19063100218772888, Accuracy: 0.96
Loss: 0.24725432693958282, Accuracy: 0.87
Loss: 0.18131978809833527, Accuracy: 0.9
Loss: 0.22296079993247986, Accuracy: 0.88
Loss: 0.13847847282886505, Accuracy: 0.94
Loss: 0.26384565234184265, Accuracy: 0.86
Loss: 0.2835090160369873, Accuracy: 0.85
Loss: 0.1605241894721985, Accuracy: 0.92
Loss: 0.22158057987689972, Accuracy: 0.91
Loss: 0.14364908635616302, Accuracy: 0.93
Loss: 0.2896186411380768, Accuracy: 0.87
Loss: 0.19795560836791992, Accuracy: 0.92
Loss: 0.12608055770397186, Accuracy: 0.95
Loss: 0.19462580978870392, Accuracy: 0.92
Loss: 0.2658568322658539, Accuracy: 0.88
Loss: 0.17665785551071167, Accuracy: 0.92
Loss: 0.17371907830238342, Accuracy: 0.89
Loss: 0.15463560819625854, Accuracy: 0.92
Loss: 0.1411873698234558, Accuracy: 0.94
Loss: 0.223043292760849, Accuracy: 0.88
Loss: 0.15901586413383484, Accuracy: 0.96
Loss: 0.2537767291069031, Accuracy: 0.89
Los

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.6124}
saved model
Epoch: 45


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.17007671296596527, Accuracy: 0.91
Loss: 0.16907066106796265, Accuracy: 0.94
Loss: 0.20972101390361786, Accuracy: 0.9
Loss: 0.19727273285388947, Accuracy: 0.93
Loss: 0.19992084801197052, Accuracy: 0.91
Loss: 0.21646951138973236, Accuracy: 0.92
Loss: 0.2705214023590088, Accuracy: 0.86
Loss: 0.24571830034255981, Accuracy: 0.91
Loss: 0.1602751612663269, Accuracy: 0.91
Loss: 0.20987354218959808, Accuracy: 0.91
Loss: 0.1890941560268402, Accuracy: 0.91
Loss: 0.2130739986896515, Accuracy: 0.9
Loss: 0.25627401471138, Accuracy: 0.89
Loss: 0.20417657494544983, Accuracy: 0.91
Loss: 0.23006445169448853, Accuracy: 0.92
Loss: 0.16591721773147583, Accuracy: 0.92
Loss: 0.3392643630504608, Accuracy: 0.84
Loss: 0.21070390939712524, Accuracy: 0.92
Loss: 0.08084960281848907, Accuracy: 0.99
Loss: 0.19396105408668518, Accuracy: 0.9
Loss: 0.290638267993927, Accuracy: 0.84
Loss: 0.2853984832763672, Accuracy: 0.84
Loss: 0.2800976634025574, Accuracy: 0.86
Loss: 0.11791569739580154, Accuracy: 0.96
Loss: 0

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.61168}
saved model
Epoch: 46


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2069481909275055, Accuracy: 0.93
Loss: 0.15387821197509766, Accuracy: 0.94
Loss: 0.17594440281391144, Accuracy: 0.92
Loss: 0.28417980670928955, Accuracy: 0.88
Loss: 0.2516609728336334, Accuracy: 0.89
Loss: 0.12604078650474548, Accuracy: 0.95
Loss: 0.15488673746585846, Accuracy: 0.94
Loss: 0.14209690690040588, Accuracy: 0.94
Loss: 0.12580937147140503, Accuracy: 0.94
Loss: 0.06299333274364471, Accuracy: 0.98
Loss: 0.24786759912967682, Accuracy: 0.87
Loss: 0.17898567020893097, Accuracy: 0.9
Loss: 0.18744033575057983, Accuracy: 0.9
Loss: 0.16706804931163788, Accuracy: 0.95
Loss: 0.27012112736701965, Accuracy: 0.87
Loss: 0.19048985838890076, Accuracy: 0.91
Loss: 0.21436764299869537, Accuracy: 0.9
Loss: 0.12193135172128677, Accuracy: 0.94
Loss: 0.26504579186439514, Accuracy: 0.88
Loss: 0.13246330618858337, Accuracy: 0.95
Loss: 0.16460245847702026, Accuracy: 0.9
Loss: 0.1309836506843567, Accuracy: 0.94
Loss: 0.13259875774383545, Accuracy: 0.95
Loss: 0.1481427103281021, Accuracy: 0.96


  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60792}
saved model
Epoch: 47


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.17531825602054596, Accuracy: 0.92
Loss: 0.1805199831724167, Accuracy: 0.92
Loss: 0.14783650636672974, Accuracy: 0.94
Loss: 0.18736259639263153, Accuracy: 0.88
Loss: 0.1664814054965973, Accuracy: 0.91
Loss: 0.11943762004375458, Accuracy: 0.96
Loss: 0.16305577754974365, Accuracy: 0.92
Loss: 0.12323834747076035, Accuracy: 0.96
Loss: 0.14606091380119324, Accuracy: 0.94
Loss: 0.159690260887146, Accuracy: 0.94
Loss: 0.14488577842712402, Accuracy: 0.96
Loss: 0.24718207120895386, Accuracy: 0.88
Loss: 0.10134658217430115, Accuracy: 0.96
Loss: 0.1839987337589264, Accuracy: 0.91
Loss: 0.1694524586200714, Accuracy: 0.93
Loss: 0.2472141683101654, Accuracy: 0.88
Loss: 0.23331470787525177, Accuracy: 0.91
Loss: 0.12045090645551682, Accuracy: 0.96
Loss: 0.17611975967884064, Accuracy: 0.91
Loss: 0.2863982021808624, Accuracy: 0.87
Loss: 0.21500149369239807, Accuracy: 0.92
Loss: 0.20612654089927673, Accuracy: 0.91
Loss: 0.18322229385375977, Accuracy: 0.94
Loss: 0.16257864236831665, Accuracy: 0.94


  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60888}
saved model
Epoch: 48


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.09459611773490906, Accuracy: 0.95
Loss: 0.13672132790088654, Accuracy: 0.93
Loss: 0.22176849842071533, Accuracy: 0.9
Loss: 0.12462016195058823, Accuracy: 0.95
Loss: 0.1803765445947647, Accuracy: 0.92
Loss: 0.09353934228420258, Accuracy: 0.97
Loss: 0.2677121162414551, Accuracy: 0.9
Loss: 0.2525424361228943, Accuracy: 0.9
Loss: 0.16651466488838196, Accuracy: 0.91
Loss: 0.25203415751457214, Accuracy: 0.91
Loss: 0.23462936282157898, Accuracy: 0.9
Loss: 0.21017012000083923, Accuracy: 0.92
Loss: 0.12647497653961182, Accuracy: 0.95
Loss: 0.13783957064151764, Accuracy: 0.96
Loss: 0.2103293389081955, Accuracy: 0.89
Loss: 0.15513628721237183, Accuracy: 0.95
Loss: 0.15778174996376038, Accuracy: 0.91
Loss: 0.1821390986442566, Accuracy: 0.89
Loss: 0.1921074539422989, Accuracy: 0.93
Loss: 0.2271161675453186, Accuracy: 0.9
Loss: 0.15445907413959503, Accuracy: 0.93
Loss: 0.15585552155971527, Accuracy: 0.95
Loss: 0.14065366983413696, Accuracy: 0.94
Loss: 0.31868329644203186, Accuracy: 0.84
Loss

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60528}
saved model
Epoch: 49


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.2540372908115387, Accuracy: 0.87
Loss: 0.24622032046318054, Accuracy: 0.88
Loss: 0.20250073075294495, Accuracy: 0.93
Loss: 0.18723514676094055, Accuracy: 0.9
Loss: 0.19631098210811615, Accuracy: 0.89
Loss: 0.17129811644554138, Accuracy: 0.93
Loss: 0.20397424697875977, Accuracy: 0.93
Loss: 0.16012921929359436, Accuracy: 0.9
Loss: 0.24851404130458832, Accuracy: 0.88
Loss: 0.25202542543411255, Accuracy: 0.89
Loss: 0.1167372614145279, Accuracy: 0.95
Loss: 0.17603947222232819, Accuracy: 0.95
Loss: 0.16904594004154205, Accuracy: 0.95
Loss: 0.17749334871768951, Accuracy: 0.94
Loss: 0.203091099858284, Accuracy: 0.92
Loss: 0.14064723253250122, Accuracy: 0.97
Loss: 0.30937880277633667, Accuracy: 0.89
Loss: 0.24844114482402802, Accuracy: 0.88
Loss: 0.20715168118476868, Accuracy: 0.93
Loss: 0.13639488816261292, Accuracy: 0.97
Loss: 0.23486702144145966, Accuracy: 0.91
Loss: 0.19155223667621613, Accuracy: 0.94
Loss: 0.11080603301525116, Accuracy: 0.94
Loss: 0.17072126269340515, Accuracy: 0.9

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on validation set: {'accuracy': 0.60992}


## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

In [None]:
# torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/small_model_tagop_Nov_16_epoch_last.pt')

In [None]:
# Load model weights if desired by uncommenting and using your own pre-trained weights
# import torch
# checkpoint = torch.load('/content/drive/MyDrive/saved_model/small_network_model.pt')
# model.load_state_dict(checkpoint)
# model.eval()

In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(test_dataloader):
      # get the inputs; 
      inputs = batch["input_ids"].to(device)
      # attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.61232}


## Inference

In [None]:
text = "I hated this movie, it's really bad."

input_ids = tokenizer(text, return_tensors="pt").input_ids

# forward pass
outputs = model(inputs=input_ids.to(device))
logits = outputs.logits 
predicted_class_idx = logits.argmax(-1).item()

print("Predicted:", model.config.id2label[predicted_class_idx])