Imports


Main code resource: https://www.analyticsvidhya.com/blog/2021/05/bert-for-natural-language-inference-simplified-in-pytorch/#h2_3

In [None]:
!pip install transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import torch
import numpy as np
import pandas as pd
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16 and int(pyarrow.__version__.split('.')[0]) == 0:
    import os
    os.kill(os.getpid(), 9)
from datasets import list_datasets, list_metrics, load_dataset_builder, load_dataset, load_metric, Dataset
from transformers import AutoTokenizer


In [None]:
SEED = 1111
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

Check for GPU

In [None]:
import tensorflow as tf

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: Tesla T4


# Load MNLI Dataset

In [None]:
ds_builder = load_dataset_builder('multi_nli')
ds_builder.info.description

'The Multi-Genre Natural Language Inference (MultiNLI) corpus is a\ncrowd-sourced collection of 433k sentence pairs annotated with textual\nentailment information. The corpus is modeled on the SNLI corpus, but differs in\nthat covers a range of genres of spoken and written text, and supports a\ndistinctive cross-genre generalization evaluation. The corpus served as the\nbasis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.\n'

In [None]:
ds_builder.info.features

{'promptID': Value(dtype='int32', id=None),
 'pairID': Value(dtype='string', id=None),
 'premise': Value(dtype='string', id=None),
 'premise_binary_parse': Value(dtype='string', id=None),
 'premise_parse': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'hypothesis_binary_parse': Value(dtype='string', id=None),
 'hypothesis_parse': Value(dtype='string', id=None),
 'genre': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}

In [None]:
mnli = load_dataset('multi_nli', split='train')
mnli = pd.DataFrame(mnli)



In [None]:
mnli.head(3)

Unnamed: 0,promptID,pairID,premise,premise_binary_parse,premise_parse,hypothesis,hypothesis_binary_parse,hypothesis_parse,genre,label
0,31193,31193n,Conceptually cream skimming has two basic dime...,( ( Conceptually ( cream skimming ) ) ( ( has ...,(ROOT (S (NP (JJ Conceptually) (NN cream) (NN ...,Product and geography are what make cream skim...,( ( ( Product and ) geography ) ( ( are ( what...,(ROOT (S (NP (NN Product) (CC and) (NN geograp...,government,1
1,101457,101457e,you know during the season and i guess at at y...,( you ( ( know ( during ( ( ( the season ) and...,(ROOT (S (NP (PRP you)) (VP (VBP know) (PP (IN...,You lose the things to the following level if ...,( You ( ( ( ( lose ( the things ) ) ( to ( the...,(ROOT (S (NP (PRP You)) (VP (VBP lose) (NP (DT...,telephone,0
2,134793,134793e,One of our number will carry out your instruct...,( ( One ( of ( our number ) ) ) ( ( will ( ( (...,(ROOT (S (NP (NP (CD One)) (PP (IN of) (NP (PR...,A member of my team will execute your orders w...,( ( ( A member ) ( of ( my team ) ) ) ( ( will...,(ROOT (S (NP (NP (DT A) (NN member)) (PP (IN o...,fiction,0


Preprocess MNLI Dataset

In [None]:
test = mnli[mnli['genre'] == 'telephone']
train = mnli[mnli['genre'] != 'telephone']

In [None]:
test = test.loc[:, ['premise', 'hypothesis', 'label']]
test = test.iloc[0:1600,:]
test.head(3)

Unnamed: 0,premise,hypothesis,label
1,you know during the season and i guess at at y...,You lose the things to the following level if ...,0
4,yeah i tell you what though if you go price so...,The tennis shoes have a range of prices.,1
5,my walkman broke so i'm upset now i just have ...,I'm upset that my walkman broke and now I have...,0


In [None]:
train_df = train.loc[:, ['premise', 'hypothesis', 'label']]
train_df = train_df.iloc[0:8000,:]
train_df.head(3)

Unnamed: 0,premise,hypothesis,label
0,Conceptually cream skimming has two basic dime...,Product and geography are what make cream skim...,1
2,One of our number will carry out your instruct...,A member of my team will execute your orders w...,0
3,How do you know? All this is their information...,This information belongs to them.,0


In [None]:
train = train_df.sample(frac=.8)
val = train_df.drop(train.index)
print(train.shape)
print(val.shape)

(6400, 3)
(1600, 3)


In [None]:
def trim_sentence(sentence):
    try:
        sent = sentence.split()
        sent = sent[:128]
        return " ".join(sent)
    except:
        return sentence

In [None]:
# trim sentences
train['premise'] = train['premise'].apply(trim_sentence)
val['premise'] = val['premise'].apply(trim_sentence)
test['premise'] = test['premise'].apply(trim_sentence)
train['hypothesis'] = train['hypothesis'].apply(trim_sentence)
val['hypothesis'] = val['hypothesis'].apply(trim_sentence)
test['hypothesis'] = test['hypothesis'].apply(trim_sentence)

# add special tokens
train['premise'] = '[CLS] ' + train['premise'] + ' [SEP] '
val['premise'] = '[CLS] ' + val['premise'] + ' [SEP] '
test['premise'] = '[CLS] ' + test['premise'] + ' [SEP] '
train['hypothesis'] = train['hypothesis'] + ' [SEP]'
val['hypothesis'] = val['hypothesis'] + ' [SEP]'
test['hypothesis'] = test['hypothesis'] + ' [SEP]'

train.head(5)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test['premise'] = test['premise'].apply(trim_sentence)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test['hypothesis'] = test['hypothesis'].apply(trim_sentence)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test['premise'] = '[CLS] ' + test['premise'] + ' [SEP] '
A value is trying to be set on a

Unnamed: 0,premise,hypothesis,label
675,"[CLS] No, I can't say it did. [SEP]",There were no clues. [SEP],1
2502,"[CLS] Finally, he poured a few drops of the co...",He put drops in a test tube and sealed it. [SEP],0
1188,[CLS] But there is one place where Will's jour...,Will's articles are only good in regards to sp...,0
4642,[CLS] He has been serving as Vice Chair of the...,He was dismissed as the Board's Vice Chair aft...,2
490,"[CLS] Adrienne Worthy, executive director of L...",Federal funds are dependent upon the populous ...,0


Tokenize Dataset

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token
print(cls_token, sep_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [None]:
cls_token_idx = tokenizer.cls_token_id
sep_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id
print(cls_token_idx, sep_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [None]:
max_input_length = 256
def tokenize_bert(sentence):
    tokens = tokenizer.tokenize(sentence) 
    return tokens
def split_and_cut(sentence):
    tokens = sentence.strip().split(" ")
    tokens = tokens[:max_input_length]
    return tokens

#Get list of 0s 
def get_sent1_token_type(sent):
    try:
        return [0]* len(sent)
    except:
        return []
#Get list of 1s
def get_sent2_token_type(sent):
    try:
        return [1]* len(sent)
    except:
        return []
#combine from lists
def combine_seq(seq):
    return " ".join(seq)
#combines from lists of int
def combine_mask(mask):
    mask = [str(m) for m in mask]
    return " ".join(mask)

In [None]:
# tokenize
train['premise_token'] = train['premise'].apply(tokenize_bert)
train['hypothesis_token'] = train['hypothesis'].apply(tokenize_bert)
test['premise_token'] = test['premise'].apply(tokenize_bert)
test['hypothesis_token'] = test['hypothesis'].apply(tokenize_bert)
val['premise_token'] = val['premise'].apply(tokenize_bert)
val['hypothesis_token'] = val['hypothesis'].apply(tokenize_bert)
# token type
train['premise_token_type'] = train['premise_token'].apply(get_sent1_token_type)
train['hypothesis_token_type'] = train['hypothesis_token'].apply(get_sent2_token_type)
test['premise_token_type'] = test['premise_token'].apply(get_sent1_token_type)
test['hypothesis_token_type'] = test['hypothesis_token'].apply(get_sent2_token_type)
val['premise_token_type'] = val['premise_token'].apply(get_sent1_token_type)
val['hypothesis_token_type'] = val['hypothesis_token'].apply(get_sent2_token_type)
# combine premise & hypothesis
train['sequence'] = train['premise_token'] + train['hypothesis_token']
test['sequence'] = test['premise_token'] + test['hypothesis_token']
val['sequence'] = val['premise_token'] + val['hypothesis_token']
# attention mask
train['attention_mask'] = train['sequence'].apply(get_sent2_token_type)
test['attention_mask'] = test['sequence'].apply(get_sent2_token_type)
val['attention_mask'] = val['sequence'].apply(get_sent2_token_type)
# combine token type ids
train['token_type'] = train['premise_token_type'] + train['hypothesis_token_type']
test['token_type'] = test['premise_token_type'] + test['hypothesis_token_type']
val['token_type'] = val['premise_token_type'] + val['hypothesis_token_type']
# make sequential
train['sequence'] = train['sequence'].apply(combine_seq)
train['attention_mask'] = train['attention_mask'].apply(combine_mask)
train['token_type'] = train['token_type'].apply(combine_mask)
test['sequence'] = test['sequence'].apply(combine_seq)
test['attention_mask'] = test['attention_mask'].apply(combine_mask)
test['token_type'] = test['token_type'].apply(combine_mask)
val['sequence'] = val['sequence'].apply(combine_seq)
val['attention_mask'] = val['attention_mask'].apply(combine_mask)
val['token_type'] = val['token_type'].apply(combine_mask)
# change label

In [None]:
!mkdir data

mkdir: cannot create directory ‘data’: File exists


In [None]:
train = train.loc[:,['label', 'sequence', 'attention_mask', 'token_type']]
train.to_csv('data/train_mnli.csv', index=False)
train.head(5)

Unnamed: 0,label,sequence,attention_mask,token_type
675,1,"[CLS] no , i can ' t say it did . [SEP] there ...",1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1,0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1
2502,0,"[CLS] finally , he poured a few drops of the c...",1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 ...
1188,0,[CLS] but there is one place where will ' s jo...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
4642,2,[CLS] he has been serving as vice chair of the...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 ...
490,0,"[CLS] ad ##rien ##ne worthy , executive direct...",1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...


In [None]:
test = test.loc[:, ['label', 'sequence', 'attention_mask', 'token_type']]
test.to_csv('data/test_mnli.csv', index=False)
test.head(5)

Unnamed: 0,label,sequence,attention_mask,token_type
1,0,[CLS] you know during the season and i guess a...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
4,1,[CLS] yeah i tell you what though if you go pr...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
5,0,[CLS] my walk ##man broke so i ' m upset now i...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 ...
16,0,[CLS] well you see that on television also [SE...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1,0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
23,0,[CLS] well it ' s been very interesting [SEP] ...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1,0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1


In [None]:
val = val.loc[:, ['label', 'sequence', 'attention_mask', 'token_type']]
val.to_csv('data/val_mnli.csv', index=False)
val.head(5)

Unnamed: 0,label,sequence,attention_mask,token_type
0,1,[CLS] conceptual ##ly cream ski ##mming has tw...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 ...
9,2,[CLS] at the end of rue des francs - bourgeois...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
12,1,[CLS] it ' s not that the questions they asked...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
14,2,[CLS] i don ' t mean to be g ##lib about your ...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
17,2,[CLS] vr ##enna and i both fought him and he n...,1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ...,0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 ...


In [None]:
# To convert back attention mask and token type ids to integer.
def convert_to_int(tok_ids):
    tok_ids = [int(x) for x in tok_ids]
    return tok_ids

# Load BERT Model

In [None]:
!pip install torchtext==0.10.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from torchtext.legacy import data

TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

LABEL = data.LabelField()

ATTENTION = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = pad_token_idx)

TTYPE = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = 1)


In [None]:
fields = [('label', LABEL), ('sequence', TEXT), ('attention_mask', ATTENTION), ('token_type', TTYPE)]

In [None]:
train_data, valid_data, test_data = data.TabularDataset.splits(path='data', 
                                        train = 'train_mnli.csv',
                                        validation='val_mnli.csv',
                                        test = 'test_mnli.csv',
                                        format = 'csv',
                                        fields = fields,
                                        skip_header = True)


In [None]:
LABEL.build_vocab(train_data)
train_data_len = len(train_data)

In [None]:
BATCH_SIZE = 16

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort_key = lambda x: len(x.sequence),
    sort_within_batch = False, 
    device = device)


In [None]:
from transformers import BertModel
bert_model = BertModel.from_pretrained('bert-base-uncased')

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
import torch.nn as nn

class BERTNLIModel(nn.Module):
    def __init__(self,
                 bert_model,
                 hidden_dim,
                 output_dim,
                ):
        
        super().__init__()
        
        self.bert = bert_model
        
        embedding_dim = bert_model.config.to_dict()['hidden_size']

        self.out = nn.Linear(embedding_dim, output_dim)
        
        
    def forward(self, sequence, attn_mask, token_type):
                
        embedded = self.bert(input_ids = sequence, attention_mask = attn_mask, token_type_ids= token_type)[1]
        
        output = self.out(embedded)
        
        return output

In [None]:
HIDDEN_DIM = 512
OUTPUT_DIM = len(LABEL.vocab)

model = BERTNLIModel(bert_model,
                         HIDDEN_DIM,
                         OUTPUT_DIM,
                        ).to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 109,484,547 trainable parameters


In [None]:
from transformers import optimization

In [None]:
import torch.optim as optim

#optimizer = optim.Adam(model.parameters())
optimizer = optimization.AdamW(model.parameters(),lr=2e-5,eps=1e-6, correct_bias=False)

def get_scheduler(optimizer, warmup_steps):
  scheduler = optimization.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
  return scheduler




In [None]:
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
def categorical_accuracy(preds, y):
  max_preds = preds.argmax(dim = 1, keepdim = True)

  correct = (max_preds.squeeze(1)==y).float()

  return correct.sum() / len(y)

In [None]:
max_grad_norm = 1
fp16 = False

def train(model, iterator, optimizer, criterion, scheduler):
    #print(iterator)
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:

        optimizer.zero_grad() # clear gradients first
        torch.cuda.empty_cache() # releases all unoccupied cached memory 
        

        sequence = batch.sequence
        attn_mask = batch.attention_mask
        token_type = batch.token_type
        #print(sequence.size(), attn_mask.size(), token_type.size())
        #print(sequence[0])
        #print(attn_mask[0])
        #print(token_type[0])
        label = batch.label
        
        predictions = model(sequence, attn_mask, token_type)
        
        #predictions = [batch_size, 3]
        #print(predictions.size())
        
        loss = criterion(predictions, label)
        
        acc = categorical_accuracy(predictions, label)
        
        if fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
        else:
            loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
[K     |████████████████████████████████| 512 kB 5.1 MB/s 
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.0


In [None]:
from torchmetrics.classification import MulticlassCalibrationError

def evaluate(model, iterator, criterion):
    #print(iterator)
    epoch_loss = 0
    epoch_acc = 0
    epoch_ece = 0
    epoch_nll = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:
            #print(batch)

            sequence = batch.sequence
            attn_mask = batch.attention_mask
            token_type = batch.token_type
            labels = batch.label
                        
            predictions = model(sequence, attn_mask, token_type)
            
            loss = criterion(predictions, labels)
            nll_loss = nn.NLLLoss().to(device)
            m = nn.LogSoftmax(dim=1)
            nll = nll_loss(m(predictions), labels)
                
            acc = categorical_accuracy(predictions, labels)
            metric = MulticlassCalibrationError(num_classes=3, n_bins=10, norm='l1')
            ece_score = metric(predictions, labels)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_ece += ece_score.item()
            epoch_nll += nll.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator), epoch_ece / len(iterator), epoch_nll / len(iterator)


In [None]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [None]:
import math
N_EPOCHS = 2

warmup_percent = 0.2
total_steps = math.ceil(N_EPOCHS*train_data_len*1./BATCH_SIZE)
warmup_steps = int(total_steps*warmup_percent)
scheduler = get_scheduler(optimizer, warmup_steps)

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, scheduler)
    valid_loss, valid_acc, valid_ece, valid_nll = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'bert-nli.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}% | Val. ECE: {valid_ece} | Val. NLL: {valid_nll}')



Epoch: 01 | Epoch Time: 1m 53s
	Train Loss: 0.202 | Train Acc: 93.64%
	 Val. Loss: 0.876 |  Val. Acc: 72.88% | Val. ECE: 0.2275235214829445 | Val. NLL: 0.8763128073513508
Epoch: 02 | Epoch Time: 1m 52s
	Train Loss: 0.094 | Train Acc: 97.09%
	 Val. Loss: 1.001 |  Val. Acc: 72.50% | Val. ECE: 0.24089108720421792 | Val. NLL: 1.0013126012682916


In [None]:
model.load_state_dict(torch.load('bert-nli.pt'))
test_loss, test_acc, test_ece, test_nll = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}% | Test ECE: {test_ece} | Test NLL: {test_nll}')