## Set-up environment

As usual, we first install HuggingFace Transformers, and Datasets.

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Nov 16 22:31:19 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   24C    P0    41W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 7.6 MB 14.6 MB/s 
[K     |████████████████████████████████| 182 kB 81.5 MB/s 
[?25h  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


In [4]:
!pip install -q datasets

[K     |████████████████████████████████| 451 kB 12.9 MB/s 
[K     |████████████████████████████████| 212 kB 82.4 MB/s 
[K     |████████████████████████████████| 115 kB 84.6 MB/s 
[K     |████████████████████████████████| 127 kB 86.5 MB/s 
[?25h

## Prepare data

Here we take a small portion of the IMDB dataset, a binary text classification dataset ("is a movie review positive or negative?").

In [5]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train', 'test'])
# train_ds, test_ds = load_dataset("imdb", split=['train[:10]+train[-10:]', 'test[:5]+test[-5:]'])

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

We create id2label and label2id mappings, which are handy at inference time.

In [6]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [7]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


Next, we prepare the data for the model using the tokenizer. 

In [8]:
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Downloading:   0%|          | 0.00/668 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/879 [00:00<?, ?B/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

We set the format to PyTorch tensors, and create familiar PyTorch dataloaders.

In [9]:
train_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])

In [10]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=100, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=50)

Here we verify some things (always important to check out your data!).

In [11]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

label torch.Size([100])
input_ids torch.Size([100, 2048])
attention_mask torch.Size([100, 2048])


In [12]:
tokenizer.decode(batch['input_ids'][3])

"[CLS]This movie was kind of interesting...I had to watch it for a college class about India, however the synopsis tells you this movie is about one thing when it doesn't really contain much cold, hard information on those details. It is not really true to the synopsis until the very end where they sloppily try to tie all the elements together. The gore factor is superb, however. Even right at the very beginning, you want to look away because the gore is pretty intense. Only watch this movie if you want to see some cool gore, because the plot is thin and will make you sad that you wasted time listening to it. I've seen rumors on other websites about this movie being based on true events, however you can not find any information about it online...so basically this movie was a waste of time to watch.[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][

In [13]:
batch['label']

tensor([1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1,
        0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1,
        1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
        0, 0, 0, 1])

## Define model

Next, we define our model, and put it on the GPU.

In [14]:
from transformers import PerceiverForSequenceClassification

import torch

from transformers.models.perceiver.modeling_perceiver import (
    PerceiverConfig,
    PerceiverModel,
    PerceiverClassificationDecoder,
    PerceiverTextPreprocessor,
    PerceiverClassificationDecoder
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


config = PerceiverConfig(
    num_self_attends_per_block = 4
)
preprocessor = PerceiverTextPreprocessor(config)
decoder = PerceiverClassificationDecoder(config,
                                          num_channels=config.d_latents,
                                          trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
                                          use_query_residual=True,
                                         )

# num_self_attends_per_block, num_self_attention_heads, num_cross_attention_heads to something more reasonable and out_channels project_pos_dim and num_channels to 64
model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)



model.to(device)

PerceiverModel(
  (input_preprocessor): PerceiverTextPreprocessor(
    (embeddings): Embedding(262, 768)
    (position_embeddings): Embedding(2048, 768)
  )
  (embeddings): PerceiverEmbeddings()
  (encoder): PerceiverEncoder(
    (cross_attention): PerceiverLayer(
      (attention): PerceiverAttention(
        (self): PerceiverSelfAttention(
          (layernorm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (layernorm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (query): Linear(in_features=1280, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): PerceiverSelfOutput(
          (dense): Linear(in_features=768, out_features=1280, bias=True)
        )
      )
      (layernorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (mlp): Percei

In [14]:
# you can then do a forward pass as follows:
tokenizer = PerceiverTokenizer()
text = "hello world"
inputs = tokenizer(text, return_tensors="pt").input_ids
inputs.to(device)
with torch.no_grad():
  outputs = model(inputs=inputs.to(device))
logits = outputs.logits
print('list(logits.shape): ', list(logits.shape))
# to train, one can train the model using standard cross-entropy:
criterion = torch.nn.CrossEntropyLoss()
labels = torch.tensor([1]).to(device)
loss = criterion(logits, labels)

list(logits.shape):  [1, 2]


## Train the model

Here we train the model using native PyTorch.

In [17]:
from transformers import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(10):  # loop over the dataset multiple times
    torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/small_network_model.pt')
    print('saved model')
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
         # get the inputs; 
         inputs = batch["input_ids"].to(device)
         attention_mask = batch["attention_mask"].to(device)
         labels = batch["label"].to(device)

         # zero the parameter gradients
         optimizer.zero_grad()

         # forward + backward + optimize
         outputs = model(inputs=inputs, attention_mask=attention_mask)
         logits = outputs.logits

         # to train, one can train the model using standard cross-entropy:
         criterion = torch.nn.CrossEntropyLoss()

         loss = criterion(logits, labels)
         loss.backward()
         optimizer.step()

         
         

         # evaluate
         predictions = outputs.logits.argmax(-1).cpu().detach().numpy()
         accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
         print(f"Loss: {loss.item()}, Accuracy: {accuracy}")



saved model
Epoch: 0


  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.8149105310440063, Accuracy: 0.42
Loss: 4.890848636627197, Accuracy: 0.52
Loss: 2.8547942638397217, Accuracy: 0.49
Loss: 0.6953761577606201, Accuracy: 0.48
Loss: 1.6854546070098877, Accuracy: 0.53
Loss: 1.6291992664337158, Accuracy: 0.54
Loss: 1.1833958625793457, Accuracy: 0.48
Loss: 0.6936582326889038, Accuracy: 0.49
Loss: 1.0830514430999756, Accuracy: 0.46
Loss: 0.9893204569816589, Accuracy: 0.57
Loss: 1.0927374362945557, Accuracy: 0.44
Loss: 0.768355131149292, Accuracy: 0.47
Loss: 0.7112301588058472, Accuracy: 0.51
Loss: 0.9394814968109131, Accuracy: 0.44
Loss: 0.9648151397705078, Accuracy: 0.47
Loss: 0.8758718967437744, Accuracy: 0.48
Loss: 0.7268973588943481, Accuracy: 0.52
Loss: 0.6942002177238464, Accuracy: 0.47
Loss: 0.7207542657852173, Accuracy: 0.52
Loss: 0.8085280656814575, Accuracy: 0.48
Loss: 0.695830225944519, Accuracy: 0.61
Loss: 0.7303501963615417, Accuracy: 0.56
Loss: 0.6912200450897217, Accuracy: 0.58
Loss: 0.6976315379142761, Accuracy: 0.53
Loss: 0.69216847419

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6972556114196777, Accuracy: 0.5
Loss: 0.7049053311347961, Accuracy: 0.49
Loss: 0.6707363724708557, Accuracy: 0.6
Loss: 0.6714637279510498, Accuracy: 0.66
Loss: 0.7107384204864502, Accuracy: 0.49
Loss: 0.6561586260795593, Accuracy: 0.62
Loss: 0.699669361114502, Accuracy: 0.53
Loss: 0.7172229886054993, Accuracy: 0.45
Loss: 0.6640281081199646, Accuracy: 0.7
Loss: 0.703586220741272, Accuracy: 0.52
Loss: 0.7204954624176025, Accuracy: 0.51
Loss: 0.7539860606193542, Accuracy: 0.44
Loss: 0.6946790218353271, Accuracy: 0.49
Loss: 0.7065412998199463, Accuracy: 0.5
Loss: 0.7140844464302063, Accuracy: 0.48
Loss: 0.6971065402030945, Accuracy: 0.52
Loss: 0.7033061385154724, Accuracy: 0.49
Loss: 0.6977470517158508, Accuracy: 0.49
Loss: 0.6843752264976501, Accuracy: 0.52
Loss: 0.7209268808364868, Accuracy: 0.48
Loss: 0.6712276339530945, Accuracy: 0.59
Loss: 0.7058241963386536, Accuracy: 0.54
Loss: 0.6966939568519592, Accuracy: 0.51
Loss: 0.6914870738983154, Accuracy: 0.51
Loss: 0.68211913108825

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.624742865562439, Accuracy: 0.63
Loss: 0.64366614818573, Accuracy: 0.63
Loss: 0.6188514828681946, Accuracy: 0.68
Loss: 0.6124616861343384, Accuracy: 0.65
Loss: 0.6473525166511536, Accuracy: 0.6
Loss: 0.6460333466529846, Accuracy: 0.63
Loss: 0.5869605541229248, Accuracy: 0.67
Loss: 0.5223217606544495, Accuracy: 0.78
Loss: 0.7632920145988464, Accuracy: 0.59
Loss: 0.6462527513504028, Accuracy: 0.61
Loss: 0.7072687745094299, Accuracy: 0.57
Loss: 0.5659775733947754, Accuracy: 0.72
Loss: 0.5970693826675415, Accuracy: 0.69
Loss: 0.582088828086853, Accuracy: 0.7
Loss: 0.5917508602142334, Accuracy: 0.69
Loss: 0.6096542477607727, Accuracy: 0.63
Loss: 0.609229564666748, Accuracy: 0.66
Loss: 0.6544513702392578, Accuracy: 0.65
Loss: 0.6954410076141357, Accuracy: 0.58
Loss: 0.643179714679718, Accuracy: 0.6
Loss: 0.6104573607444763, Accuracy: 0.67
Loss: 0.7712092399597168, Accuracy: 0.5
Loss: 0.5910837650299072, Accuracy: 0.68
Loss: 0.6511735320091248, Accuracy: 0.6
Loss: 0.6416870355606079, A

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.6520463824272156, Accuracy: 0.63
Loss: 0.5973290205001831, Accuracy: 0.68
Loss: 0.5760281085968018, Accuracy: 0.7
Loss: 0.6003818511962891, Accuracy: 0.67
Loss: 0.5666558146476746, Accuracy: 0.68
Loss: 0.5752565264701843, Accuracy: 0.68
Loss: 0.5777528882026672, Accuracy: 0.7
Loss: 0.64691162109375, Accuracy: 0.65
Loss: 0.5532845258712769, Accuracy: 0.7
Loss: 0.6207194924354553, Accuracy: 0.68
Loss: 0.5620065331459045, Accuracy: 0.71
Loss: 0.5550460815429688, Accuracy: 0.7
Loss: 0.5572571754455566, Accuracy: 0.72
Loss: 0.5894081592559814, Accuracy: 0.68
Loss: 0.596457839012146, Accuracy: 0.71
Loss: 0.5844485759735107, Accuracy: 0.72
Loss: 0.6435142755508423, Accuracy: 0.62
Loss: 0.853756844997406, Accuracy: 0.53
Loss: 0.6056102514266968, Accuracy: 0.64
Loss: 0.6658840179443359, Accuracy: 0.59
Loss: 0.6126827001571655, Accuracy: 0.68
Loss: 0.6427200436592102, Accuracy: 0.65
Loss: 0.6132991909980774, Accuracy: 0.67
Loss: 0.6163018941879272, Accuracy: 0.68
Loss: 0.6115663051605225

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5113239884376526, Accuracy: 0.79
Loss: 0.5913992524147034, Accuracy: 0.71
Loss: 0.6241251230239868, Accuracy: 0.68
Loss: 0.5786783695220947, Accuracy: 0.69
Loss: 0.5484888553619385, Accuracy: 0.73
Loss: 0.6197939515113831, Accuracy: 0.6
Loss: 0.6078637838363647, Accuracy: 0.66
Loss: 0.5823365449905396, Accuracy: 0.67
Loss: 0.6083751320838928, Accuracy: 0.67
Loss: 0.6838860511779785, Accuracy: 0.6
Loss: 0.6315487623214722, Accuracy: 0.62
Loss: 0.5908374190330505, Accuracy: 0.69
Loss: 0.618710458278656, Accuracy: 0.64
Loss: 0.595784604549408, Accuracy: 0.67
Loss: 0.5636928677558899, Accuracy: 0.71
Loss: 0.614997923374176, Accuracy: 0.64
Loss: 0.5990140438079834, Accuracy: 0.65
Loss: 0.622967541217804, Accuracy: 0.65
Loss: 0.637529194355011, Accuracy: 0.66
Loss: 0.690875768661499, Accuracy: 0.59
Loss: 0.5678872466087341, Accuracy: 0.72
Loss: 0.6110143065452576, Accuracy: 0.65
Loss: 0.6187320947647095, Accuracy: 0.66
Loss: 0.5941398739814758, Accuracy: 0.7
Loss: 0.6071233153343201,

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5351019501686096, Accuracy: 0.75
Loss: 0.6276341080665588, Accuracy: 0.65
Loss: 0.6279945969581604, Accuracy: 0.62
Loss: 0.5905330777168274, Accuracy: 0.7
Loss: 0.534795880317688, Accuracy: 0.74
Loss: 0.618872880935669, Accuracy: 0.66
Loss: 0.6622747778892517, Accuracy: 0.58
Loss: 0.6244267821311951, Accuracy: 0.64
Loss: 0.5633645057678223, Accuracy: 0.72
Loss: 0.5493077635765076, Accuracy: 0.73
Loss: 0.5931513905525208, Accuracy: 0.74
Loss: 0.5478883981704712, Accuracy: 0.74
Loss: 0.5568809509277344, Accuracy: 0.74
Loss: 0.5398498177528381, Accuracy: 0.7
Loss: 0.5699109435081482, Accuracy: 0.72
Loss: 0.4856789708137512, Accuracy: 0.78
Loss: 0.5436105132102966, Accuracy: 0.71
Loss: 0.6430768370628357, Accuracy: 0.66
Loss: 0.6677400469779968, Accuracy: 0.64
Loss: 0.5601648092269897, Accuracy: 0.74
Loss: 0.5590203404426575, Accuracy: 0.76
Loss: 0.5025854706764221, Accuracy: 0.74
Loss: 0.614219069480896, Accuracy: 0.72
Loss: 0.5419645309448242, Accuracy: 0.7
Loss: 0.66423112154006

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5973528027534485, Accuracy: 0.7
Loss: 0.5843245983123779, Accuracy: 0.68
Loss: 0.594640851020813, Accuracy: 0.66
Loss: 0.5338210463523865, Accuracy: 0.71
Loss: 0.624112069606781, Accuracy: 0.64
Loss: 0.5400570631027222, Accuracy: 0.72
Loss: 0.5695996284484863, Accuracy: 0.69
Loss: 0.6221241354942322, Accuracy: 0.61
Loss: 0.6082391738891602, Accuracy: 0.63
Loss: 0.5860830545425415, Accuracy: 0.69
Loss: 0.5560788512229919, Accuracy: 0.73
Loss: 0.5386773943901062, Accuracy: 0.72
Loss: 0.5597430467605591, Accuracy: 0.7
Loss: 0.47777146100997925, Accuracy: 0.76
Loss: 0.5152416825294495, Accuracy: 0.73
Loss: 0.6107290983200073, Accuracy: 0.69
Loss: 0.6271894574165344, Accuracy: 0.69
Loss: 0.6407873034477234, Accuracy: 0.65
Loss: 0.517270028591156, Accuracy: 0.73
Loss: 0.5667163133621216, Accuracy: 0.73
Loss: 0.7169260382652283, Accuracy: 0.57
Loss: 0.6702940464019775, Accuracy: 0.61
Loss: 0.6225292682647705, Accuracy: 0.65
Loss: 0.6315984129905701, Accuracy: 0.66
Loss: 0.529555618762

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5310337543487549, Accuracy: 0.74
Loss: 0.5592631101608276, Accuracy: 0.71
Loss: 0.5887239575386047, Accuracy: 0.73
Loss: 0.5964827537536621, Accuracy: 0.7
Loss: 0.5215810537338257, Accuracy: 0.76
Loss: 0.5206197500228882, Accuracy: 0.69
Loss: 0.5317584872245789, Accuracy: 0.76
Loss: 0.5371253490447998, Accuracy: 0.74
Loss: 0.5013887286186218, Accuracy: 0.73
Loss: 0.599990963935852, Accuracy: 0.72
Loss: 0.5893210768699646, Accuracy: 0.68
Loss: 0.551275372505188, Accuracy: 0.76
Loss: 0.5610489845275879, Accuracy: 0.7
Loss: 0.4997388422489166, Accuracy: 0.77
Loss: 0.5505139231681824, Accuracy: 0.74
Loss: 0.5193781852722168, Accuracy: 0.7
Loss: 0.5350690484046936, Accuracy: 0.75
Loss: 0.7128503322601318, Accuracy: 0.55
Loss: 0.5307496786117554, Accuracy: 0.77
Loss: 0.6133606433868408, Accuracy: 0.7
Loss: 0.5929160118103027, Accuracy: 0.62
Loss: 0.5020261406898499, Accuracy: 0.76
Loss: 0.5454347729682922, Accuracy: 0.7
Loss: 0.5289321541786194, Accuracy: 0.74
Loss: 0.46331986784935,

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.5762773752212524, Accuracy: 0.66
Loss: 0.5616966485977173, Accuracy: 0.74
Loss: 0.5981369614601135, Accuracy: 0.69
Loss: 0.44648295640945435, Accuracy: 0.81
Loss: 0.5554646849632263, Accuracy: 0.67
Loss: 0.49976861476898193, Accuracy: 0.74
Loss: 0.4716343581676483, Accuracy: 0.81
Loss: 0.4973128139972687, Accuracy: 0.76
Loss: 0.4895037114620209, Accuracy: 0.77
Loss: 0.5987851619720459, Accuracy: 0.65
Loss: 0.443391889333725, Accuracy: 0.81
Loss: 0.6356395483016968, Accuracy: 0.65
Loss: 0.6145510077476501, Accuracy: 0.72
Loss: 0.5902463793754578, Accuracy: 0.77
Loss: 0.5878056287765503, Accuracy: 0.7
Loss: 0.5631515979766846, Accuracy: 0.7
Loss: 0.5118572115898132, Accuracy: 0.77
Loss: 0.5665714144706726, Accuracy: 0.67
Loss: 0.5206239819526672, Accuracy: 0.81
Loss: 0.62160325050354, Accuracy: 0.65
Loss: 0.5432358384132385, Accuracy: 0.73
Loss: 0.49285688996315, Accuracy: 0.77
Loss: 0.535338282585144, Accuracy: 0.75
Loss: 0.6071218252182007, Accuracy: 0.67
Loss: 0.53748327493667

  0%|          | 0/250 [00:00<?, ?it/s]

Loss: 0.49541667103767395, Accuracy: 0.77
Loss: 0.5485630631446838, Accuracy: 0.75
Loss: 0.5459569692611694, Accuracy: 0.72
Loss: 0.528515100479126, Accuracy: 0.72
Loss: 0.4873869717121124, Accuracy: 0.79
Loss: 0.5611368417739868, Accuracy: 0.73
Loss: 0.49046647548675537, Accuracy: 0.79
Loss: 0.4401278793811798, Accuracy: 0.79
Loss: 0.5053061842918396, Accuracy: 0.76
Loss: 0.5958578586578369, Accuracy: 0.67
Loss: 0.5386735796928406, Accuracy: 0.73
Loss: 0.5323777198791504, Accuracy: 0.77
Loss: 0.45400798320770264, Accuracy: 0.78
Loss: 0.4605492353439331, Accuracy: 0.77
Loss: 0.48613080382347107, Accuracy: 0.73
Loss: 0.5844961404800415, Accuracy: 0.67
Loss: 0.5472129583358765, Accuracy: 0.76
Loss: 0.49379104375839233, Accuracy: 0.76
Loss: 0.5259675979614258, Accuracy: 0.71
Loss: 0.5354723334312439, Accuracy: 0.71
Loss: 0.48391464352607727, Accuracy: 0.82
Loss: 0.5211396813392639, Accuracy: 0.78
Loss: 0.46093645691871643, Accuracy: 0.75
Loss: 0.47968336939811707, Accuracy: 0.78
Loss: 0.5

## Evaluate the model

Finally, we evaluate the model on the test set. We use the Datasets library to compute the accuracy.

In [23]:
torch.save(model.state_dict(), '/content/drive/MyDrive/saved_model/small_network_model_checkpoint.pt')

Mounted at /content/drive


In [16]:

# import torch
# checkpoint = torch.load('/content/drive/MyDrive/saved_model/small_network_model.pt')
# model.load_state_dict(checkpoint)
# model.eval()

RuntimeError: ignored

In [18]:
from tqdm.notebook import tqdm
from datasets import load_metric

accuracy = load_metric("accuracy")

model.eval()
for batch in tqdm(test_dataloader):
      # get the inputs; 
      inputs = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # forward pass
      outputs = model(inputs=inputs, attention_mask=attention_mask)
      logits = outputs.logits 
      predictions = logits.argmax(-1).cpu().detach().numpy()
      references = batch["label"].numpy()
      accuracy.add_batch(predictions=predictions, references=references)

final_score = accuracy.compute()
print("Accuracy on test set:", final_score)

  after removing the cwd from sys.path.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Accuracy on test set: {'accuracy': 0.62816}


## Inference

In [22]:
text = "I hated this movie, it's really bad."

input_ids = tokenizer(text, return_tensors="pt").input_ids

# forward pass
outputs = model(inputs=input_ids.to(device))
logits = outputs.logits 
predicted_class_idx = logits.argmax(-1).item()

print("Predicted:", model.config.id2label[predicted_class_idx])

Predicted: LABEL_1
