In [None]:
# References
# https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891
# https://ai.plainenglish.io/bert-pytorch-implementation-prepare-dataset-part-1-efd259113e5a

import torch
from torch import nn
from pathlib import Path
from tokenizers import Tokenizer
from huggingface_hub import PyTorchModelHubMixin
import os
import torch
import re
import random
import transformers, datasets
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer
import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
import itertools
import math
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam
import math
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [None]:
!nvidia-smi

In [None]:
#Hyperparameters
n_warmup_steps = 1000 #4000
beta_1 = 0.9
beta_2 = 0.98
epsilon = 1e-9
n_segments = 3
block_size = 64
batch_size = 64
embeddings_dims = 128
attn_dropout = 0.1
no_of_heads = 2 #IMP needs to be thoroughly calculated
dropout = 0.1
epochs = 20
max_lr = 2.5e-5
no_of_encoder_layers = 2 #IMP needs to be thoroughly calculated

In [None]:
#Data

!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -qq cornell_movie_dialogs_corpus.zip
!rm cornell_movie_dialogs_corpus.zip
!mkdir datasets
!mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
!mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets





In [None]:
### loading all data into memory
corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

### splitting text using special lines
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

### generate convo  pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1]) #Evaluates the string as a list now
    for i in range(len(ids)):
        pair = []
        
        if i == len(ids) - 1:
            break
        # print(ids[i])
        first = lines_dic[ids[i]].strip()  
        second = lines_dic[ids[i+1]].strip() 

        pair.append(' '.join(first.split()[:block_size]))
        pair.append(' '.join(second.split()[:block_size]))
        pairs.append(pair)
        # break
    # break
    

In [None]:
len(pairs) #Total pairs-> 221K

##########W Tokenization #################
# WordPiece tokenizer

### save data as txt file
text_data = []
file_count = 0

def clean_text(text):
    return text.encode('utf-8', 'ignore').decode('utf-8')

for sample in tqdm([x[0] for x in pairs]):
    # cleaned_sample = clean_text(sample)
    text_data.append(sample)

    # once we hit the 10K mark, save to file
    # if len(text_data) == 10000:
with open(f'./datasets/text.txt', 'w', encoding='utf-8') as fp:
    fp.write('\n'.join(text_data))
        # text_data = []
        # file_count += 1

paths = 'datasets/text.txt'
# print(paths)
### training own tokenizer
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

tokenizer.train( 
    files=paths,
    vocab_size=10000, 
    min_frequency=5,
    # limit_alphabet=1000, 
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )

if not os.path.exists('./bert-it-1'): os.mkdir('./bert-it-1')
tokenizer.save_model('./bert-it-1', 'bert-it')
tokenizer = BertTokenizer.from_pretrained('./bert-it-1/bert-it-vocab.txt', local_files_only=True)

#Setting vocab size
vocab_size = tokenizer.vocab_size

In [None]:
class BERTDataset(Dataset):
    def __init__(self, data_pair, tokenizer, seq_len=block_size):

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.corpus_lines = len(data_pair)
        self.lines = data_pair

    def __len__(self):
        return self.corpus_lines

    
    def __getitem__(self,item):
        
        #Getting NSP sentences
        sent1, sent2, is_next = self.get_nsp(item)
        
        #Getting masked sentences
        sent1_masked , label1 = self.get_masked_sentences(sent1)
        sent2_masked , label2 = self.get_masked_sentences(sent2)
        
        #Adding CLS and SEP tokens
        sent1_masked_cls_and_sep_aded = [self.tokenizer.vocab['[CLS]']]+ sent1_masked + [self.tokenizer.vocab['[SEP]']]
        sent2_masked_cls_and_sep_aded = sent2_masked + [self.tokenizer.vocab['[SEP]']]
        
        label1_padding_added = [self.tokenizer.vocab['[PAD]']] + label1 + [self.tokenizer.vocab['[PAD]']] #because of [1:-1] thing (I removed CLS and SEP token before) and the middle one because of the added [SEP] token
        label2_padding_added = label2 + [self.tokenizer.vocab['[PAD]']]
        
        #Segment ids
        segment_ids = [1 for _ in range(len(sent1_masked_cls_and_sep_aded))] + [2 for _ in range(len(sent2_masked_cls_and_sep_aded))]
        

        
        
        #Combine the sentences
        combined_sentence = sent1_masked_cls_and_sep_aded + sent2_masked_cls_and_sep_aded
        combined_labels = label1_padding_added + label2_padding_added
        
        #Padding and truncation
        if(len(combined_sentence) > self.seq_len): 
            combined_sentence = combined_sentence[:self.seq_len]
            combined_labels = combined_labels[:self.seq_len]
            segment_ids = segment_ids[:self.seq_len]
        elif (len(combined_sentence) < self.seq_len):
            while(len(combined_sentence) < self.seq_len):
                combined_sentence = [self.tokenizer.vocab['[PAD]']] + combined_sentence
                segment_ids = [0] + segment_ids
                combined_labels = [0] + combined_labels
                
        values = {
            'bert_input_masked': combined_sentence,
            'bert_input_labels': combined_labels,
            'segment_ids': segment_ids,
            'is_next': is_next
        }


        assert len(combined_labels) == len(combined_sentence)
        return {key: torch.tensor(value) for key, value in values.items()} #Must be converted into tensor 
    
    def get_nsp(self,index):
            t1, t2 = self.lines[index][0], self.lines[index][1]
            
            prob = random.random()
            if(prob < 0.5):
                return t1, t2, 1
            else:
                return t1, self.lines[random.randrange(len(pairs))][1], 0
        
            
    def get_masked_sentences(self, sentence):
        tokens = self.tokenizer(sentence)['input_ids'][1:-1]
        mask_label = []
        output = []

        for token in tokens:
            prob = random.random()

            if prob < 0.15:
                prob /= 0.15

                if prob < 0.8:
                    output.append(self.tokenizer.vocab['[MASK]'])
                elif prob < 0.9:
                    output.append(random.randrange(len(self.tokenizer.vocab)))
                else:
                    output.append(token)
                mask_label.append(token)
            else:
                output.append(token)
                mask_label.append(0)

        assert len(output) == len(mask_label)
        return output, mask_label
   
        
    

In [None]:
#Creating an instance of the dataset class
dataset = BERTDataset(data_pair=pairs, tokenizer=tokenizer, seq_len=block_size)

# Assuming 'dataset' is already created
# Split the dataset into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

import os
#Creating a dataloader
# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=os.cpu_count())



In [None]:
#Test
#Loading a sample from the batch
sample_data = next(iter(train_loader))
# print('Batch Size', sample_data['bert_input_masked'].size())
print(sample_data)

# 3 is MASK
# result = dataset[random.randrange(len(dataset))]
# print(result)
# print(tokenizer.convert_ids_to_tokens(result['bert_input_masked']))

In [None]:
# Text embeddings
class TextEmbeddings(nn.Module):
    def __init__(
        self,
        vocab_size = vocab_size,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.embeddings_table = nn.Embedding(num_embeddings = vocab_size, embedding_dim=embeddings_dims, device=device, padding_idx=0) #Just a look up table to convert the toekns_ids to some numbers
        # nn.init.normal_(self.embeddings_table.weight.data, mean=0, std=0.02)

    def forward(self, x):
        return self.embeddings_table(x) 

In [None]:
# Segment embeddings
class SegmentEmbeddings(nn.Module):
    def __init__(
        self,
        n_segments = n_segments,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.seg_embds = nn.Embedding(num_embeddings = n_segments, embedding_dim=embeddings_dims, device=device, padding_idx=0)
    def forward(self, x):
        return self.seg_embds(x)

In [None]:
#Layer Normalization

class LayerNormalization(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embeddings_dims, device=device)

    def forward(self, x):
        return self.layer_norm(x)

In [None]:
#FeedForward Neural Network

class MLPBlock(nn.Module):
    def __init__(
        self,
        dropout = dropout,
        embeddings_size = embeddings_dims,
        # inner_dimensional_states: int = 3072
    ):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(device=device, in_features=embeddings_size, out_features= 4 * embeddings_size),
            nn.ReLU(),
            nn.Linear(device=device, in_features= 4 * embeddings_size, out_features=embeddings_size),
            nn.Dropout(p = dropout)
        )

    def forward(self, x):
        # mlp_weights_init = self.mlp.apply(weights_init)
        return self.mlp(x)

In [None]:
#Single Attention Head

class AttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, x, mask=None):
        # batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        # masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        if(mask != None):
            masked_values = weights.masked_fill(mask == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out
        else:
            weights_normalized = nn.functional.softmax(weights, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out

In [None]:
# MHA

class MHA(nn.Module):
    def __init__(
        self,
        mask = None,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings
        
    def forward(self, x, mask):
        concat = torch.cat([head(x, mask) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [None]:

import math
class PositionEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
        block_size = block_size
    ):
        super().__init__()
        
        self.pos_embd = torch.ones((block_size, embeddings_dims), device=device, requires_grad=False)
        for pos in range(block_size):
            for i in range(embeddings_dims, 2):
                self.pos_embd[pos, i] = math.sin(pos/(10000**((2*i)/embeddings_dims)))
                self.pos_embd[pos, i + 1] = math.cos(pos/(10000**((2*(i + 1))/embeddings_dims)))
                
    def forward(self,x):
        pos_embd = self.pos_embd
        pos_embd = pos_embd.unsqueeze(0)
        return pos_embd

In [None]:
# Decoder Block

class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        mask=None
    ):
        super().__init__()

        self.mha = MHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = LayerNormalization(embeddings_dims=embeddings_dims)
        self.layer_norm2 = LayerNormalization(embeddings_dims=embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, x, mask):
        x = self.layer_norm1(x + self.mha(x, mask))
        x = self.layer_norm2(x + self.mlp_block(x))

        return x

In [None]:
# Encoder Block

class EncoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_encoder_layers = no_of_encoder_layers,
        vocab_size = vocab_size,
        n_segments = n_segments,
        mask=None
    ):
        super().__init__()
        self.encoder_layer_stacked = []
        self.positional_embeddings = PositionEmbeddings(block_size=block_size, embeddings_dims=embeddings_dims)
        self.text_embds = TextEmbeddings(vocab_size=vocab_size, embeddings_dims=embeddings_dims)
        # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        self.no_of_encoder_layers = no_of_encoder_layers
        self.layer_norm = LayerNormalization(embeddings_dims=embeddings_dims)
        # self.encoder_layers = nn.Sequential(*[TransformerEncoderBlock(embeddings_dims=embeddings_dims, attn_dropout=attn_dropout, no_of_heads=no_of_heads, dropout=dropout) for _ in range(no_of_encoder_layers)])
        self.dropout = nn.Dropout(p = dropout)
        self.seg_embds = SegmentEmbeddings(n_segments=n_segments, embeddings_dims=embeddings_dims)
        self.encoder_layer = TransformerEncoderBlock(embeddings_dims=embeddings_dims, attn_dropout=attn_dropout, no_of_heads=no_of_heads, dropout=dropout)
        for _ in range(no_of_encoder_layers):
            self.encoder_layer_stacked.append(self.encoder_layer)

    def forward(self, x, segment_ids):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1) #Create a boolean matrix of size (block_size * block_size) to mask all the padded tokens
        x = (self.text_embds(x) + self.seg_embds(segment_ids) + self.positional_embeddings(x))*math.sqrt(embeddings_dims)
        x = self.dropout(x)
        for layer in self.encoder_layer_stacked:
            x = layer(x, mask=mask)  
        out = self.layer_norm(x)
        return x

In [None]:
#NSP

class NSP(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
    ):
        super().__init__()
        # self.encoder_block = EncoderModel(no_of_encoder_layers=no_of_encoder_layers, attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, block_size=block_size, dropout=dropout, vocab_size=vocab_size, n_segments=n_segments)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=2, device=device)
    

    def forward(self, x,  isnext):
        logits = self.linear_layer(x[:,0]) #to get the CLS token embeddings across all batches
        loss = nn.functional.cross_entropy(logits, isnext, ignore_index = 0)
        return loss, logits

In [None]:
#MLM

class MLM(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
        vocab_size = vocab_size
    ):
        super().__init__()
        # self.encoder_block = EncoderModel(no_of_encoder_layers=no_of_encoder_layers, attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, block_size=block_size, dropout=dropout, vocab_size=vocab_size, n_segments=n_segments)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, device=device)
        # self.linear_layer2 = nn.Linear(in_features=vocab_size, out_features=block_size, device=device)
        
    

    def forward(self, x,  mask_labels):
        # Get the logits from the linear layer
        logits = self.linear_layer1(x)  # logits: (batch_size, seq_len, vocab_size)
        # print(logits.shape)
        # Reshape logits and mask_labels for cross_entropy
        batch_size, seq_len, vocab_size = logits.shape
        logits = logits.view(-1, vocab_size)        # logits: (batch_size * seq_len, vocab_size)
        mask_labels = mask_labels.view(-1)          # mask_labels: (batch_size * seq_len)
        
        # Calculate the cross-entropy loss
        loss = nn.functional.cross_entropy(logits, mask_labels, ignore_index = 0)
        
        return loss, logits

In [None]:
#BERT

class BERT(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        vocab_size = vocab_size,
        n_segments = n_segments
    ):
        super().__init__()

        self.mlm = MLM(embeddings_dims=embeddings_dims, vocab_size=vocab_size)
        self.nsp = NSP(embeddings_dims=embeddings_dims)
        self.encoder_layer = EncoderModel(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, no_of_encoder_layers=no_of_encoder_layers, block_size=block_size,dropout=dropout,n_segments=n_segments)

    def forward(self, x, segment_ids, labels, isnext):
        x = self.encoder_layer(x, segment_ids)
        mlm_loss, mlm_logits = self.mlm(x, labels)
        nsp_loss, nsp_logits = self.nsp(x, isnext)
        return mlm_loss, mlm_logits, nsp_loss , nsp_logits

In [None]:
model = BERT(embeddings_dims=embeddings_dims, vocab_size=vocab_size)
model = model.to(device)

In [None]:
#Printing a summary of the architecture
from torchinfo import summary
sample_data = {key: value.to(device) for key, value in sample_data.items()}
summary(model=model,
        input_data=(sample_data['bert_input_masked'], sample_data['segment_ids'], sample_data['bert_input_labels'], sample_data['is_next']),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])
# model(result['bert_input_masked'],result['segment_ids'])

In [None]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, embeddings_dims, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(embeddings_dims, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad(set_to_none=True)

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr


In [None]:
#Setting up optimizer and lr scheduler

#For use cases without a lr scheduler
# optimizer = torch.optim.Adam(model.parameters(), lr = max_lr)

#For use cases with a lr scheduler
optimizer = torch.optim.Adam(model.parameters(), lr = max_lr, eps=epsilon, betas=(beta_1, beta_2))
lr_scheduler = ScheduledOptim(optimizer=optimizer,embeddings_dims=embeddings_dims,n_warmup_steps=n_warmup_steps)

In [None]:
import os
#For not showing the warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:

@torch.inference_mode()
def cal_val(val_loader):
    mlm_accuracy = []
    nsp_accuracy = []
    tot_mlm_loss = []
    tot_nsp_loss = []
    loss = []
    model.eval()
    for epoch in range(1):
        for data in val_loader:
            result = {key: value.to(device) for key,value in data.items()}
            mlm_loss, mlm_logits, loss_nsp, nsp_logits = model(result['bert_input_masked'], result['segment_ids'], result['bert_input_labels'], result['is_next'])
            loss_tot = mlm_loss + loss_nsp
            loss.append(loss_tot.item())
            tot_mlm_loss.append(mlm_loss.item())
            tot_nsp_loss.append(loss_nsp.item())
    mean_loss = sum(loss) / len(loss)

    mean_loss_mlm = sum(tot_mlm_loss) / len(tot_mlm_loss)
    mean_loss_nsp = sum(tot_nsp_loss) / len(tot_nsp_loss)
    model.train()
    return mean_loss_mlm,  mean_loss_nsp, mean_loss

In [None]:
#Training Loop
# from sklearn.metrics import accuracy_score
from tqdm import tqdm

model.train()
loss = []
items = []
for epoch in tqdm(range(epochs)):
    loss = []
    items = []
    correct = 0
    total_instances = 0
    nsp_accuracy = []
    mlm_accuracy = []
    tot_mlm_loss = []
    tot_nsp_loss = []
    for i, data in enumerate(train_loader):
        result = {key: value.to(device) for key,value in data.items()}
        mlm_loss, mlm_logits, loss_nsp, nsp_logits = model(result['bert_input_masked'], result['segment_ids'], result['bert_input_labels'], result['is_next'])
        loss_tot = mlm_loss + loss_nsp
        
        #Without lr scheduler
        # optimizer.zero_grad(set_to_none=True)
        # loss_tot.backward()
        # optimizer.step()
        
        #With lr scheduler
        lr_scheduler.zero_grad()
        loss_tot.backward()
        lr_scheduler.step_and_update_lr()
        
        loss.append(loss_tot.item())

        tot_mlm_loss.append(mlm_loss.item())
        tot_nsp_loss.append(loss_nsp.item())
        if(i % 100 == 0 or i == len(train_loader) - 1):
            mean_loss = sum(loss) / len(loss)
            # mean_accuracy_mlm = sum(mlm_accuracy) / len(mlm_accuracy)
            # mean_accuracy_nsp = sum(nsp_accuracy) / len(nsp_accuracy)
            mean_mlm_loss = sum(tot_mlm_loss) / len(tot_mlm_loss)
            mean_loss_nsp = sum(tot_nsp_loss) / len(tot_nsp_loss)
            
            #Validation losses
            val_mean_loss_mlm,val_mean_loss_nsp, val_loss = cal_val(val_loader)
            print("Steps: ",i, "Epoch: ", epoch , "Train loss: ", mean_loss, "Train NSP Loss: ", mean_loss_nsp, "Train MLM Loss: ", mean_mlm_loss, "Val loss: ", val_loss, "Val NSP Loss: ", val_mean_loss_nsp, "Val MLM Loss: ", val_mean_loss_mlm)
            
        