# CentraleSupelec - Natural language processing
# Practical session n°7

## Natural Language Inferencing (NLI): 

(NLI) is a classical NLP (Natural Language Processing) problem that involves taking two sentences (the premise and the hypothesis ), and deciding how they are related- if the premise *entails* the hypothesis, *contradicts* it, or *neither*.

Ex: 


| Premise | Label | Hypothesis |
| --- | --- | --- |
| A man inspects the uniform of a figure in some East Asian country. | contradiction | The man is sleeping. |
| An older and younger man smiling. | neutral | Two men are smiling and laughing at the cats playing on the floor. |
| A soccer game with multiple males playing. | entailment | Some men are playing a sport. |

### Stanford NLI (SNLI) corpus

In this labwork, I propose to use the Stanford NLI (SNLI) corpus ( https://nlp.stanford.edu/projects/snli/ ), available in the *Datasets* library by Huggingface.

    from datasets import load_dataset
    snli = load_dataset("snli")
    #Removing sentence pairs with no label (-1)
    snli = snli.filter(lambda example: example['label'] != -1) 

## Subject

You are asked to provide an operational Jupyter notebook that performs the task of NLI. For that, you need to tackle the following aspects of the problem:

1. Loading and preprocessing the data
2. Designing a PyTorch model that, given two sentences, decides how they are related (*entails*, *contradicts* or *neither*.)
3. Training and evaluating the model using appropriate metrics
4. (Optional) Allowing to play with the model (forward user sentences and visualize the prediction easily)
5. (Optional) Providing visual insight about the model (i.e. visualizing the attention if your model is using attention)

Although it is not mandatory, I suggest that you use a transformer model to perform the task. For that, you can use the *Transformer* library by Huggingface.

## Evaluation

The evaluation will be based on several criteria:

- Clarity and readability of the notebook. The notebook is the report of you project. Make it easy and pleasant to read.
- Justification of implementation choices (i.e. the network, the cost funtion, the optimizer, ...)
- Quality of the code. The various deeplearning and NLP labworks provide many example of good practices for designing experiments with neural networks. Use them as inspirational examples!

## Additional recommendations

- You are not seeking to publish a research paper! I'm not expecting state-of-the-art results! The idea of this labwork is to assess that you have integrated the skills necessary to handle textual data using deep neural network techniques.

- This labwork will be evaluated but we are still here to help you! Don't hesitate to request our help if you are stuck.

- If you intend to use BERT based models, let me give you an advice. The bert-base-* models available in *Transformers* need more than 12Go to be fine-tuned on GPU. To avoid memory issues, you can use several solutions: 

    - Use a lighter BERT based model such as DistilBERT, ALBERT, ...
    - Train a classification model on top of BERT, whithout fine-tuning it (i.e. freezing BERT weights)

## Huggingface documentations

In case you want to use the huggingface *Datasets* and *Transformer* libraries (which I advice), here are some useful documentation pages:

- Dataset quick tour

    https://huggingface.co/docs/datasets/quicktour.html
    
- Documentation on data preprocessing for transformers

    https://huggingface.co/transformers/preprocessing.html
    
- Transformer Quick tour (with distilbert example for classification).

    https://huggingface.co/transformers/quicktour.html
    


## Part 0 : Imports

In [1]:
from nltk.tokenize import word_tokenize 
import pandas as pd
import os
from datasets import load_dataset
from transformers import AutoTokenizer, BertForSequenceClassification,AutoModelForSequenceClassification, DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification, BertTokenizer, BertModel
import time
import multiprocessing
from tqdm import tqdm
import torch
from torch import nn

## Part 0 bis : Variables

In [2]:
BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-5
epocs = 10

## Part 1 : Load data / tokenizer and preprocess


Dans cette première partie, on va télécharger le dataset et le stocker dans un dataframe pandas. Ensuite, on effectura un preprocessing des données en effectuant une tokenization des phrases des corpus.

In [3]:
snli = load_dataset("snli")
#Removing sentence pairs with no label (-1)
snli = snli.filter(lambda example: example['label'] != -1)

W0406 08:13:22.185691 140051160676160 builder.py:506] Reusing dataset snli (/usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
W0406 08:13:22.797408 140051160676160 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-3bdabd315eb35115.arrow
W0406 08:13:22.836282 140051160676160 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-65b48051e0e50eaa.arrow
W0406 08:13:23.633090 140051160676160 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5

In [4]:
snli

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9842
    })
})

In [5]:
snli['train']['premise'][1]

'A person on a horse jumps over a broken down airplane.'

In [6]:
snli['train']['hypothesis'][1]

'A person is at a diner, ordering an omelette.'

In [7]:
snli['train']['label'][1]

2

In [8]:
train = snli['train']
validation = snli['validation']
test = snli['test']

In [9]:
print(train.shape)
print(validation.shape)
print(test.shape)

(549367, 3)
(9842, 3)
(9824, 3)


In [10]:
from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification

bertmodel = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",output_attentions=True,num_labels=3)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [11]:
tokenizer

PreTrainedTokenizer(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [12]:
def bert_encode(dataset):
    return(tokenizer(dataset['premise'], dataset['hypothesis'], truncation=True, padding='max_length',max_length = 110))

In [13]:
train = train.map(bert_encode, batched=True,batch_size = BATCH_SIZE, num_proc=os.cpu_count())
test = test.map(bert_encode, batched=True,batch_size = BATCH_SIZE,num_proc=os.cpu_count())
validation = validation.map(bert_encode, batch_size = BATCH_SIZE,batched=True,num_proc=os.cpu_count())



























In [14]:
print(train[0])

{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'hypothesis': 'A person is training his horse for a competition.', 'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': 1, 'premise': 'A person on a horse jumps over a broken down airplane.'}


In [15]:
train = train.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())
validation = validation.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())
test = test.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())



























In [16]:
train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
validation.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
val_dataloader = torch.utils.data.DataLoader(validation, batch_size=BATCH_SIZE,num_workers=os.cpu_count())
test_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE,num_workers=os.cpu_count())

In [19]:
next(iter(train_dataloader))

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[ 101, 1037, 2711,  ...,    0,    0,    0],
         [ 101, 1037, 2711,  ...,    0,    0,    0],
         [ 101, 1037, 2711,  ...,    0,    0,    0],
         ...,
         [ 101, 2450, 1999,  ...,    0,    0,    0],
         [ 101, 2450, 1999,  ...,    0,    0,    0],
         [ 101, 2450, 1999,  ...,    0,    0,    0]]),
 'labels': tensor([1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 2, 1, 1, 2, 0, 1, 2, 0, 0, 2, 1, 1, 2, 0,
         2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 1, 0, 2, 0, 1, 1, 0, 2, 1, 0, 0, 0, 1, 2,
         2, 0, 1, 2, 0, 1, 2, 1, 0, 1, 2, 0, 0, 2, 1, 0])}

In [18]:
print(train['premise'][1])
print(tokenizer.decode(train['input_ids'][1]))

A person on a horse jumps over a broken down airplane.
[CLS] a person on a horse jumps over a broken down airplane. [SEP] a person is at a diner, ordering an omelette. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


Entrainons maintenant le modèle Bert large cased

## Part 2 : Train

In [29]:
def train_bert(model,clip):
    totloss=0
    model.train().to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        totloss+= loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if i % 100 == 0:
            print(f"train_loss: {loss}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg

def val(model):
    model.train().to(device)    
    model.eval()
    totloss = 0
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(val_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        totloss+=loss.item()
        if i % 100 == 0:
            print(f"Val loss: {loss:.3f}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg


In [28]:
#https://github.com/12860/NLP/blob/master/Seq2seq.py

best_validation_loss = float('inf')
clip=1
for epoch in range(epocs):
    print('Epoch : ',epoch)
    start_time = time.time()    
    train_loss = train_bert(bertmodel,clip)
    valid_loss= val(bertmodel)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss > best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), '/usr/users/gpusdi1/gpusdi1_26/Documents/NLP/NLP_Natural-Language-Inferencing/dist_model.pt')
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

  0%|          | 0/8584 [00:00<?, ?it/s]

Epoch :  0


  0%|          | 1/8584 [00:00<1:45:45,  1.35it/s]

train_loss: 0.14522214233875275


  1%|          | 101/8584 [00:30<43:08,  3.28it/s]

train_loss: 0.23425036668777466


  2%|▏         | 201/8584 [01:00<43:01,  3.25it/s]

train_loss: 0.1981733739376068


  4%|▎         | 301/8584 [01:31<42:17,  3.26it/s]

train_loss: 0.42086783051490784


  5%|▍         | 401/8584 [02:01<42:11,  3.23it/s]

train_loss: 0.4618408679962158


  6%|▌         | 501/8584 [02:32<40:51,  3.30it/s]

train_loss: 0.13465049862861633


  7%|▋         | 601/8584 [03:02<42:50,  3.11it/s]

train_loss: 0.31676730513572693


  8%|▊         | 701/8584 [03:33<40:25,  3.25it/s]

train_loss: 0.3315136730670929


  9%|▉         | 801/8584 [04:03<39:56,  3.25it/s]

train_loss: 0.28398969769477844


 10%|█         | 901/8584 [04:34<39:59,  3.20it/s]

train_loss: 0.5319575667381287


 12%|█▏        | 1001/8584 [05:04<38:01,  3.32it/s]

train_loss: 0.3529589772224426


 13%|█▎        | 1101/8584 [05:34<38:26,  3.24it/s]

train_loss: 0.41666850447654724


 14%|█▍        | 1201/8584 [06:05<36:42,  3.35it/s]

train_loss: 0.40517866611480713


 15%|█▌        | 1301/8584 [06:35<36:31,  3.32it/s]

train_loss: 0.5030125379562378


 16%|█▋        | 1401/8584 [07:06<36:54,  3.24it/s]

train_loss: 0.2665935754776001


 17%|█▋        | 1501/8584 [07:36<36:14,  3.26it/s]

train_loss: 0.15697376430034637


 19%|█▊        | 1601/8584 [08:07<35:38,  3.26it/s]

train_loss: 0.3297790288925171


 20%|█▉        | 1701/8584 [08:37<34:17,  3.34it/s]

train_loss: 0.338992714881897


 21%|██        | 1801/8584 [09:07<34:39,  3.26it/s]

train_loss: 0.2529921531677246


 22%|██▏       | 1901/8584 [09:38<34:12,  3.26it/s]

train_loss: 0.3599524199962616


 23%|██▎       | 2001/8584 [10:08<32:41,  3.36it/s]

train_loss: 0.28833869099617004


 24%|██▍       | 2101/8584 [10:38<32:46,  3.30it/s]

train_loss: 0.3159860074520111


 26%|██▌       | 2201/8584 [11:08<32:41,  3.25it/s]

train_loss: 0.49734583497047424


 27%|██▋       | 2301/8584 [11:38<32:09,  3.26it/s]

train_loss: 0.22605301439762115


 28%|██▊       | 2401/8584 [12:09<31:38,  3.26it/s]

train_loss: 0.2692491412162781


 29%|██▉       | 2501/8584 [12:39<30:29,  3.33it/s]

train_loss: 0.2712995409965515


 30%|███       | 2601/8584 [13:09<30:42,  3.25it/s]

train_loss: 0.33650824427604675


 31%|███▏      | 2701/8584 [13:39<29:51,  3.28it/s]

train_loss: 0.5786368250846863


 33%|███▎      | 2801/8584 [14:10<29:14,  3.30it/s]

train_loss: 0.28621914982795715


 34%|███▍      | 2901/8584 [14:40<29:10,  3.25it/s]

train_loss: 0.5617309808731079


 35%|███▍      | 3001/8584 [15:10<27:54,  3.33it/s]

train_loss: 0.1801915019750595


 36%|███▌      | 3101/8584 [15:40<28:02,  3.26it/s]

train_loss: 0.4273701310157776


 37%|███▋      | 3201/8584 [16:11<27:11,  3.30it/s]

train_loss: 0.1515745222568512


 38%|███▊      | 3301/8584 [16:41<26:39,  3.30it/s]

train_loss: 0.2576152980327606


 40%|███▉      | 3401/8584 [17:11<25:50,  3.34it/s]

train_loss: 0.26958131790161133


 41%|████      | 3501/8584 [17:42<25:50,  3.28it/s]

train_loss: 0.5718315243721008


 42%|████▏     | 3601/8584 [18:12<24:44,  3.36it/s]

train_loss: 0.31357091665267944


 43%|████▎     | 3701/8584 [18:43<24:54,  3.27it/s]

train_loss: 0.2723030149936676


 44%|████▍     | 3801/8584 [19:13<24:08,  3.30it/s]

train_loss: 0.312578946352005


 45%|████▌     | 3901/8584 [19:43<23:53,  3.27it/s]

train_loss: 0.3922831118106842


 47%|████▋     | 4001/8584 [20:13<23:11,  3.29it/s]

train_loss: 0.35326677560806274


 48%|████▊     | 4101/8584 [20:44<23:07,  3.23it/s]

train_loss: 0.31724220514297485


 49%|████▉     | 4201/8584 [21:14<22:30,  3.25it/s]

train_loss: 0.4671997129917145


 50%|█████     | 4301/8584 [21:44<21:49,  3.27it/s]

train_loss: 0.31056806445121765


 51%|█████▏    | 4401/8584 [22:15<21:27,  3.25it/s]

train_loss: 0.4410552978515625


 52%|█████▏    | 4501/8584 [22:46<20:55,  3.25it/s]

train_loss: 0.3993662893772125


 54%|█████▎    | 4601/8584 [23:16<20:16,  3.27it/s]

train_loss: 0.3992123007774353


 55%|█████▍    | 4701/8584 [23:46<19:15,  3.36it/s]

train_loss: 0.3004539906978607


 56%|█████▌    | 4801/8584 [24:17<19:25,  3.25it/s]

train_loss: 0.3292571008205414


 57%|█████▋    | 4901/8584 [24:47<18:18,  3.35it/s]

train_loss: 0.30453914403915405


 58%|█████▊    | 5001/8584 [25:17<18:23,  3.25it/s]

train_loss: 0.3773057758808136


 59%|█████▉    | 5101/8584 [25:47<17:51,  3.25it/s]

train_loss: 0.48989343643188477


 61%|██████    | 5201/8584 [26:18<17:21,  3.25it/s]

train_loss: 0.49028849601745605


 62%|██████▏   | 5301/8584 [26:48<16:49,  3.25it/s]

train_loss: 0.30512362718582153


 63%|██████▎   | 5401/8584 [27:18<16:16,  3.26it/s]

train_loss: 0.2565942108631134


 64%|██████▍   | 5501/8584 [27:49<15:49,  3.25it/s]

train_loss: 0.20433278381824493


 65%|██████▌   | 5601/8584 [28:19<15:15,  3.26it/s]

train_loss: 0.2872927784919739


 66%|██████▋   | 5701/8584 [28:49<14:40,  3.28it/s]

train_loss: 0.3647804856300354


 68%|██████▊   | 5801/8584 [29:19<14:09,  3.27it/s]

train_loss: 0.35779431462287903


 69%|██████▊   | 5901/8584 [29:50<13:23,  3.34it/s]

train_loss: 0.27364858984947205


 70%|██████▉   | 6001/8584 [30:20<13:11,  3.26it/s]

train_loss: 0.25485339760780334


 71%|███████   | 6101/8584 [30:50<12:20,  3.35it/s]

train_loss: 0.4888186454772949


 72%|███████▏  | 6201/8584 [31:20<12:06,  3.28it/s]

train_loss: 0.26375651359558105


 73%|███████▎  | 6301/8584 [31:51<11:42,  3.25it/s]

train_loss: 0.23241664469242096


 75%|███████▍  | 6401/8584 [32:21<11:09,  3.26it/s]

train_loss: 0.373620867729187


 76%|███████▌  | 6501/8584 [32:52<10:23,  3.34it/s]

train_loss: 0.30040988326072693


 77%|███████▋  | 6601/8584 [33:22<10:05,  3.27it/s]

train_loss: 0.23150545358657837


 78%|███████▊  | 6701/8584 [33:53<09:40,  3.24it/s]

train_loss: 0.3117799460887909


 79%|███████▉  | 6801/8584 [34:23<08:52,  3.35it/s]

train_loss: 0.2698381841182709


 80%|████████  | 6901/8584 [34:53<08:25,  3.33it/s]

train_loss: 0.33881625533103943


 82%|████████▏ | 7001/8584 [35:24<08:06,  3.25it/s]

train_loss: 0.382996529340744


 83%|████████▎ | 7101/8584 [35:54<07:35,  3.26it/s]

train_loss: 0.26738622784614563


 84%|████████▍ | 7201/8584 [36:24<07:06,  3.25it/s]

train_loss: 0.3100307583808899


 85%|████████▌ | 7301/8584 [36:54<06:28,  3.30it/s]

train_loss: 0.39548155665397644


 86%|████████▌ | 7401/8584 [37:25<05:53,  3.35it/s]

train_loss: 0.3277791440486908


 87%|████████▋ | 7501/8584 [37:55<05:26,  3.32it/s]

train_loss: 0.3025898337364197


 89%|████████▊ | 7601/8584 [38:25<04:54,  3.34it/s]

train_loss: 0.2561916708946228


 90%|████████▉ | 7701/8584 [38:56<04:30,  3.26it/s]

train_loss: 0.5366554856300354


 91%|█████████ | 7801/8584 [39:26<04:00,  3.26it/s]

train_loss: 0.33864399790763855


 92%|█████████▏| 7901/8584 [39:56<03:28,  3.28it/s]

train_loss: 0.27003368735313416


 93%|█████████▎| 8001/8584 [40:26<02:59,  3.25it/s]

train_loss: 0.39738425612449646


 94%|█████████▍| 8101/8584 [40:57<02:26,  3.29it/s]

train_loss: 0.41703933477401733


 96%|█████████▌| 8201/8584 [41:27<01:55,  3.32it/s]

train_loss: 0.3341292142868042


 97%|█████████▋| 8301/8584 [41:58<01:26,  3.26it/s]

train_loss: 0.18375955522060394


 98%|█████████▊| 8401/8584 [42:28<00:56,  3.27it/s]

train_loss: 0.34118711948394775


 99%|█████████▉| 8501/8584 [42:58<00:25,  3.29it/s]

train_loss: 0.37995296716690063


100%|██████████| 8584/8584 [43:24<00:00,  3.30it/s]
  0%|          | 0/154 [00:00<?, ?it/s]

Loss: 0.3508328425503559


  0%|          | 0/154 [00:00<?, ?it/s]


UnboundLocalError: local variable 'total_loss' referenced before assignment

## Part 3 : Val

In [None]:
def val():
    model.eval()
    totloss = 0
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(val_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        total_loss+=loss.item()
        if i % 100 == 0:
            print(f"Val loss: {loss:.3f}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg

In [None]:
val()