In [2]:
! pip install transformers
!pip install  --pre pycaret

[0mCollecting pycaret
  Downloading pycaret-3.0.0rc4-py3-none-any.whl (487 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.2/487.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting plotly-resampler>=0.7.2.2
  Downloading plotly_resampler-0.8.2.tar.gz (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting schemdraw>=0.14
  Downloading schemdraw-0.15-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.8/106.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
Collecting sktime~=0.13.2
  Downloading sktime-0.13.4-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m26.3 MB/s[0m eta [36

In [3]:
import numpy as np
import pandas as pd
import pycaret
import transformers
from transformers import AutoModel, BertTokenizerFast
import matplotlib.pyplot as plt
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# specify GPU
device = torch.device("cuda")

In [4]:
!pip install datasets 
from datasets import load_dataset
dataset = load_dataset("liar")

[0m

Downloading builder script:   0%|          | 0.00/2.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading and preparing dataset liar/default (download: 989.82 KiB, generated: 3.26 MiB, post-processed: Unknown size, total: 4.22 MiB) to /root/.cache/huggingface/datasets/liar/default/1.0.0/479463e757b7991eed50ffa7504d7788d6218631a484442e2098dabbf3b44514...


Downloading data:   0%|          | 0.00/1.01M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10269 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1283 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1284 [00:00<?, ? examples/s]

Dataset liar downloaded and prepared to /root/.cache/huggingface/datasets/liar/default/1.0.0/479463e757b7991eed50ffa7504d7788d6218631a484442e2098dabbf3b44514. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
        num_rows: 10269
    })
    test: Dataset({
        features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
        num_rows: 1283
    })
    validation: Dataset({
        features: ['id', 'label', 'statement', 'subject', 'speaker', 'job_title', 'state_info', 'party_affiliation', 'barely_true_counts', 'false_counts', 'half_true_counts', 'mostly_true_counts', 'pants_on_fire_counts', 'context'],
        num_rows: 1284
    })
})


In [6]:
# loading pre-trained models
from transformers import RobertaForSequenceClassification,RobertaTokenizer
from transformers import RobertaTokenizerFast
                                                           

# RoBERTa
roberta = AutoModel.from_pretrained("roberta-base")
roberta_tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")


print(' Base models loaded')

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

 Base models loaded


In [7]:

MAX_LENGHT = 100

# Tokenize and encode sequences in the train set
Text=dataset["train"]["statement"]
tokens_train = roberta_tokenizer.batch_encode_plus(
    Text,
    max_length = MAX_LENGHT,
    pad_to_max_length=True,
    truncation=True
)

# tokenize and encode sequences in the validation set
Text=dataset["validation"]["statement"]
tokens_val = roberta_tokenizer.batch_encode_plus(
    Text,
    max_length = MAX_LENGHT,
    pad_to_max_length=True,
    truncation=True
)

# tokenize and encode sequences in the test set
Text=dataset["test"]["statement"]
tokens_test = roberta_tokenizer.batch_encode_plus(
    Text,
    max_length = MAX_LENGHT,
    pad_to_max_length=True,
    truncation=True
)

In [8]:
# Convert lists to tensors
train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(dataset["train"]["label"])

val_seq = torch.tensor(tokens_val['input_ids'])
val_mask = torch.tensor(tokens_val['attention_mask'])
val_y = torch.tensor(dataset["validation"]["label"])

test_seq = torch.tensor(tokens_test['input_ids'])
test_mask = torch.tensor(tokens_test['attention_mask'])
test_y = torch.tensor(dataset["test"]["label"])

In [9]:
print(len(val_seq),len(val_mask),len(val_y))
print(len(test_seq),len(test_mask),len(test_y))

print(len(train_seq),len(train_mask),len(train_y))
print(val_y)
print(len(dataset["validation"]["statement"]))
print()

1284 1284 1284
1283 1283 1283
10269 10269 10269
tensor([4, 5, 0,  ..., 3, 0, 4])
1284



In [10]:
# Data Loader structure definition
batch_size = 32                                               #define a batch size

train_data = TensorDataset(train_seq, train_mask, train_y)    # wrap tensors
train_sampler = RandomSampler(train_data)                     # sampler for sampling the data during training
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
                                                              # dataLoader for train set
val_data = TensorDataset(val_seq, val_mask, val_y)            # wrap tensors
val_sampler = SequentialSampler(val_data)                     # sampler for sampling the data during training
val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)
                                                              # dataLoader for validation set

In [11]:
# Freezing the parameters and defining trainable BERT structure
for param in roberta.parameters():
    param.requires_grad = True   # false here means gradient need not be computed

In [12]:
class RoBERT_Arch(nn.Module):
    def __init__(self, roberta):  
      super(RoBERT_Arch, self).__init__()
      self.roberta = roberta   
      self.dropout = nn.Dropout(0.1)            # dropout layer
      self.relu =  nn.ReLU()                    # relu activation function
      self.fc1 = nn.Linear(768,512)             # dense layer 1
      self.fc2 = nn.Linear(512,6)               # dense layer 2 (Output layer)
      self.softmax = nn.LogSoftmax(dim=1)       # softmax activation function
    def forward(self, sent_id, mask):           # define the forward pass  
      cls_hs = self.roberta(sent_id, attention_mask=mask)['pooler_output']
                                                # pass the inputs to the model
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)                           # output layer
      x = self.softmax(x)                       # apply softmax activation
      return x

model = RoBERT_Arch(roberta)
# Defining the hyperparameters (optimizer, weights of the classes and the epochs)
# Define the optimizer
from transformers import AdamW
optimizer = AdamW(model.parameters(),
                  lr = 1e-5)          # learning rate
# Define the loss function
cross_entropy  = nn.NLLLoss() 
# Number of training epochs
epochs = 3

In [13]:
# Defining training and evaluation functions
def train():  
  model.train()
  total_loss, total_accuracy = 0, 0
  
  for step,batch in enumerate(tqdm(train_dataloader)):                # iterate over batches
    if step % 50 == 0 and not step == 0:                        # progress update after every 50 batches.
      print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))
    batch = [r for r in batch]                                  # push the batch to gpu
    sent_id, mask, labels = batch 
    model.zero_grad()                                           # clear previously calculated gradients
    preds = model(sent_id, mask)                                # get model predictions for current batch
    loss = cross_entropy(preds, labels)                         # compute loss between actual & predicted values
    total_loss = total_loss + loss.item()                       # add on to the total loss
    loss.backward()                                             # backward pass to calculate the gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)     # clip gradients to 1.0. It helps in preventing exploding gradient problem
    optimizer.step()                                            # update parameters
    preds=preds.detach().cpu().numpy()                          # model predictions are stored on GPU. So, push it to CPU

  avg_loss = total_loss / len(train_dataloader)                 # compute training loss of the epoch  
                                                                # reshape predictions in form of (# samples, # classes)
  return avg_loss                                 # returns the loss and predictions

def evaluate():  
  print("\nEvaluating...")  
  model.eval()                                    # Deactivate dropout layers
  total_loss, total_accuracy = 0, 0  
  for step,batch in enumerate(tqdm(val_dataloader)):    # Iterate over batches  
    if step % 50 == 0 and not step == 0:          # Progress update every 50 batches.     
                                                  # Calculate elapsed time in minutes.
                                                  # Elapsed = format_time(time.time() - t0)
      print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))
                                                  # Report progress
    batch = [t for t in batch]                    # Push the batch to GPU
    sent_id, mask, labels = batch
    with torch.no_grad():                         # Deactivate autograd
      preds = model(sent_id, mask)                # Model predictions
      loss = cross_entropy(preds,labels)          # Compute the validation loss between actual and predicted values
      total_loss = total_loss + loss.item()
      preds = preds.detach().cpu().numpy()
  avg_loss = total_loss / len(val_dataloader)         # compute the validation loss of the epoch
  return avg_loss

In [14]:
# Train and predict


best_valid_loss = float('inf')
train_losses=[]                   # empty lists to store training and validation loss of each epoch
valid_losses=[]

for epoch in tqdm(range(epochs)):     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))     
    train_loss = train()                       # train model
    valid_loss = evaluate()                    # evaluate model
    if valid_loss < best_valid_loss:              # save the best model
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'changed_weights_roberta.pt')
    train_losses.append(train_loss)               # append training and validation loss
    valid_losses.append(valid_loss)
    
    print(f'\nTraining Loss: {train_loss:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')

  0%|          | 0/3 [00:00<?, ?it/s]


 Epoch 1 / 3



  0%|          | 0/321 [00:00<?, ?it/s][A
  0%|          | 1/321 [00:19<1:41:24, 19.01s/it][A
  1%|          | 2/321 [00:36<1:36:46, 18.20s/it][A
  1%|          | 3/321 [00:52<1:31:23, 17.24s/it][A
  1%|          | 4/321 [01:08<1:28:33, 16.76s/it][A
  2%|▏         | 5/321 [01:24<1:26:53, 16.50s/it][A
  2%|▏         | 6/321 [01:40<1:25:22, 16.26s/it][A
  2%|▏         | 7/321 [01:56<1:23:58, 16.05s/it][A
  2%|▏         | 8/321 [02:12<1:23:58, 16.10s/it][A
  3%|▎         | 9/321 [02:28<1:23:15, 16.01s/it][A
  3%|▎         | 10/321 [02:44<1:22:56, 16.00s/it][A
  3%|▎         | 11/321 [02:59<1:22:06, 15.89s/it][A
  4%|▎         | 12/321 [03:16<1:22:38, 16.05s/it][A
  4%|▍         | 13/321 [03:32<1:22:04, 15.99s/it][A
  4%|▍         | 14/321 [03:48<1:22:31, 16.13s/it][A
  5%|▍         | 15/321 [04:04<1:21:36, 16.00s/it][A
  5%|▍         | 16/321 [04:20<1:22:17, 16.19s/it][A
  5%|▌         | 17/321 [04:36<1:20:47, 15.95s/it][A
  6%|▌         | 18/321 [04:52<1:20:32, 15.95s/

  Batch    50  of    321.



 16%|█▌        | 51/321 [13:39<1:11:28, 15.88s/it][A
 16%|█▌        | 52/321 [13:55<1:11:11, 15.88s/it][A
 17%|█▋        | 53/321 [14:11<1:11:10, 15.94s/it][A
 17%|█▋        | 54/321 [14:27<1:10:51, 15.92s/it][A
 17%|█▋        | 55/321 [14:43<1:10:53, 15.99s/it][A
 17%|█▋        | 56/321 [14:59<1:10:26, 15.95s/it][A
 18%|█▊        | 57/321 [15:14<1:10:08, 15.94s/it][A
 18%|█▊        | 58/321 [15:30<1:09:39, 15.89s/it][A
 18%|█▊        | 59/321 [15:47<1:10:13, 16.08s/it][A
 19%|█▊        | 60/321 [16:03<1:09:28, 15.97s/it][A
 19%|█▉        | 61/321 [16:19<1:10:31, 16.28s/it][A
 19%|█▉        | 62/321 [16:35<1:09:46, 16.16s/it][A
 20%|█▉        | 63/321 [16:52<1:09:50, 16.24s/it][A
 20%|█▉        | 64/321 [17:07<1:08:33, 16.01s/it][A
 20%|██        | 65/321 [17:24<1:09:01, 16.18s/it][A
 21%|██        | 66/321 [17:40<1:08:15, 16.06s/it][A
 21%|██        | 67/321 [17:56<1:08:36, 16.21s/it][A
 21%|██        | 68/321 [18:12<1:07:54, 16.10s/it][A
 21%|██▏       | 69/321 [18

  Batch   100  of    321.



 31%|███▏      | 101/321 [26:56<57:54, 15.79s/it][A
 32%|███▏      | 102/321 [27:12<57:21, 15.72s/it][A
 32%|███▏      | 103/321 [27:28<57:37, 15.86s/it][A
 32%|███▏      | 104/321 [27:43<56:49, 15.71s/it][A
 33%|███▎      | 105/321 [27:59<56:55, 15.81s/it][A
 33%|███▎      | 106/321 [28:15<56:19, 15.72s/it][A
 33%|███▎      | 107/321 [28:31<56:32, 15.85s/it][A
 34%|███▎      | 108/321 [28:46<55:41, 15.69s/it][A
 34%|███▍      | 109/321 [29:02<55:32, 15.72s/it][A
 34%|███▍      | 110/321 [29:17<54:56, 15.62s/it][A
 35%|███▍      | 111/321 [29:34<55:14, 15.78s/it][A
 35%|███▍      | 112/321 [29:49<54:42, 15.70s/it][A
 35%|███▌      | 113/321 [30:05<54:57, 15.85s/it][A
 36%|███▌      | 114/321 [30:21<54:32, 15.81s/it][A
 36%|███▌      | 115/321 [30:37<54:31, 15.88s/it][A
 36%|███▌      | 116/321 [30:52<53:25, 15.64s/it][A
 36%|███▋      | 117/321 [31:08<53:37, 15.77s/it][A
 37%|███▋      | 118/321 [31:24<53:01, 15.67s/it][A
 37%|███▋      | 119/321 [31:40<53:24, 15.87s

  Batch   150  of    321.



 47%|████▋     | 151/321 [40:00<44:22, 15.66s/it][A
 47%|████▋     | 152/321 [40:16<43:56, 15.60s/it][A
 48%|████▊     | 153/321 [40:31<43:45, 15.63s/it][A
 48%|████▊     | 154/321 [40:47<43:08, 15.50s/it][A
 48%|████▊     | 155/321 [41:02<42:59, 15.54s/it][A
 49%|████▊     | 156/321 [41:18<42:34, 15.48s/it][A
 49%|████▉     | 157/321 [41:33<42:40, 15.61s/it][A
 49%|████▉     | 158/321 [41:49<42:10, 15.53s/it][A
 50%|████▉     | 159/321 [42:05<42:06, 15.60s/it][A
 50%|████▉     | 160/321 [42:20<41:57, 15.64s/it][A
 50%|█████     | 161/321 [42:36<41:49, 15.69s/it][A
 50%|█████     | 162/321 [42:52<41:26, 15.64s/it][A
 51%|█████     | 163/321 [43:07<40:50, 15.51s/it][A
 51%|█████     | 164/321 [43:22<40:28, 15.47s/it][A
 51%|█████▏    | 165/321 [43:37<40:02, 15.40s/it][A
 52%|█████▏    | 166/321 [43:53<40:06, 15.53s/it][A
 52%|█████▏    | 167/321 [44:08<39:33, 15.41s/it][A
 52%|█████▏    | 168/321 [44:24<39:42, 15.57s/it][A
 53%|█████▎    | 169/321 [44:40<39:29, 15.59s

  Batch   200  of    321.



 63%|██████▎   | 201/321 [53:02<31:35, 15.80s/it][A
 63%|██████▎   | 202/321 [53:17<31:21, 15.81s/it][A
 63%|██████▎   | 203/321 [53:33<30:48, 15.66s/it][A
 64%|██████▎   | 204/321 [53:49<30:39, 15.72s/it][A
 64%|██████▍   | 205/321 [54:04<30:05, 15.57s/it][A
 64%|██████▍   | 206/321 [54:20<30:00, 15.66s/it][A
 64%|██████▍   | 207/321 [54:35<29:40, 15.61s/it][A
 65%|██████▍   | 208/321 [54:51<29:33, 15.69s/it][A
 65%|██████▌   | 209/321 [55:06<29:00, 15.54s/it][A
 65%|██████▌   | 210/321 [55:22<28:56, 15.65s/it][A
 66%|██████▌   | 211/321 [55:37<28:31, 15.56s/it][A
 66%|██████▌   | 212/321 [55:53<28:26, 15.66s/it][A
 66%|██████▋   | 213/321 [56:09<28:03, 15.59s/it][A
 67%|██████▋   | 214/321 [56:25<28:05, 15.75s/it][A
 67%|██████▋   | 215/321 [56:40<27:40, 15.66s/it][A
 67%|██████▋   | 216/321 [56:56<27:30, 15.72s/it][A
 68%|██████▊   | 217/321 [57:11<26:59, 15.57s/it][A
 68%|██████▊   | 218/321 [57:27<26:40, 15.54s/it][A
 68%|██████▊   | 219/321 [57:42<26:16, 15.46s

  Batch   250  of    321.



 78%|███████▊  | 251/321 [1:06:03<18:13, 15.62s/it][A
 79%|███████▊  | 252/321 [1:06:19<18:09, 15.78s/it][A
 79%|███████▉  | 253/321 [1:06:34<17:42, 15.62s/it][A
 79%|███████▉  | 254/321 [1:06:50<17:31, 15.70s/it][A
 79%|███████▉  | 255/321 [1:07:06<17:10, 15.61s/it][A
 80%|███████▉  | 256/321 [1:07:21<16:56, 15.64s/it][A
 80%|████████  | 257/321 [1:07:37<16:50, 15.78s/it][A
 80%|████████  | 258/321 [1:07:53<16:33, 15.77s/it][A
 81%|████████  | 259/321 [1:08:08<16:09, 15.63s/it][A
 81%|████████  | 260/321 [1:08:24<15:56, 15.68s/it][A
 81%|████████▏ | 261/321 [1:08:40<15:35, 15.59s/it][A
 82%|████████▏ | 262/321 [1:08:55<15:20, 15.61s/it][A
 82%|████████▏ | 263/321 [1:09:11<15:01, 15.54s/it][A
 82%|████████▏ | 264/321 [1:09:26<14:47, 15.57s/it][A
 83%|████████▎ | 265/321 [1:09:42<14:27, 15.49s/it][A
 83%|████████▎ | 266/321 [1:09:58<14:19, 15.62s/it][A
 83%|████████▎ | 267/321 [1:10:13<14:05, 15.65s/it][A
 83%|████████▎ | 268/321 [1:10:29<13:54, 15.75s/it][A
 84%|████

  Batch   300  of    321.



 94%|█████████▍| 301/321 [1:19:06<05:06, 15.35s/it][A
 94%|█████████▍| 302/321 [1:19:22<04:51, 15.35s/it][A
 94%|█████████▍| 303/321 [1:19:37<04:35, 15.32s/it][A
 95%|█████████▍| 304/321 [1:19:52<04:20, 15.33s/it][A
 95%|█████████▌| 305/321 [1:20:07<04:05, 15.33s/it][A
 95%|█████████▌| 306/321 [1:20:23<03:49, 15.28s/it][A
 96%|█████████▌| 307/321 [1:20:38<03:35, 15.41s/it][A
 96%|█████████▌| 308/321 [1:20:53<03:18, 15.30s/it][A
 96%|█████████▋| 309/321 [1:21:09<03:04, 15.37s/it][A
 97%|█████████▋| 310/321 [1:21:24<02:47, 15.22s/it][A
 97%|█████████▋| 311/321 [1:21:40<02:33, 15.39s/it][A
 97%|█████████▋| 312/321 [1:21:55<02:17, 15.27s/it][A
 98%|█████████▊| 313/321 [1:22:10<02:02, 15.35s/it][A
 98%|█████████▊| 314/321 [1:22:25<01:46, 15.21s/it][A
 98%|█████████▊| 315/321 [1:22:41<01:31, 15.32s/it][A
 98%|█████████▊| 316/321 [1:22:56<01:16, 15.25s/it][A
 99%|█████████▉| 317/321 [1:23:11<01:01, 15.40s/it][A
 99%|█████████▉| 318/321 [1:23:26<00:45, 15.27s/it][A
 99%|████


Evaluating...



  0%|          | 0/41 [00:00<?, ?it/s][A
  2%|▏         | 1/41 [00:04<03:09,  4.75s/it][A
  5%|▍         | 2/41 [00:09<03:02,  4.69s/it][A
  7%|▋         | 3/41 [00:13<02:52,  4.54s/it][A
 10%|▉         | 4/41 [00:18<02:51,  4.62s/it][A
 12%|█▏        | 5/41 [00:23<02:45,  4.60s/it][A
 15%|█▍        | 6/41 [00:27<02:43,  4.68s/it][A
 17%|█▋        | 7/41 [00:32<02:43,  4.80s/it][A
 20%|█▉        | 8/41 [00:37<02:38,  4.79s/it][A
 22%|██▏       | 9/41 [00:42<02:31,  4.74s/it][A
 24%|██▍       | 10/41 [00:46<02:24,  4.68s/it][A
 27%|██▋       | 11/41 [00:51<02:20,  4.70s/it][A
 29%|██▉       | 12/41 [00:56<02:14,  4.64s/it][A
 32%|███▏      | 13/41 [01:01<02:16,  4.88s/it][A
 34%|███▍      | 14/41 [01:06<02:09,  4.78s/it][A
 37%|███▋      | 15/41 [01:10<02:03,  4.77s/it][A
 39%|███▉      | 16/41 [01:15<01:56,  4.67s/it][A
 41%|████▏     | 17/41 [01:19<01:51,  4.65s/it][A
 44%|████▍     | 18/41 [01:24<01:47,  4.65s/it][A
 46%|████▋     | 19/41 [01:29<01:41,  4.60s/it]


Training Loss: 1.748
Validation Loss: 1.717

 Epoch 2 / 3



  0%|          | 0/321 [00:00<?, ?it/s][A
  0%|          | 1/321 [00:15<1:20:38, 15.12s/it][A
  1%|          | 2/321 [00:31<1:23:35, 15.72s/it][A
  1%|          | 3/321 [00:46<1:22:56, 15.65s/it][A
  1%|          | 4/321 [01:02<1:22:50, 15.68s/it][A
  2%|▏         | 5/321 [01:18<1:22:09, 15.60s/it][A
  2%|▏         | 6/321 [01:33<1:22:34, 15.73s/it][A
  2%|▏         | 7/321 [01:49<1:21:25, 15.56s/it][A
  2%|▏         | 8/321 [02:04<1:21:02, 15.53s/it][A
  3%|▎         | 9/321 [02:19<1:20:06, 15.41s/it][A
  3%|▎         | 10/321 [02:35<1:19:48, 15.40s/it][A
  3%|▎         | 11/321 [02:50<1:18:47, 15.25s/it][A
  4%|▎         | 12/321 [03:04<1:17:54, 15.13s/it][A
  4%|▍         | 13/321 [03:20<1:18:52, 15.36s/it][A
  4%|▍         | 14/321 [03:36<1:18:16, 15.30s/it][A
  5%|▍         | 15/321 [03:52<1:19:13, 15.53s/it][A
  5%|▍         | 16/321 [04:07<1:18:26, 15.43s/it][A
  5%|▌         | 17/321 [04:22<1:17:58, 15.39s/it][A
  6%|▌         | 18/321 [04:37<1:17:25, 15.33s/

  Batch    50  of    321.



 16%|█▌        | 51/321 [13:11<1:10:21, 15.63s/it][A
 16%|█▌        | 52/321 [13:27<1:09:50, 15.58s/it][A
 17%|█▋        | 53/321 [13:43<1:09:52, 15.64s/it][A
 17%|█▋        | 54/321 [13:58<1:09:19, 15.58s/it][A
 17%|█▋        | 55/321 [14:14<1:09:02, 15.57s/it][A
 17%|█▋        | 56/321 [14:29<1:08:47, 15.58s/it][A
 18%|█▊        | 57/321 [14:45<1:08:37, 15.60s/it][A
 18%|█▊        | 58/321 [15:00<1:07:42, 15.45s/it][A
 18%|█▊        | 59/321 [15:16<1:07:55, 15.55s/it][A
 19%|█▊        | 60/321 [15:31<1:07:43, 15.57s/it][A
 19%|█▉        | 61/321 [15:47<1:08:13, 15.75s/it][A
 19%|█▉        | 62/321 [16:03<1:07:13, 15.57s/it][A
 20%|█▉        | 63/321 [16:18<1:07:04, 15.60s/it][A
 20%|█▉        | 64/321 [16:34<1:06:36, 15.55s/it][A
 20%|██        | 65/321 [16:50<1:06:37, 15.62s/it][A
 21%|██        | 66/321 [17:05<1:06:16, 15.59s/it][A
 21%|██        | 67/321 [17:21<1:06:33, 15.72s/it][A
 21%|██        | 68/321 [17:36<1:05:39, 15.57s/it][A
 21%|██▏       | 69/321 [17

  Batch   100  of    321.



 31%|███▏      | 101/321 [26:12<57:24, 15.66s/it][A
 32%|███▏      | 102/321 [26:28<57:24, 15.73s/it][A
 32%|███▏      | 103/321 [26:44<57:06, 15.72s/it][A
 32%|███▏      | 104/321 [27:00<57:04, 15.78s/it][A
 33%|███▎      | 105/321 [27:15<56:15, 15.63s/it][A
 33%|███▎      | 106/321 [27:32<56:44, 15.83s/it][A
 33%|███▎      | 107/321 [27:47<56:06, 15.73s/it][A
 34%|███▎      | 108/321 [28:03<56:07, 15.81s/it][A
 34%|███▍      | 109/321 [28:18<55:19, 15.66s/it][A
 34%|███▍      | 110/321 [28:34<55:20, 15.74s/it][A
 35%|███▍      | 111/321 [28:50<54:42, 15.63s/it][A
 35%|███▍      | 112/321 [29:06<54:40, 15.70s/it][A
 35%|███▌      | 113/321 [29:21<54:15, 15.65s/it][A
 36%|███▌      | 114/321 [29:37<53:56, 15.64s/it][A
 36%|███▌      | 115/321 [29:52<53:14, 15.51s/it][A
 36%|███▌      | 116/321 [30:08<53:20, 15.61s/it][A
 36%|███▋      | 117/321 [30:23<52:49, 15.54s/it][A
 37%|███▋      | 118/321 [30:39<52:54, 15.64s/it][A
 37%|███▋      | 119/321 [30:54<52:01, 15.45s

  Batch   150  of    321.



 47%|████▋     | 151/321 [39:13<45:09, 15.94s/it][A
 47%|████▋     | 152/321 [39:29<44:47, 15.90s/it][A
 48%|████▊     | 153/321 [39:45<44:12, 15.79s/it][A
 48%|████▊     | 154/321 [40:00<43:56, 15.79s/it][A
 48%|████▊     | 155/321 [40:16<43:12, 15.62s/it][A
 49%|████▊     | 156/321 [40:31<43:04, 15.66s/it][A
 49%|████▉     | 157/321 [40:47<42:35, 15.58s/it][A
 49%|████▉     | 158/321 [41:02<42:10, 15.52s/it][A
 50%|████▉     | 159/321 [41:17<41:37, 15.41s/it][A
 50%|████▉     | 160/321 [41:33<41:21, 15.41s/it][A
 50%|█████     | 161/321 [41:48<41:11, 15.44s/it][A
 50%|█████     | 162/321 [42:04<40:53, 15.43s/it][A
 51%|█████     | 163/321 [42:19<40:23, 15.34s/it][A
 51%|█████     | 164/321 [42:35<40:32, 15.49s/it][A
 51%|█████▏    | 165/321 [42:50<40:09, 15.44s/it][A
 52%|█████▏    | 166/321 [43:05<39:54, 15.45s/it][A
 52%|█████▏    | 167/321 [43:21<39:31, 15.40s/it][A
 52%|█████▏    | 168/321 [43:36<39:01, 15.31s/it][A
 53%|█████▎    | 169/321 [43:52<39:30, 15.60s

  Batch   200  of    321.



 63%|██████▎   | 201/321 [52:10<31:01, 15.51s/it][A
 63%|██████▎   | 202/321 [52:25<30:38, 15.45s/it][A
 63%|██████▎   | 203/321 [52:41<30:43, 15.62s/it][A
 64%|██████▎   | 204/321 [52:56<30:19, 15.55s/it][A
 64%|██████▍   | 205/321 [53:12<30:07, 15.58s/it][A
 64%|██████▍   | 206/321 [53:27<29:38, 15.46s/it][A
 64%|██████▍   | 207/321 [53:43<29:43, 15.64s/it][A
 65%|██████▍   | 208/321 [53:59<29:12, 15.51s/it][A
 65%|██████▌   | 209/321 [54:14<29:06, 15.60s/it][A
 65%|██████▌   | 210/321 [54:30<28:48, 15.57s/it][A
 66%|██████▌   | 211/321 [54:45<28:33, 15.58s/it][A
 66%|██████▌   | 212/321 [55:01<28:13, 15.54s/it][A
 66%|██████▋   | 213/321 [55:16<28:01, 15.56s/it][A
 67%|██████▋   | 214/321 [55:32<27:41, 15.53s/it][A
 67%|██████▋   | 215/321 [55:48<27:42, 15.68s/it][A
 67%|██████▋   | 216/321 [56:03<27:12, 15.55s/it][A
 68%|██████▊   | 217/321 [56:19<27:01, 15.59s/it][A
 68%|██████▊   | 218/321 [56:34<26:37, 15.51s/it][A
 68%|██████▊   | 219/321 [56:50<26:39, 15.68s

  Batch   250  of    321.



 78%|███████▊  | 251/321 [1:05:05<18:14, 15.63s/it][A
 79%|███████▊  | 252/321 [1:05:20<18:02, 15.69s/it][A
 79%|███████▉  | 253/321 [1:05:36<17:40, 15.59s/it][A
 79%|███████▉  | 254/321 [1:05:52<17:38, 15.80s/it][A
 79%|███████▉  | 255/321 [1:06:07<17:08, 15.59s/it][A
 80%|███████▉  | 256/321 [1:06:23<17:00, 15.70s/it][A
 80%|████████  | 257/321 [1:06:39<16:41, 15.64s/it][A
 80%|████████  | 258/321 [1:06:55<16:32, 15.76s/it][A
 81%|████████  | 259/321 [1:07:10<16:07, 15.60s/it][A
 81%|████████  | 260/321 [1:07:25<15:50, 15.58s/it][A
 81%|████████▏ | 261/321 [1:07:41<15:28, 15.48s/it][A
 82%|████████▏ | 262/321 [1:07:57<15:20, 15.60s/it][A
 82%|████████▏ | 263/321 [1:08:12<15:01, 15.54s/it][A
 82%|████████▏ | 264/321 [1:08:28<14:50, 15.62s/it][A
 83%|████████▎ | 265/321 [1:08:43<14:34, 15.61s/it][A
 83%|████████▎ | 266/321 [1:09:00<14:29, 15.82s/it][A
 83%|████████▎ | 267/321 [1:09:14<13:56, 15.49s/it][A
 83%|████████▎ | 268/321 [1:09:30<13:48, 15.63s/it][A
 84%|████

  Batch   300  of    321.



 94%|█████████▍| 301/321 [1:18:02<05:11, 15.59s/it][A
 94%|█████████▍| 302/321 [1:18:17<04:52, 15.39s/it][A
 94%|█████████▍| 303/321 [1:18:32<04:39, 15.52s/it][A
 95%|█████████▍| 304/321 [1:18:48<04:23, 15.50s/it][A
 95%|█████████▌| 305/321 [1:19:04<04:09, 15.62s/it][A
 95%|█████████▌| 306/321 [1:19:19<03:52, 15.47s/it][A
 96%|█████████▌| 307/321 [1:19:35<03:38, 15.62s/it][A
 96%|█████████▌| 308/321 [1:19:50<03:20, 15.45s/it][A
 96%|█████████▋| 309/321 [1:20:06<03:08, 15.73s/it][A
 97%|█████████▋| 310/321 [1:20:22<02:51, 15.61s/it][A
 97%|█████████▋| 311/321 [1:20:38<02:37, 15.73s/it][A
 97%|█████████▋| 312/321 [1:20:53<02:21, 15.70s/it][A
 98%|█████████▊| 313/321 [1:21:09<02:05, 15.75s/it][A
 98%|█████████▊| 314/321 [1:21:24<01:48, 15.46s/it][A
 98%|█████████▊| 315/321 [1:21:40<01:33, 15.52s/it][A
 98%|█████████▊| 316/321 [1:21:55<01:17, 15.45s/it][A
 99%|█████████▉| 317/321 [1:22:11<01:02, 15.58s/it][A
 99%|█████████▉| 318/321 [1:22:26<00:46, 15.40s/it][A
 99%|████


Evaluating...



  0%|          | 0/41 [00:00<?, ?it/s][A
  2%|▏         | 1/41 [00:04<03:15,  4.89s/it][A
  5%|▍         | 2/41 [00:09<03:02,  4.68s/it][A
  7%|▋         | 3/41 [00:14<03:01,  4.78s/it][A
 10%|▉         | 4/41 [00:18<02:53,  4.68s/it][A
 12%|█▏        | 5/41 [00:23<02:46,  4.63s/it][A
 15%|█▍        | 6/41 [00:28<02:52,  4.92s/it][A
 17%|█▋        | 7/41 [00:33<02:43,  4.80s/it][A
 20%|█▉        | 8/41 [00:38<02:38,  4.81s/it][A
 22%|██▏       | 9/41 [00:43<02:33,  4.80s/it][A
 24%|██▍       | 10/41 [00:47<02:29,  4.83s/it][A
 27%|██▋       | 11/41 [00:52<02:22,  4.75s/it][A
 29%|██▉       | 12/41 [00:57<02:18,  4.78s/it][A
 32%|███▏      | 13/41 [01:02<02:17,  4.92s/it][A
 34%|███▍      | 14/41 [01:07<02:10,  4.83s/it][A
 37%|███▋      | 15/41 [01:11<02:04,  4.78s/it][A
 39%|███▉      | 16/41 [01:16<01:57,  4.72s/it][A
 41%|████▏     | 17/41 [01:21<01:53,  4.72s/it][A
 44%|████▍     | 18/41 [01:25<01:46,  4.65s/it][A
 46%|████▋     | 19/41 [01:30<01:46,  4.83s/it]


Training Loss: 1.690
Validation Loss: 1.692

 Epoch 3 / 3



  0%|          | 0/321 [00:00<?, ?it/s][A
  0%|          | 1/321 [00:14<1:19:35, 14.92s/it][A
  1%|          | 2/321 [00:31<1:24:35, 15.91s/it][A
  1%|          | 3/321 [00:46<1:21:27, 15.37s/it][A
  1%|          | 4/321 [01:01<1:20:38, 15.26s/it][A
  2%|▏         | 5/321 [01:16<1:19:45, 15.15s/it][A
  2%|▏         | 6/321 [01:32<1:20:39, 15.36s/it][A
  2%|▏         | 7/321 [01:47<1:20:16, 15.34s/it][A
  2%|▏         | 8/321 [02:02<1:20:03, 15.35s/it][A
  3%|▎         | 9/321 [02:17<1:18:39, 15.13s/it][A
  3%|▎         | 10/321 [02:33<1:19:41, 15.38s/it][A
  3%|▎         | 11/321 [02:48<1:18:45, 15.24s/it][A
  4%|▎         | 12/321 [03:04<1:19:26, 15.43s/it][A
  4%|▍         | 13/321 [03:19<1:19:03, 15.40s/it][A
  4%|▍         | 14/321 [03:34<1:18:46, 15.40s/it][A
  5%|▍         | 15/321 [03:49<1:17:48, 15.26s/it][A
  5%|▍         | 16/321 [04:04<1:16:53, 15.13s/it][A
  5%|▌         | 17/321 [04:19<1:16:38, 15.13s/it][A
  6%|▌         | 18/321 [04:34<1:16:16, 15.10s/

  Batch    50  of    321.



 16%|█▌        | 51/321 [12:53<1:07:54, 15.09s/it][A
 16%|█▌        | 52/321 [13:08<1:08:09, 15.20s/it][A
 17%|█▋        | 53/321 [13:23<1:07:06, 15.02s/it][A
 17%|█▋        | 54/321 [13:39<1:08:01, 15.29s/it][A
 17%|█▋        | 55/321 [13:54<1:07:27, 15.22s/it][A
 17%|█▋        | 56/321 [14:10<1:07:55, 15.38s/it][A
 18%|█▊        | 57/321 [14:25<1:07:13, 15.28s/it][A
 18%|█▊        | 58/321 [14:40<1:07:23, 15.38s/it][A
 18%|█▊        | 59/321 [14:55<1:06:25, 15.21s/it][A
 19%|█▊        | 60/321 [15:11<1:06:59, 15.40s/it][A
 19%|█▉        | 61/321 [15:26<1:06:10, 15.27s/it][A
 19%|█▉        | 62/321 [15:42<1:06:48, 15.48s/it][A
 20%|█▉        | 63/321 [15:57<1:06:16, 15.41s/it][A
 20%|█▉        | 64/321 [16:14<1:07:18, 15.71s/it][A
 20%|██        | 65/321 [16:30<1:07:23, 15.80s/it][A
 21%|██        | 66/321 [16:47<1:08:49, 16.19s/it][A
 21%|██        | 67/321 [17:03<1:08:28, 16.18s/it][A
 21%|██        | 68/321 [17:19<1:08:14, 16.19s/it][A
 21%|██▏       | 69/321 [17

  Batch   100  of    321.



 31%|███▏      | 101/321 [25:50<56:35, 15.44s/it][A
 32%|███▏      | 102/321 [26:05<55:51, 15.30s/it][A
 32%|███▏      | 103/321 [26:21<56:23, 15.52s/it][A
 32%|███▏      | 104/321 [26:36<55:43, 15.41s/it][A
 33%|███▎      | 105/321 [26:52<55:27, 15.41s/it][A
 33%|███▎      | 106/321 [27:07<54:49, 15.30s/it][A
 33%|███▎      | 107/321 [27:23<55:00, 15.42s/it][A
 34%|███▎      | 108/321 [27:38<55:01, 15.50s/it][A
 34%|███▍      | 109/321 [27:54<54:46, 15.50s/it][A
 34%|███▍      | 110/321 [28:09<54:15, 15.43s/it][A
 35%|███▍      | 111/321 [28:25<54:18, 15.52s/it][A
 35%|███▍      | 112/321 [28:40<53:56, 15.48s/it][A
 35%|███▌      | 113/321 [28:56<54:12, 15.64s/it][A
 36%|███▌      | 114/321 [29:12<53:49, 15.60s/it][A
 36%|███▌      | 115/321 [29:28<53:59, 15.72s/it][A
 36%|███▌      | 116/321 [29:43<53:28, 15.65s/it][A
 36%|███▋      | 117/321 [29:59<53:24, 15.71s/it][A
 37%|███▋      | 118/321 [30:14<52:54, 15.64s/it][A
 37%|███▋      | 119/321 [30:30<52:45, 15.67s

  Batch   150  of    321.



 47%|████▋     | 151/321 [38:54<44:44, 15.79s/it][A
 47%|████▋     | 152/321 [39:10<44:11, 15.69s/it][A
 48%|████▊     | 153/321 [39:26<44:12, 15.79s/it][A
 48%|████▊     | 154/321 [39:41<43:40, 15.69s/it][A
 48%|████▊     | 155/321 [39:57<43:40, 15.78s/it][A
 49%|████▊     | 156/321 [40:13<43:01, 15.65s/it][A
 49%|████▉     | 157/321 [40:29<43:27, 15.90s/it][A
 49%|████▉     | 158/321 [40:45<43:04, 15.85s/it][A
 50%|████▉     | 159/321 [41:01<43:07, 15.97s/it][A
 50%|████▉     | 160/321 [41:17<42:42, 15.92s/it][A
 50%|█████     | 161/321 [41:33<42:31, 15.95s/it][A
 50%|█████     | 162/321 [41:49<42:03, 15.87s/it][A
 51%|█████     | 163/321 [42:04<41:51, 15.90s/it][A
 51%|█████     | 164/321 [42:20<41:10, 15.73s/it][A
 51%|█████▏    | 165/321 [42:36<41:19, 15.89s/it][A
 52%|█████▏    | 166/321 [42:51<40:36, 15.72s/it][A
 52%|█████▏    | 167/321 [43:07<40:19, 15.71s/it][A
 52%|█████▏    | 168/321 [43:22<39:40, 15.56s/it][A
 53%|█████▎    | 169/321 [43:38<39:37, 15.64s

  Batch   200  of    321.



 63%|██████▎   | 201/321 [51:54<30:45, 15.38s/it][A
 63%|██████▎   | 202/321 [52:10<30:42, 15.48s/it][A
 63%|██████▎   | 203/321 [52:25<30:14, 15.37s/it][A
 64%|██████▎   | 204/321 [52:41<30:06, 15.44s/it][A
 64%|██████▍   | 205/321 [52:56<29:41, 15.36s/it][A
 64%|██████▍   | 206/321 [53:12<29:33, 15.42s/it][A
 64%|██████▍   | 207/321 [53:26<28:57, 15.24s/it][A
 65%|██████▍   | 208/321 [53:42<29:05, 15.45s/it][A
 65%|██████▌   | 209/321 [53:58<28:52, 15.46s/it][A
 65%|██████▌   | 210/321 [54:14<29:07, 15.74s/it][A
 66%|██████▌   | 211/321 [54:29<28:30, 15.55s/it][A
 66%|██████▌   | 212/321 [54:46<28:43, 15.81s/it][A
 66%|██████▋   | 213/321 [55:01<28:20, 15.74s/it][A
 67%|██████▋   | 214/321 [55:18<28:32, 16.01s/it][A
 67%|██████▋   | 215/321 [55:34<28:04, 15.89s/it][A
 67%|██████▋   | 216/321 [55:50<28:12, 16.12s/it][A
 68%|██████▊   | 217/321 [56:06<27:49, 16.06s/it][A
 68%|██████▊   | 218/321 [56:22<27:43, 16.15s/it][A
 68%|██████▊   | 219/321 [56:38<27:06, 15.95s

  Batch   250  of    321.



 78%|███████▊  | 251/321 [1:05:03<18:27, 15.83s/it][A
 79%|███████▊  | 252/321 [1:05:19<18:16, 15.89s/it][A
 79%|███████▉  | 253/321 [1:05:34<17:57, 15.84s/it][A
 79%|███████▉  | 254/321 [1:05:50<17:38, 15.80s/it][A
 79%|███████▉  | 255/321 [1:06:06<17:16, 15.71s/it][A
 80%|███████▉  | 256/321 [1:06:22<17:07, 15.81s/it][A
 80%|████████  | 257/321 [1:06:37<16:49, 15.77s/it][A
 80%|████████  | 258/321 [1:06:53<16:38, 15.86s/it][A
 81%|████████  | 259/321 [1:07:09<16:11, 15.67s/it][A
 81%|████████  | 260/321 [1:07:25<16:04, 15.81s/it][A
 81%|████████▏ | 261/321 [1:07:40<15:43, 15.72s/it][A
 82%|████████▏ | 262/321 [1:07:56<15:34, 15.84s/it][A
 82%|████████▏ | 263/321 [1:08:12<15:13, 15.76s/it][A
 82%|████████▏ | 264/321 [1:08:28<15:04, 15.87s/it][A
 83%|████████▎ | 265/321 [1:08:44<14:49, 15.88s/it][A
 83%|████████▎ | 266/321 [1:09:00<14:35, 15.91s/it][A
 83%|████████▎ | 267/321 [1:09:16<14:19, 15.92s/it][A
 83%|████████▎ | 268/321 [1:09:32<14:07, 15.99s/it][A
 84%|████

  Batch   300  of    321.



 94%|█████████▍| 301/321 [1:18:16<05:18, 15.94s/it][A
 94%|█████████▍| 302/321 [1:18:33<05:04, 16.02s/it][A
 94%|█████████▍| 303/321 [1:18:49<04:48, 16.04s/it][A
 95%|█████████▍| 304/321 [1:19:05<04:32, 16.03s/it][A
 95%|█████████▌| 305/321 [1:19:20<04:14, 15.92s/it][A
 95%|█████████▌| 306/321 [1:19:36<03:55, 15.72s/it][A
 96%|█████████▌| 307/321 [1:19:51<03:38, 15.59s/it][A
 96%|█████████▌| 308/321 [1:20:07<03:23, 15.66s/it][A
 96%|█████████▋| 309/321 [1:20:22<03:08, 15.67s/it][A
 97%|█████████▋| 310/321 [1:20:38<02:52, 15.69s/it][A
 97%|█████████▋| 311/321 [1:20:54<02:36, 15.64s/it][A
 97%|█████████▋| 312/321 [1:21:09<02:20, 15.61s/it][A
 98%|█████████▊| 313/321 [1:21:25<02:04, 15.59s/it][A
 98%|█████████▊| 314/321 [1:21:40<01:49, 15.61s/it][A
 98%|█████████▊| 315/321 [1:21:56<01:33, 15.61s/it][A
 98%|█████████▊| 316/321 [1:22:12<01:18, 15.67s/it][A
 99%|█████████▉| 317/321 [1:22:27<01:02, 15.60s/it][A
 99%|█████████▉| 318/321 [1:22:43<00:46, 15.65s/it][A
 99%|████


Evaluating...



  0%|          | 0/41 [00:00<?, ?it/s][A
  2%|▏         | 1/41 [00:04<03:12,  4.82s/it][A
  5%|▍         | 2/41 [00:09<03:11,  4.91s/it][A
  7%|▋         | 3/41 [00:14<03:04,  4.86s/it][A
 10%|▉         | 4/41 [00:19<02:54,  4.72s/it][A
 12%|█▏        | 5/41 [00:23<02:49,  4.70s/it][A
 15%|█▍        | 6/41 [00:28<02:45,  4.74s/it][A
 17%|█▋        | 7/41 [00:33<02:41,  4.74s/it][A
 20%|█▉        | 8/41 [00:38<02:43,  4.94s/it][A
 22%|██▏       | 9/41 [00:43<02:35,  4.86s/it][A
 24%|██▍       | 10/41 [00:48<02:34,  4.97s/it][A
 27%|██▋       | 11/41 [00:53<02:26,  4.87s/it][A
 29%|██▉       | 12/41 [00:58<02:23,  4.93s/it][A
 32%|███▏      | 13/41 [01:02<02:15,  4.83s/it][A
 34%|███▍      | 14/41 [01:07<02:09,  4.81s/it][A
 37%|███▋      | 15/41 [01:13<02:11,  5.04s/it][A
 39%|███▉      | 16/41 [01:17<02:03,  4.95s/it][A
 41%|████▏     | 17/41 [01:22<01:57,  4.89s/it][A
 44%|████▍     | 18/41 [01:27<01:50,  4.80s/it][A
 46%|████▋     | 19/41 [01:32<01:45,  4.80s/it]


Training Loss: 1.630
Validation Loss: 1.686





In [15]:
# load weights of best model
path = '/kaggle/working/changed_weights_roberta.pt'
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [16]:
with torch.no_grad():
  preds = model(test_seq, test_mask)
  preds = preds.detach().cpu().numpy()

preds = np.argmax(preds, axis = 1)
print(classification_report(test_y, preds))

              precision    recall  f1-score   support

           0       0.29      0.40      0.34       250
           1       0.28      0.30      0.29       267
           2       0.28      0.42      0.34       249
           3       0.30      0.28      0.29       211
           4       0.25      0.05      0.08       214
           5       0.44      0.21      0.28        92

    accuracy                           0.29      1283
   macro avg       0.31      0.28      0.27      1283
weighted avg       0.29      0.29      0.27      1283



In [17]:
# testing on unseen data
unseen_news_text = ["Donald Trump Sends Out Embarrassing New Year’s Eve Message; This is Disturbing",     # Fake
                    "WATCH: George W. Bush Calls Out Trump For Supporting White Supremacy",               # Fake
                    "U.S. lawmakers question businessman at 2016 Trump Tower meeting: sources",           # True
                    "Trump administration issues new rules on U.S. visa waivers"                          # True
                    ]

# tokenize and encode sequences in the test set
MAX_LENGHT = 100
tokens_unseen = roberta_tokenizer.batch_encode_plus(
    unseen_news_text,
    max_length = MAX_LENGHT,
    pad_to_max_length=True,
    truncation=True
)

unseen_seq = torch.tensor(tokens_unseen['input_ids'])
unseen_mask = torch.tensor(tokens_unseen['attention_mask'])

with torch.no_grad():
  preds = model(unseen_seq, unseen_mask)
  preds = preds.detach().cpu().numpy()

preds = np.argmax(preds, axis = 1)
preds

array([5, 0, 0, 0])

In [18]:
import os 
%cd /kaggle/working
from IPython.display import FileLink
FileLink(r'changed_weights_roberta.pt')

/kaggle/working
