# CentraleSupelec - Natural language processing
# Practical session n°7

## Natural Language Inferencing (NLI): 

(NLI) is a classical NLP (Natural Language Processing) problem that involves taking two sentences (the premise and the hypothesis ), and deciding how they are related- if the premise *entails* the hypothesis, *contradicts* it, or *neither*.

Ex: 


| Premise | Label | Hypothesis |
| --- | --- | --- |
| A man inspects the uniform of a figure in some East Asian country. | contradiction | The man is sleeping. |
| An older and younger man smiling. | neutral | Two men are smiling and laughing at the cats playing on the floor. |
| A soccer game with multiple males playing. | entailment | Some men are playing a sport. |

### Stanford NLI (SNLI) corpus

In this labwork, I propose to use the Stanford NLI (SNLI) corpus ( https://nlp.stanford.edu/projects/snli/ ), available in the *Datasets* library by Huggingface.

    from datasets import load_dataset
    snli = load_dataset("snli")
    #Removing sentence pairs with no label (-1)
    snli = snli.filter(lambda example: example['label'] != -1) 

## Subject

You are asked to provide an operational Jupyter notebook that performs the task of NLI. For that, you need to tackle the following aspects of the problem:

1. Loading and preprocessing the data
2. Designing a PyTorch model that, given two sentences, decides how they are related (*entails*, *contradicts* or *neither*.)
3. Training and evaluating the model using appropriate metrics
4. (Optional) Allowing to play with the model (forward user sentences and visualize the prediction easily)
5. (Optional) Providing visual insight about the model (i.e. visualizing the attention if your model is using attention)

Although it is not mandatory, I suggest that you use a transformer model to perform the task. For that, you can use the *Transformer* library by Huggingface.

## Evaluation

The evaluation will be based on several criteria:

- Clarity and readability of the notebook. The notebook is the report of you project. Make it easy and pleasant to read.
- Justification of implementation choices (i.e. the network, the cost funtion, the optimizer, ...)
- Quality of the code. The various deeplearning and NLP labworks provide many example of good practices for designing experiments with neural networks. Use them as inspirational examples!

## Additional recommendations

- You are not seeking to publish a research paper! I'm not expecting state-of-the-art results! The idea of this labwork is to assess that you have integrated the skills necessary to handle textual data using deep neural network techniques.

- This labwork will be evaluated but we are still here to help you! Don't hesitate to request our help if you are stuck.

- If you intend to use BERT based models, let me give you an advice. The bert-base-* models available in *Transformers* need more than 12Go to be fine-tuned on GPU. To avoid memory issues, you can use several solutions: 

    - Use a lighter BERT based model such as DistilBERT, ALBERT, ...
    - Train a classification model on top of BERT, whithout fine-tuning it (i.e. freezing BERT weights)

## Huggingface documentations

In case you want to use the huggingface *Datasets* and *Transformer* libraries (which I advice), here are some useful documentation pages:

- Dataset quick tour

    https://huggingface.co/docs/datasets/quicktour.html
    
- Documentation on data preprocessing for transformers

    https://huggingface.co/transformers/preprocessing.html
    
- Transformer Quick tour (with distilbert example for classification).

    https://huggingface.co/transformers/quicktour.html
    


## Part 0 : Imports

In [1]:
from nltk.tokenize import word_tokenize 
import pandas as pd
import os
from datasets import load_dataset
from transformers import AutoTokenizer, BertForSequenceClassification,AutoModelForSequenceClassification, DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification, BertTokenizer, BertModel
import time
import multiprocessing
from tqdm import tqdm
import torch
from torch import nn

## Part 0 bis : Variables

In [2]:
BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-5
epocs = 10

## Part 1 : Load data / tokenizer and preprocess


Dans cette première partie, on va télécharger le dataset et le stocker dans un dataframe pandas. Ensuite, on effectura un preprocessing des données en effectuant une tokenization des phrases des corpus.

In [3]:
snli = load_dataset("snli")
#Removing sentence pairs with no label (-1)
snli = snli.filter(lambda example: example['label'] != -1)

W0406 10:02:28.070537 140015683462976 builder.py:506] Reusing dataset snli (/usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
W0406 10:02:28.083729 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-3bdabd315eb35115.arrow
W0406 10:02:28.090463 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-65b48051e0e50eaa.arrow
W0406 10:02:28.111936 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5

In [4]:
snli

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9842
    })
})

In [5]:
snli['train']['premise'][1]

'A person on a horse jumps over a broken down airplane.'

In [6]:
snli['train']['hypothesis'][1]

'A person is at a diner, ordering an omelette.'

In [7]:
snli['train']['label'][1]

2

In [8]:
train = snli['train']
validation = snli['validation']
test = snli['test']

In [9]:
print(train.shape)
print(validation.shape)
print(test.shape)

(549367, 3)
(9842, 3)
(9824, 3)


In [10]:
from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertForSequenceClassification

bertmodel = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",output_attentions=True,num_labels=3)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [11]:
tokenizer

PreTrainedTokenizer(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [12]:
def bert_encode(dataset):
    return(tokenizer(dataset['premise'], dataset['hypothesis'], truncation=True, padding='max_length',max_length = 110))

In [13]:
train = train.map(bert_encode, batched=True,batch_size = BATCH_SIZE, num_proc=os.cpu_count())
test = test.map(bert_encode, batched=True,batch_size = BATCH_SIZE,num_proc=os.cpu_count())
validation = validation.map(bert_encode, batch_size = BATCH_SIZE,batched=True,num_proc=os.cpu_count())

W0406 10:02:45.158911 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-1f011e0f672b052c.arrow
W0406 10:02:45.174207 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-f50f38d2decf143c.arrow
W0406 10:02:45.175216 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-40f2833bf04d4326.arrow
W0406 10:02:45.174928 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c

In [14]:
print(train[0])

{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'hypothesis': 'A person is training his horse for a competition.', 'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': 1, 'premise': 'A person on a horse jumps over a broken down airplane.'}


In [15]:
train = train.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())
validation = validation.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())
test = test.map(lambda examples: {'labels': examples['label']}, batched=True, num_proc = os.cpu_count())

W0406 10:02:53.009278 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-7b0d43643ccae4ae.arrow
W0406 10:02:53.015909 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-0093c65474077e83.arrow
W0406 10:02:53.027776 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-310ed81319acdf76.arrow
W0406 10:02:53.031366 140015683462976 arrow_dataset.py:1349] Loading cached processed dataset at /usr/users/gpusdi1/gpusdi1_26/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c

In [16]:
train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
validation.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
val_dataloader = torch.utils.data.DataLoader(validation, batch_size=BATCH_SIZE,num_workers=os.cpu_count())
test_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE,num_workers=os.cpu_count())

In [17]:
next(iter(train_dataloader))

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[ 101, 1037, 2711,  ...,    0,    0,    0],
         [ 101, 1037, 2711,  ...,    0,    0,    0],
         [ 101, 1037, 2711,  ...,    0,    0,    0],
         ...,
         [ 101, 2450, 1999,  ...,    0,    0,    0],
         [ 101, 2450, 1999,  ...,    0,    0,    0],
         [ 101, 2450, 1999,  ...,    0,    0,    0]]),
 'labels': tensor([1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 2, 1, 1, 2, 0, 1, 2, 0, 0, 2, 1, 1, 2, 0,
         2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 1, 0, 2, 0, 1, 1, 0, 2, 1, 0, 0, 0, 1, 2,
         2, 0, 1, 2, 0, 1, 2, 1, 0, 1, 2, 0, 0, 2, 1, 0])}

In [18]:
print(train['premise'][1])
print(tokenizer.decode(train['input_ids'][1]))

A person on a horse jumps over a broken down airplane.
[CLS] a person on a horse jumps over a broken down airplane. [SEP] a person is at a diner, ordering an omelette. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


Entrainons maintenant le modèle Bert large cased

## Part 2 : Train

In [19]:
def train_bert(model,clip):
    totloss=0
    model.train().to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        totloss+= loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if i % 100 == 0:
            print(f"train_loss: {loss}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg

def val(model):
    model.train().to(device)    
    model.eval()
    totloss = 0
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(val_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        totloss+=loss.item()
        if i % 100 == 0:
            print(f"Val loss: {loss:.3f}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg


In [22]:
#https://github.com/12860/NLP/blob/master/Seq2seq.py

best_validation_loss = float('inf')
clip=1
for epoch in range(epocs):
    print('Epoch : ',epoch)    
    train_loss = train_bert(bertmodel,clip)
    valid_loss = val(bertmodel)
    if valid_loss < best_validation_loss:
        print('Saving a new model')
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), '/usr/users/gpusdi1/gpusdi1_26/Documents/NLP/NLP_Natural-Language-Inferencing/dist_model.pt')
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

  0%|          | 0/8584 [00:00<?, ?it/s]

Epoch :  0


  0%|          | 1/8584 [00:00<1:50:53,  1.29it/s]

train_loss: 0.23379069566726685


  1%|          | 101/8584 [00:30<43:02,  3.28it/s]

train_loss: 0.39080458879470825


  2%|▏         | 201/8584 [01:01<42:30,  3.29it/s]

train_loss: 0.28283268213272095


  4%|▎         | 301/8584 [01:31<42:14,  3.27it/s]

train_loss: 0.45580270886421204


  5%|▍         | 401/8584 [02:01<42:01,  3.25it/s]

train_loss: 0.4035240411758423


  6%|▌         | 501/8584 [02:32<41:19,  3.26it/s]

train_loss: 0.15157850086688995


  7%|▋         | 601/8584 [03:03<41:00,  3.24it/s]

train_loss: 0.31889790296554565


  8%|▊         | 701/8584 [03:33<39:17,  3.34it/s]

train_loss: 0.33595940470695496


  9%|▉         | 801/8584 [04:04<39:49,  3.26it/s]

train_loss: 0.21622586250305176


 10%|█         | 901/8584 [04:34<39:43,  3.22it/s]

train_loss: 0.5131458044052124


 12%|█▏        | 1001/8584 [05:05<39:04,  3.23it/s]

train_loss: 0.24892859160900116


 13%|█▎        | 1101/8584 [05:35<37:21,  3.34it/s]

train_loss: 0.3187336325645447


 14%|█▍        | 1201/8584 [06:06<37:54,  3.25it/s]

train_loss: 0.4094257354736328


 15%|█▌        | 1301/8584 [06:37<37:11,  3.26it/s]

train_loss: 0.42424020171165466


 16%|█▋        | 1401/8584 [07:07<36:43,  3.26it/s]

train_loss: 0.1889568716287613


 17%|█▋        | 1501/8584 [07:37<35:26,  3.33it/s]

train_loss: 0.11261671781539917


 19%|█▊        | 1601/8584 [08:08<35:22,  3.29it/s]

train_loss: 0.2809886634349823


 20%|█▉        | 1701/8584 [08:38<35:01,  3.28it/s]

train_loss: 0.2983098328113556


 21%|██        | 1801/8584 [09:08<35:06,  3.22it/s]

train_loss: 0.26347851753234863


 22%|██▏       | 1901/8584 [09:39<34:21,  3.24it/s]

train_loss: 0.35806456208229065


 23%|██▎       | 2001/8584 [10:09<33:52,  3.24it/s]

train_loss: 0.28656429052352905


 24%|██▍       | 2101/8584 [10:39<33:15,  3.25it/s]

train_loss: 0.23085331916809082


 26%|██▌       | 2201/8584 [11:10<32:57,  3.23it/s]

train_loss: 0.416031152009964


 27%|██▋       | 2301/8584 [11:40<31:54,  3.28it/s]

train_loss: 0.21625684201717377


 28%|██▊       | 2401/8584 [12:10<31:50,  3.24it/s]

train_loss: 0.22633975744247437


 29%|██▉       | 2501/8584 [12:41<31:12,  3.25it/s]

train_loss: 0.1954478919506073


 30%|███       | 2601/8584 [13:11<30:49,  3.24it/s]

train_loss: 0.30572083592414856


 31%|███▏      | 2701/8584 [13:42<30:17,  3.24it/s]

train_loss: 0.5074304342269897


 33%|███▎      | 2801/8584 [14:12<28:54,  3.33it/s]

train_loss: 0.27139750123023987


 34%|███▍      | 2901/8584 [14:42<28:19,  3.34it/s]

train_loss: 0.46154600381851196


 35%|███▍      | 3001/8584 [15:13<28:56,  3.22it/s]

train_loss: 0.12941071391105652


 36%|███▌      | 3101/8584 [15:43<27:35,  3.31it/s]

train_loss: 0.2925505042076111


 37%|███▋      | 3201/8584 [16:14<27:38,  3.25it/s]

train_loss: 0.09802065789699554


 38%|███▊      | 3301/8584 [16:44<26:26,  3.33it/s]

train_loss: 0.24954740703105927


 40%|███▉      | 3401/8584 [17:15<26:33,  3.25it/s]

train_loss: 0.2523605525493622


 41%|████      | 3501/8584 [17:45<26:06,  3.25it/s]

train_loss: 0.5727783441543579


 42%|████▏     | 3601/8584 [18:16<25:32,  3.25it/s]

train_loss: 0.21960699558258057


 43%|████▎     | 3701/8584 [18:46<24:29,  3.32it/s]

train_loss: 0.25089022517204285


 44%|████▍     | 3801/8584 [19:16<23:56,  3.33it/s]

train_loss: 0.2519577145576477


 45%|████▌     | 3901/8584 [19:47<24:02,  3.25it/s]

train_loss: 0.2667706608772278


 47%|████▋     | 4001/8584 [20:17<23:01,  3.32it/s]

train_loss: 0.25893568992614746


 48%|████▊     | 4101/8584 [20:48<23:04,  3.24it/s]

train_loss: 0.20490750670433044


 49%|████▉     | 4201/8584 [21:18<22:02,  3.31it/s]

train_loss: 0.41373199224472046


 50%|█████     | 4301/8584 [21:48<21:21,  3.34it/s]

train_loss: 0.22858776152133942


 51%|█████▏    | 4401/8584 [22:19<21:28,  3.25it/s]

train_loss: 0.372648686170578


 52%|█████▏    | 4501/8584 [22:49<20:37,  3.30it/s]

train_loss: 0.34148073196411133


 54%|█████▎    | 4601/8584 [23:20<20:18,  3.27it/s]

train_loss: 0.35716238617897034


 55%|█████▍    | 4701/8584 [23:50<19:28,  3.32it/s]

train_loss: 0.250332236289978


 56%|█████▌    | 4801/8584 [24:20<19:16,  3.27it/s]

train_loss: 0.25036707520484924


 57%|█████▋    | 4901/8584 [24:51<18:36,  3.30it/s]

train_loss: 0.31143757700920105


 58%|█████▊    | 5001/8584 [25:21<17:58,  3.32it/s]

train_loss: 0.33451828360557556


 59%|█████▉    | 5101/8584 [25:52<17:52,  3.25it/s]

train_loss: 0.42082688212394714


 61%|██████    | 5201/8584 [26:22<16:57,  3.33it/s]

train_loss: 0.36367711424827576


 62%|██████▏   | 5301/8584 [26:53<16:51,  3.24it/s]

train_loss: 0.24696150422096252


 63%|██████▎   | 5401/8584 [27:23<15:49,  3.35it/s]

train_loss: 0.19395332038402557


 64%|██████▍   | 5501/8584 [27:54<15:41,  3.28it/s]

train_loss: 0.20873580873012543


 65%|██████▌   | 5601/8584 [28:24<15:16,  3.25it/s]

train_loss: 0.2924080193042755


 66%|██████▋   | 5701/8584 [28:55<14:26,  3.33it/s]

train_loss: 0.26274868845939636


 68%|██████▊   | 5801/8584 [29:25<14:06,  3.29it/s]

train_loss: 0.2728598713874817


 69%|██████▊   | 5901/8584 [29:55<13:43,  3.26it/s]

train_loss: 0.1568974256515503


 70%|██████▉   | 6001/8584 [30:26<13:13,  3.25it/s]

train_loss: 0.25266432762145996


 71%|███████   | 6101/8584 [30:56<12:39,  3.27it/s]

train_loss: 0.46954596042633057


 72%|███████▏  | 6201/8584 [31:27<12:12,  3.25it/s]

train_loss: 0.22389955818653107


 73%|███████▎  | 6301/8584 [31:57<11:40,  3.26it/s]

train_loss: 0.1842035949230194


 75%|███████▍  | 6401/8584 [32:28<11:12,  3.25it/s]

train_loss: 0.2841540575027466


 76%|███████▌  | 6501/8584 [32:58<10:34,  3.28it/s]

train_loss: 0.21475611627101898


 77%|███████▋  | 6601/8584 [33:29<10:11,  3.24it/s]

train_loss: 0.18583279848098755


 78%|███████▊  | 6701/8584 [33:59<09:39,  3.25it/s]

train_loss: 0.21518085896968842


 79%|███████▉  | 6801/8584 [34:29<09:11,  3.23it/s]

train_loss: 0.17485876381397247


 80%|████████  | 6901/8584 [35:00<08:38,  3.25it/s]

train_loss: 0.3247925341129303


 82%|████████▏ | 7001/8584 [35:30<07:58,  3.31it/s]

train_loss: 0.32546892762184143


 83%|████████▎ | 7101/8584 [36:00<07:23,  3.35it/s]

train_loss: 0.2432672530412674


 84%|████████▍ | 7201/8584 [36:31<07:03,  3.26it/s]

train_loss: 0.27494484186172485


 85%|████████▌ | 7301/8584 [37:01<06:33,  3.26it/s]

train_loss: 0.32086455821990967


 86%|████████▌ | 7401/8584 [37:32<06:03,  3.26it/s]

train_loss: 0.27378156781196594


 87%|████████▋ | 7501/8584 [38:02<05:34,  3.24it/s]

train_loss: 0.20876918733119965


 89%|████████▊ | 7601/8584 [38:33<04:55,  3.33it/s]

train_loss: 0.29961317777633667


 90%|████████▉ | 7701/8584 [39:03<04:24,  3.34it/s]

train_loss: 0.5068380236625671


 91%|█████████ | 7801/8584 [39:33<03:59,  3.27it/s]

train_loss: 0.2993374466896057


 92%|█████████▏| 7901/8584 [40:04<03:24,  3.34it/s]

train_loss: 0.1660275161266327


 93%|█████████▎| 8001/8584 [40:34<02:55,  3.31it/s]

train_loss: 0.3249013125896454


 94%|█████████▍| 8101/8584 [41:04<02:26,  3.31it/s]

train_loss: 0.4079590439796448


 96%|█████████▌| 8201/8584 [41:35<01:57,  3.26it/s]

train_loss: 0.24881604313850403


 97%|█████████▋| 8301/8584 [42:05<01:26,  3.27it/s]

train_loss: 0.14462189376354218


 98%|█████████▊| 8401/8584 [42:35<00:56,  3.24it/s]

train_loss: 0.28190112113952637


 99%|█████████▉| 8501/8584 [43:06<00:24,  3.34it/s]

train_loss: 0.2752790153026581


100%|██████████| 8584/8584 [43:31<00:00,  3.29it/s]
  0%|          | 0/154 [00:00<?, ?it/s]

Loss: 0.30684887012076206


  1%|▏         | 2/154 [00:00<00:56,  2.70it/s]

Val loss: 0.232


 66%|██████▌   | 102/154 [00:10<00:05, 10.04it/s]

Val loss: 0.283


100%|██████████| 154/154 [00:15<00:00,  9.74it/s]
  0%|          | 0/8584 [00:00<?, ?it/s]

Loss: 0.005505850308601643
Epoch: 01
	Train Loss: 0.307
	 Val. Loss: 0.006
Epoch :  1


  0%|          | 1/8584 [00:00<1:39:37,  1.44it/s]

train_loss: 0.2126312106847763


  1%|          | 101/8584 [00:31<42:47,  3.30it/s]

train_loss: 0.371002733707428


  2%|▏         | 201/8584 [01:01<42:35,  3.28it/s]

train_loss: 0.18701884150505066


  3%|▎         | 233/8584 [01:11<42:43,  3.26it/s]


KeyboardInterrupt: 

## Part 3 : Val

In [None]:
def val():
    model.eval()
    totloss = 0
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    for i, batch in enumerate(tqdm(val_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        total_loss+=loss.item()
        if i % 100 == 0:
            print(f"Val loss: {loss:.3f}")
    lossavg = totloss/len(train_dataloader)
    print(f"Loss: {lossavg}")
    return lossavg

In [None]:
val()