In [1]:
import requests , os , zipfile , shutil

In [2]:
# data_url = "http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"

# zip_name = data_url.split("/")[-1]
# directory = zip_name.split(".")[0]

# with open(zip_name , "wb") as f:
#     r = requests.get(data_url)
#     f.write(r.content)
    
# with zipfile.ZipFile(zip_name,"r") as zip_ref:
#     zip_ref.extractall(directory)



In [3]:
# data_dir = "dataset"
# os.makedirs(data_dir,exist_ok=True)

# shutil.move(r"cornell_movie_dialogs_corpus\cornell movie-dialogs corpus\movie_conversations.txt" , r"dataset\movie_conversations.txt")
# shutil.move(r"cornell_movie_dialogs_corpus\cornell movie-dialogs corpus\movie_lines.txt" , r"dataset\movie_lines.txt")


In [4]:
import os
from pathlib import Path
import re
import random
import numpy as np
import itertools
import math
from dataclasses import dataclass

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer
from tokenizers import BertWordPieceTokenizer

In [6]:
@dataclass
class ModelArgs:
    seq_len = 64
    batch_size = 128
    num_heads = 4
    embedding_dim = 128
    
    ffnn_units = 4*embedding_dim
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    nx = 2
    num_heads = 4
    
    dropout = 0.1

In [7]:
corpus_movie_conv = "./dataset/movie_conversations.txt"
corpus_movie_lines = "./dataset/movie_lines.txt"

with open(corpus_movie_conv,"r",encoding="iso-8859-1") as f:
    conv = f.readlines()
    
with open(corpus_movie_lines,"r",encoding="iso-8859-1") as f:
    lines = f.readlines()
    
    
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]
    
    
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i == len(ids)-1:
            break
        first = lines_dic[ids[i]].strip()
        second = lines_dic[ids[i+1]].strip()
        
        qa_pairs.append(" ".join(first.split()[:ModelArgs.seq_len]))
        qa_pairs.append(" ".join(second.split()[:ModelArgs.seq_len]))
        
        pairs.append(qa_pairs)

        

In [8]:
pairs[0]

['Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [9]:
os.makedirs("./data",exist_ok=True)
text_data = []
file_count = 0

for sample in [x[0] for x in pairs]:
    text_data.append(sample)
    
    if len(text_data)==1000:
        with open(f"./data/file_{file_count}.txt","w",encoding="utf-8") as f:
            f.write("\n".join(text_data))
        text_data = []
        file_count += 1
        
paths = [str(x) for x in Path("./data").glob('**/*.txt')]

tokenizer = BertWordPieceTokenizer(
    clean_text = True,
    handle_chinese_chars = False,
    strip_accents = False,
    lowercase = False
)

tokenizer.train(
    files=paths,
    vocab_size=30_000,
    min_frequency=5,
    limit_alphabet = 1000,
    wordpieces_prefix="##",
    special_tokens=["[PAD]","[CLS]","[SEP]","[MASK]","[UNK]"]
)


os.makedirs("./bert-it-l",exist_ok=True)
tokenizer.save_model("./bert-it-l","bert-it")


['./bert-it-l\\bert-it-vocab.txt']

In [10]:
tokenizer = BertTokenizer("./bert-it-l/bert-it-vocab.txt")

In [11]:
class BERTDataset(Dataset):
    def __init__(self,data_pair, tokenizer, seq_len=ModelArgs.seq_len):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.corpus_lines = len(data_pair)
        self.lines = data_pair
        
    def __len__(self):
        return self.corpus_lines
    
    def __getitem__(self,index):
        
        # 1. get a random sentence pair either it is positive or negative represend by is_next
        t1,t2,is_next_label =self._get_sent(index)
  
  
        # 2. replace random words in a sentece with mask / random wrods      
        t1_random , t1_label = self._random_word(t1)
        t2_random , t2_label = self._random_word(t2)
        
        # 3. Add cls and sep tokens at the start and end of the sequence & add pad token for the labels
        t1 = [self.tokenizer.vocab["[CLS]"]]+t1_random+[self.tokenizer.vocab["[SEP]"]]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        t1_label = [self.tokenizer.vocab["[PAD]"]]+t1_label+[self.tokenizer.vocab["[PAD]"]]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]
        
        # 4. combine sentence 1 and sentence 2
        # add pad tokens to make the input sentence equal to the seq_len
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1+t2)[:self.seq_len]
        bert_label = (t1_label+t2_label)[:self.seq_len]
        padding = [self.tokenizer.vocab["[PAD]"] for _ in range(self.seq_len - len(bert_input))]
        segment_label.extend(padding)
        bert_input.extend(padding)
        bert_label.extend(padding)
        
        output = {"bert_input":bert_input,
                  "bert_label":bert_label,
                  "segment_label":segment_label,
                  "is_next_label":is_next_label}
        
        return {key:torch.tensor(value) for key,value in output.items()}

        
    def _random_word(self,sentence):
        tokens = sentence.split()
        
        output = []
        output_label = []
        
        for i,token in enumerate(tokens):
            prob = random.random()
            
            token_id = self.tokenizer(token)["input_ids"][1:-1]
            
            if prob < 0.15:
                prob /= 0.15
                
                if prob < 0.8:
                    for i in range(len(token_id)):
                        output.append(self.tokenizer.vocab["[MASK]"])
                        
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        output.append(random.randrange(len(self.tokenizer.vocab)))
            
                else:
                    output.append(token_id)
                    
                output_label.append(token_id)
            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)
                    
        # flattening
        output = list(itertools.chain(*[[x] if not isinstance(x , list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x,list) else x for x in output_label]))
        
        assert len(output) == len(output_label)
        
        return output,output_label
    
    
    def _get_sent(self,index):
        t1,t2 = self._get_corpus_line(index)
        
        p = random.random()
        
        if p > 0.5:
            return t1,t2,1
        else :
            return t1,self._get_random_line(),0
    
    
    def _get_corpus_line(self,index):
        return self.lines[index][0],self.lines[index][1]
        
    def _get_random_line(self):
        "return a random sentence from the second pairs "
        return self.lines[random.randrange(len(self.lines))][1]

In [12]:
# tokenizer(pairs[0][0])["input"]

In [13]:
class PositionalEmbeddings(nn.Module):
    def __init__(self,embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.pe = torch.zeros(size=(ModelArgs.seq_len,ModelArgs.embedding_dim))
        
        
        for pos in range(self.pe.shape[0]):
            for i in range(0,embedding_dim,2):
                self.pe[pos,i] = math.sin(pos / 10_000 ** (2 * i / embedding_dim))
                self.pe[pos,i+1] = math.cos(pos / 10_000 ** (2 * i / embedding_dim))
            
        self.register_buffer("pos_encoding",self.pe)
        
        
    
    def forward(self):
        return self.pos_encoding.unsqueeze(0).expand(ModelArgs.batch_size , -1 , -1)
        

In [14]:
class BERTEmbeddings(nn.Module):
    def __init__(self,vocab_size,embedding_dim,dropout):
        super().__init__()
        
        self.token = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)
        self.segment = nn.Embedding(3,embedding_dim=embedding_dim)
        self.position = PositionalEmbeddings(embedding_dim=embedding_dim)
        self.dropout =  nn.Dropout(p=dropout)
        
    def forward(self,sequence,segment_labels):
        
        # print(self.token(sequence).device)
        # print(self.segment(segment_labels).device)
        # print(self.position().device)
        embeddings = self.token(sequence) + self.segment(segment_labels) + self.position()
        return embeddings

In [15]:
# row = [0,0,0,0,1,1,1]
# col = [
#     0,
#     0,
#     0,
#     1,
#     1,
#     1,
#     1
# ]


# 0&0 0&0 0&0 0&0 1&0 1&0 1&0
# 0&0 0&0 0&0 0&0 1&0 1&0 1&0
# 0&0 0&0 0&0 0&0 1&0 1&0 1&0
# 0&1 0&1 0&1 0&0 1&1 1&1 1&1
# 0&1 0&1 0&1 0&1 1&1 1&1 1&1
# 0&1 0&1 0&1 0&1 1&1 1&1 1&1
# 0&1 0&1 0&1 0&1 1&1 1&1 1&1

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embedding_dim,num_heads ,dropout):
        super().__init__()
        
        assert embedding_dim % num_heads == 0
        
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim // num_heads
        self.query = nn.Linear(embedding_dim,embedding_dim)
        self.key = nn.Linear(embedding_dim,embedding_dim)
        self.value = nn.Linear(embedding_dim,embedding_dim)
        self.out_proj = nn.Linear(embedding_dim,embedding_dim)
        self.dropout = nn.Dropout(p=dropout)    
        
    def forward(self,query,key,value,mask=None):
        # query -> [batch_size , seq_len , embedding_dim]
        # key -> [batch_size , seq_len , embedding_dim]
        # value -> [batch_size , seq_len , embedding_dim]
        
        # mask -> [batch_size , seq_len]
        
        def shape(X):
            return  X.view(ModelArgs.batch_size , self.num_heads , ModelArgs.seq_len , self.head_dim)
        
        query = shape(self.query(query)) # [batch_size , num_heads , seq_len , head_dim]
        key = shape(self.key(key)) # [batch_size , num_heads , seq_len , head_dim]
        value = shape(self.value(value)) # [batch_size , num_heads , seq_len , head_dim]
        
        scores = torch.matmul(query,value.transpose(-2,-1)) / (self.head_dim ** 0.5) # [batch_size , num_heads , seq_len(query) , seq_len(key)]
        
        if mask is not None:
            q_mask = mask.unsqueeze(2) # [batch_size , seq_len , 1]
            k_mask = mask.unsqueeze(1) # [batch_size , 1 , seq_len]
            
            full_mask = q_mask | k_mask 
            identiy = torch.eye(n=ModelArgs.seq_len , device=ModelArgs.device)
            full_mask = full_mask + identiy # [batch_size , seq_len , seq_len ]
            # print(full_mask.shape) # [batch_size , seq_len , seq_len ]
            # print(scores.shape) #  [batch_size , num_heads , seq_len(query) , seq_len(key)]
            full_mask = full_mask.unsqueeze(1).expand(-1,self.num_heads,-1,-1)
            scores = scores.masked_fill(full_mask==1,float('-inf'))
        
        attn = torch.softmax(scores,dim=-1) # [batch_size , num_heads , seq_len , seq_le]
        attn = self.dropout(attn) # [batch_size , num_heads , seq_len , seq_le]
        
        out = torch.matmul(attn,value) # [batch_size , num_heads , seq_len , seq_len] @ [batch_size , num_heads , seq_len , head_dim] = [batch_size , num_heads , seq_len , head_dim]
        
        out = out.transpose(-2,-1).contiguous().view(ModelArgs.batch_size,ModelArgs.seq_len,ModelArgs.embedding_dim)
        
        return self.out_proj(out) # use for how to adjust the context from different heads
        
        
        
        

In [17]:
class Norm(nn.Module):
    def __init__(self,embedding_dim):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)
        
    def forward(self,embeddings):
        return self.norm(embeddings)

In [18]:
class AddResidual(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,X1,X2):
        return X1+X2
        

In [19]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self,embedding_dim,dropout):
        super().__init__()
        self.fn = nn.Sequential(
            nn.Linear(embedding_dim,ModelArgs.ffnn_units),
            nn.GELU(),
            nn.Linear(ModelArgs.ffnn_units,embedding_dim)
        )
        self.dropout = nn.Dropout(p=dropout)
    def forward(self,embeddings):
        embeddings = self.fn(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

In [20]:
class EncoderBlock(nn.Module):
    def __init__(self,embedding_dim,num_heads,dropout):
        super().__init__()
        self.norm1 = Norm(embedding_dim=embedding_dim)
        self.multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim,num_heads=num_heads,dropout=dropout)
        self.add = AddResidual()
        self.feed_forward_neural_network = FeedForwardNeuralNetwork(embedding_dim=embedding_dim,dropout=dropout)
        self.norm2 = Norm(embedding_dim=embedding_dim)
        
    def forward(self,embeddings,mask):
        # pre normalizatiom
        
        embeddings_norm = self.norm1(embeddings)
        embeddings_norm_mha = self.multi_head_attention(query=embeddings_norm,key=embeddings_norm,value=embeddings_norm , mask=mask)
        embeddings_norm_mha_add = self.add(embeddings_norm_mha,embeddings)
        
        embeddings_norm_mha_add_norm = self.norm2(embeddings_norm_mha_add)
        embeddings_norm_mha_add_norm_fn = self.feed_forward_neural_network(embeddings_norm_mha_add_norm)
        embeddings_norm_mha_add_norm_fn_add = self.add(embeddings_norm_mha_add,embeddings_norm_mha_add_norm_fn)
        
        return embeddings_norm_mha_add_norm_fn_add

In [21]:
class BERT(nn.Module):
    def __init__(self,nx,vocab_size,embedding_dim,num_heads,dropout):
        super().__init__()
        self.bert_embeddings = BERTEmbeddings(vocab_size=vocab_size,embedding_dim=embedding_dim,dropout=dropout)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(embedding_dim=embedding_dim,num_heads=num_heads,dropout=dropout) for _ in range(nx)])
        self.apply(self._init_weights)  
        
    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            nn.init.xavier_normal_(module.weight)
        elif isinstance(module,nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        
    def forward(self,sequence,segment_labels):
        # sequnce -> [batch_size , seq_len ]
        batch_size, seq_len = sequence.shape
        mask = torch.zeros((batch_size, seq_len), device=sequence.device)

        mask = (sequence == 0).long()  # Automatically sets 1 where padding token (id=0), else 0
                
        embeddings = self.bert_embeddings(sequence,segment_labels)
        
        for encoder in self.encoder_blocks:
            embeddings = encoder(embeddings,mask)
            
        return embeddings
        

In [22]:
class NextSentencePrediction(nn.Module):
    """
    2-class classification : is_next
    """
    def __init__(self,hidden):
        super().__init__()
        self.linear = nn.Linear(hidden,2)
        self.log_softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self,X):
        # print("NSP")
        # print(X.shape)
        # X -> [batch_size , seq_len , embedding_dim]
        # use the [CLS] toke for the classification head
        X = self.linear(X[:,0,:])
        return self.log_softmax(X)

In [23]:
class MaskedLanguageModel(nn.Module):
    def __init__(self,hidden,vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden,vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=-1)
    def forward(self,X):
        X = self.linear(X)
        return self.log_softmax(X)

In [24]:
class BERTLM(nn.Module):
    def __init__(self,bert,embedding_dim,vocab_size):
        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(hidden=embedding_dim)
        self.masked_lm = MaskedLanguageModel(hidden=embedding_dim,vocab_size=vocab_size)
        
    def forward(self,sequence,segment_label):
        embeddings = self.bert(sequence,segment_label) # [batch_size , seq_len , embedding_dim]
        
        nsp = self.next_sentence(embeddings)
        mlm = self.masked_lm(embeddings)
        
        return nsp,mlm

In [25]:
class SchduledOptim():
    """A simple wrapper class for learning rate scheduling"""
    def __init__(self,optimizer,d_model,warmup_steps , base_lrs =None):
        self.optimizer = optimizer
        self.init_lr = math.pow(d_model,-0.5)
        self.warmup_steps = warmup_steps
        self.current_step = 0
        
        if base_lrs is None:
            self.base_lrs = [group['lr'] for group in self.optimizer.param_groups]
        else:
            self.base_lrs = base_lrs
        
    def zero_grad(self):
        self.optimizer.zero_grad()    
    
    def step_and_update_lr(self):
        self.optimizer.step()
        self.update_lr()
        
    def update_lr(self):
        self.current_step += 1
        lr_scale =  self.get_lr_scale()
        
        # update the lr in optimizer
        for base_lr , param_group in zip(self.base_lrs,self.optimizer.param_groups):
            param_group['lr'] = base_lr * lr_scale
        
    def get_lr_scale(self):
        return min(
            math.pow(self.current_step,-0.5),
            self.current_step*math.pow(self.warmup_steps,-1.5)
        )

In [26]:
from collections import defaultdict
class BERTTrainer():
    def __init__(self,
                 model,
                 embedding_dim,
                 vocab_size,
                 warmup_steps,
                 train_dataloader,
                 val_dataloader,
                 epochs,
                 lr,
                 min_val_loss,
                 betas,
                 weight_decay,
                 layer_decay,
                 wandb,
                 device):
        self.vocab_size = vocab_size
        self.bert = model.to(device)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.epochs = epochs
        self.device = device
        self.min_val_loss = min_val_loss
        
        self.wandb = wandb
        
        param_groups = self.get_layerwise_lr_params(self.bert , base_lr=lr , layer_decay= layer_decay)
        base_lrs = [group['lr'] for group in param_groups]
        self._optimizer = torch.optim.Adam(param_groups,betas=betas,weight_decay=weight_decay,eps=1e-8)
        self.scheduled_optimizer = SchduledOptim(optimizer=self._optimizer,d_model=embedding_dim,warmup_steps=warmup_steps,base_lrs=base_lrs)
        self.criterion = nn.NLLLoss(ignore_index=0)
        
        
        
    def get_layerwise_lr_params(self,model, base_lr, layer_decay=0.95):
        """
        Returns param groups with exponentially decayed learning rates per layer.
        """
        param_groups = []
        assigned = set()

        # Embeddings
        param_groups.append({
            "params": list(model.bert.bert_embeddings.parameters()),
            "lr": base_lr * (layer_decay ** 0)
        })
        assigned |= set(model.bert.bert_embeddings.parameters())

        # Encoder layers
        for i, layer in enumerate(model.bert.encoder_blocks):
            layer_params = list(layer.parameters())
            param_groups.append({
                "params": layer_params,
                "lr": base_lr * (layer_decay ** (i + 1))  # deeper layers get smaller lr
            })
            assigned |= set(layer_params)

        # Any remaining (top) layers like NSP or MLM head
        remaining_params = [p for p in model.parameters() if p not in assigned]
        if remaining_params:
            param_groups.append({
                "params": remaining_params,
                "lr": base_lr  # full LR for heads
            })

        return param_groups

    
    def train_and_evaluate(self,log_grad_norm):
        from tqdm import tqdm
        best_val_loss = float('inf')
        for epoch in range(self.epochs):
            train_loss,correct,total = 0.0,0,0
            train_progress = tqdm(self.train_dataloader,desc="Training")
            for i,data in enumerate(train_progress):
                data = {key:value.to(self.device) for key,value in data.items()}
                
                # bert_input : [batch_size , seq_len]
                # bert_label : [batch_size , seq_len]
                # segment_label : [batch_size , seq_len]
                # is_next_label : a scalar value
                
                nsp_output , mlm_output = self.bert(data["bert_input"],data["segment_label"]) #  [batch_size , 2] ,  [batch_size , seq_len ,vocab_size ] 
                
                nsp_loss = self.criterion(nsp_output,data["is_next_label"])
                
                # mlm_output = [batch_size , seq_len , vocab_size] vs bert_label = [batch_size , seq_len] 
                mlm_output = mlm_output.view(-1,self.vocab_size)
                data["bert_label"] = data["bert_label"].view(-1)
                
                mlm_loss = self.criterion(mlm_output,data['bert_label'])
                loss = nsp_loss + mlm_loss
                ###########################################################################################################################
                self.scheduled_optimizer.zero_grad()
                self._optimizer
                loss.backward()
                # --------------------------------------------------------------------------------------------------------------------------
                torch.nn.utils.clip_grad_norm_(self.bert.parameters(),max_norm=1.0)
                
                if log_grad_norm:
                    total_norm = 0
                    for p in self.bert.parameters():
                        if p.grad is not None:
                            param_norm = p.grad.data.norm(2) # L2 norm
                            total_norm = param_norm.item()**2
                            
                    

                    total_norm = total_norm**0.5
                    
                    self.wandb.log({"norm":total_norm})

                    grad_groups = defaultdict(list)

                    for name, param in self.bert.named_parameters():
                        if param.grad is None:
                            continue
                        # print(f"name : {name} grad_norm : {param.grad.norm().item():.6f}")
                        # Get group name: e.g., "encoder.0", "decoder.1", "embedding", etc.
                        tokens = name.split('.')
                        group_name = '.'.join(tokens[:3]) if len(tokens) >= 3 else '.'.join(tokens[:2])
                
                        grad_norm = param.grad.norm().item()
                        grad_groups[group_name].append(grad_norm)
                
                    # Average gradients per group
                    avg_grad_per_group = {k: sum(v)/len(v) for k, v in grad_groups.items()}
                    # ----------------------------------------------------------------------------------------------------------------------
                    self.wandb.log(avg_grad_per_group)
                    self.wandb.log({"total_norm":total_norm})
                    
                    self.scheduled_optimizer.step_and_update_lr()
                ######################################################################################################################
                
                preds = torch.argmax(nsp_output,dim=-1)
                correct += (preds == data["is_next_label"]).sum().item()
                total += ModelArgs.batch_size
                
                preds = torch.argmax(mlm_output,dim=-1) # [batch_size * seq_len]
                correct += (preds == data["bert_label"]).sum().item()
                total += ModelArgs.batch_size * ModelArgs.seq_len
                
                train_loss += loss.item()
                
                train_progress.set_postfix({"nsp_loss":f"{nsp_loss.item():.4f}","mlm_loss":f"{mlm_loss.item():.4f}","loss":f"{loss.item():.4f}"})
                
                self.wandb.log({"train_loss":train_loss/(i+1),"train_acc":correct/total})
            
            train_loss /= len(self.train_dataloader)
            train_acc = correct / total
                
            # with torch.inference_mode():
            #     val_loss,correct,total = 0.0,0,0
            #     val_progress = tqdm(self.val_dataloader)
            #     for i,data in enumerate(val_progress):
            #         data = {key:value.to(self.device) for key,value in data.items()}
                    
            #         # bert_input : [batch_size , seq_len]
            #         # bert_label : [batch_size , seq_len]
            #         # segment_label : [batch_size , seq_len]
            #         # is_next_label : a scalar value
                    
            #         nsp_output , mlm_output = self.bert(data["bert_input"],data["segment_label"])  #  [batch_size , 2] ,  [batch_size , seq_len ,vocab_size ] 
                    
            #         nsp_loss = self.criterion(nsp_output,data["is_next_label"])
                    
            #         mlm_loss = self.criterion(mlm_output,data["bert_label"])
                    
            #         loss = nsp_loss + mlm_loss
                    
            #         preds = torch.argmax(nsp_output,dim=-1)
            #         correct += (preds == data["is_next_label"]).sum().item()
            #         total += ModelArgs.batch_size
                    
            #         preds = torch.argmax(mlm_output,dim=-1).view(-1)
            #         correct += (preds == data["bert_label"].view(-1)).sum().item()
            #         total += ModelArgs.batch_size*ModelArgs.seq_len
                    
            #         val_progress.set_postfix({"nsp_loss":f"{nsp_loss.item():.4f}","mlm_loss":f"{mlm_loss.item():.4f}","loss":f"{loss.item():.4f}"})
                    
            #         self.wandb.log({"val_loss":f"{val_loss/i+1}","val_acc":f"{correct/total}"})
                
            #     val_loss /= len(self.val_dataloader)
            #     val_acc = correct / total
                
                
            print(f"Epoch : {epoch+1}/{self.epochs} \n train_loss : {train_loss:.5f} train_acc : {train_acc:.5f} \n")
                
            if train_loss < best_val_loss:
                best_val_loss = train_loss
                torch.save(self.bert.state_dict(),f"bert_scratch_{train_loss:.4f}.pth")
            
            if train_loss < self.min_val_loss:
                print("[SUCCESS] model trained successfully")
                break
                    
                
                
                
                 
                

In [27]:
train_dataset = BERTDataset(data_pair=pairs,tokenizer=tokenizer)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=ModelArgs.batch_size,
                              shuffle=True,
                              drop_last=True)

In [28]:
sample = next(iter(train_dataloader))

In [29]:
for key,value in sample.items():
    print(f"key : {key} value : {value.shape}")

key : bert_input value : torch.Size([128, 64])
key : bert_label value : torch.Size([128, 64])
key : segment_label value : torch.Size([128, 64])
key : is_next_label value : torch.Size([128])


In [30]:
bert = BERT(nx=ModelArgs.nx , vocab_size=len(tokenizer.vocab) , embedding_dim=ModelArgs.embedding_dim , num_heads=ModelArgs.num_heads , dropout= ModelArgs.dropout) 

In [31]:
bert = bert.to(ModelArgs.device)
bert(sample["bert_input"].to(ModelArgs.device),sample["segment_label"].to(ModelArgs.device)).shape

torch.Size([128, 64, 128])

In [32]:
bert_lm = BERTLM(bert=bert,embedding_dim=ModelArgs.embedding_dim , vocab_size=len(tokenizer.vocab))
bert_lm = bert_lm.to(ModelArgs.device)
sample_out = bert_lm(sample["bert_input"].to(ModelArgs.device),sample["segment_label"].to(ModelArgs.device))
sample_out[0].shape,sample_out[1].shape

(torch.Size([128, 2]), torch.Size([128, 64, 24864]))

In [33]:
from torchinfo import summary

summary(model=bert,
        input_data=(sample["bert_input"].to(ModelArgs.device),sample["segment_label"].to(ModelArgs.device)),
        col_names=["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"],
        device=ModelArgs.device)

Layer (type (var_name))                                                Input Shape          Output Shape         Param #              Trainable
BERT (BERT)                                                            [128, 64]            [128, 64, 128]       --                   True
├─BERTEmbeddings (bert_embeddings)                                     [128, 64]            [128, 64, 128]       --                   True
│    └─Embedding (token)                                               [128, 64]            [128, 64, 128]       3,182,592            True
│    └─Embedding (segment)                                             [128, 64]            [128, 64, 128]       384                  True
│    └─PositionalEmbeddings (position)                                 --                   [128, 64, 128]       --                   --
├─ModuleList (encoder_blocks)                                          --                   --                   --                   True
│    └─EncoderBlock (0) 

In [34]:
import dotenv
import wandb
dotenv.load_dotenv()
wandb.login(key=os.environ["WANDB_API_KEY"])

wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\janar\_netrc
wandb: Currently logged in as: janardhanthippabattini (janardhanthippabattini-rajiv-gandhi-university-of-knowle) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [35]:
bert_lm.load_state_dict(torch.load(r"bert_scratch_7.8702.pth",weights_only=True))

<All keys matched successfully>

In [None]:
for

SyntaxError: invalid syntax (3193057967.py, line 1)

In [37]:
exp_name = f"bert_scratch-nsp&mlm layer_wise_learning_rate_decay=True,weight_init=normal,lr=1e-4,nx={ModelArgs.nx},heads={ModelArgs.num_heads},d_model={ModelArgs.embedding_dim},dropout={ModelArgs.dropout},scheduler=reserch_paper,optimizer=adam,seq_len={ModelArgs.seq_len},batch_size={ModelArgs.batch_size},warmup_steps=4_000,clip_grad_norm=True"

In [42]:
wandb.init(
    project="bert_scratch-nsp&mlm",
    name=exp_name,
    # reinit=True,
    id="rajnz4f6",
    resume="allow"
)

0,1
bert.bert_embeddings.segment,▆▆▆▅▄▅▇▅▆▇▅▇▄▆▅▆█▅▆▂▅▅▇▃▇▅▆▅▇▁▆▆▄▅▇▆▄▅▄▇
bert.bert_embeddings.token,▅▆▁▇▆▆▄▅█▄▅▆█▃▄▆▃▇▂▄▃▅▇▅▆▂▃▃▂▇▃▆▄█▂▇▁▆▆▇
bert.encoder_blocks.0,▆▆▆▄▅▅▆▄▄▇▄▂▅▅▁▂▅▅▁▄▁▄▃▅▆▇▃█▅▄▇▆▆▃▂▃▆▃▄▅
bert.encoder_blocks.1,▇▅▆▇▅▂▅▇▇█▇█▆▅▆▆▇█▄▇▆▃▇▇▆▄█▁██▆▇▆▃▄▇▅▅▇▆
masked_lm.linear.bias,▃▃▅▂▅▃▃▃▁▁▃▁█▇▂▄▃▃▁▃▃▃▇▃▁▄▂▄▄▂▄▂▃▂▂▄▅▆▃▂
masked_lm.linear.weight,▃▄▃▂▃▆▃▂▅▂▅▄▄▂▅█▆▃▄▅▄█▄▁▄▅▄▆▁▂▄▂▆▇▃▃▅▄▅▄
next_sentence.linear.bias,▃▂▂▆▃▅█▃▄▇▇▆▅▁▃▆▄▆▄▆▃▆▇▃▃▅▃▅▇▃▅▇▃▂▃▃▃▇▇▄
next_sentence.linear.weight,█▆▄▄▇▇▇▇▅▄▄▅▆▂▁▇▆▅▅▅▅▆▆▆▇█▅█▄▇▆▅▃▃▄▃▇▅▄▆
norm,▃▃▂▃▂▃▂▂▃▃▅▃▃▂▁▇▅▂▃▄▄▅▅▄▄▂▁▄█▃▂▄▄▃▃▃▅▅▅▄
total_norm,▃▃▁▂▅▄▁▂▂▂▄▅▄▃▁▁▂▆▅▃▅▄▃▃▃▅█▂▁▅▃▅▂▃▂▄▃▃▃▂

0,1
bert.bert_embeddings.segment,0.04294
bert.bert_embeddings.token,0.03797
bert.encoder_blocks.0,0.07035
bert.encoder_blocks.1,0.06405
masked_lm.linear.bias,0.03762
masked_lm.linear.weight,0.83514
next_sentence.linear.bias,0.00227
next_sentence.linear.weight,0.04674
norm,0.03762
total_norm,0.03762


In [39]:
14*1731

24234

In [43]:
trainer = BERTTrainer(model=bert_lm,
            embedding_dim=ModelArgs.embedding_dim,
            vocab_size=tokenizer.vocab_size,
            warmup_steps=4_000,
            train_dataloader=train_dataloader,
            val_dataloader=None,
            epochs=200,
            lr=1e-4,
            min_val_loss=1e-2,
            betas=(0.9,0.999),
            weight_decay=0.1,
            layer_decay=0.95,
            wandb=wandb,    
            device=ModelArgs.device
            )

trainer.scheduled_optimizer.current_step = 24234




In [44]:
trainer.train_and_evaluate(log_grad_norm=True)

Training: 100%|██████████| 1731/1731 [14:37<00:00,  1.97it/s, nsp_loss=0.0026, mlm_loss=7.7095, loss=7.7120]


Epoch : 1/200 
 train_loss : 7.71548 train_acc : 0.01441 



Training: 100%|██████████| 1731/1731 [16:16<00:00,  1.77it/s, nsp_loss=0.0025, mlm_loss=7.3963, loss=7.3988]


Epoch : 2/200 
 train_loss : 7.60933 train_acc : 0.01437 



Training: 100%|██████████| 1731/1731 [09:55<00:00,  2.91it/s, nsp_loss=0.0021, mlm_loss=7.4242, loss=7.4263]


Epoch : 3/200 
 train_loss : 7.50741 train_acc : 0.01440 



Training: 100%|██████████| 1731/1731 [11:33<00:00,  2.50it/s, nsp_loss=0.0021, mlm_loss=7.4698, loss=7.4719]


Epoch : 4/200 
 train_loss : 7.42131 train_acc : 0.01439 



Training: 100%|██████████| 1731/1731 [09:37<00:00,  2.99it/s, nsp_loss=0.0018, mlm_loss=7.3663, loss=7.3681]


Epoch : 5/200 
 train_loss : 7.34176 train_acc : 0.01446 



Training: 100%|██████████| 1731/1731 [09:30<00:00,  3.03it/s, nsp_loss=0.0018, mlm_loss=7.3366, loss=7.3383]


Epoch : 6/200 
 train_loss : 7.27010 train_acc : 0.01450 



Training: 100%|██████████| 1731/1731 [09:27<00:00,  3.05it/s, nsp_loss=0.0016, mlm_loss=7.1562, loss=7.1578]


Epoch : 7/200 
 train_loss : 7.20579 train_acc : 0.01449 



Training: 100%|██████████| 1731/1731 [09:24<00:00,  3.06it/s, nsp_loss=0.0014, mlm_loss=7.0495, loss=7.0509]


Epoch : 8/200 
 train_loss : 7.14859 train_acc : 0.01457 



Training: 100%|██████████| 1731/1731 [09:17<00:00,  3.10it/s, nsp_loss=0.0014, mlm_loss=7.1209, loss=7.1224]


Epoch : 9/200 
 train_loss : 7.09846 train_acc : 0.01455 



Training: 100%|██████████| 1731/1731 [20:53<00:00,  1.38it/s, nsp_loss=0.0013, mlm_loss=6.8906, loss=6.8919]


Epoch : 10/200 
 train_loss : 7.04691 train_acc : 0.01451 



Training: 100%|██████████| 1731/1731 [24:59<00:00,  1.15it/s, nsp_loss=0.0014, mlm_loss=7.2688, loss=7.2702]


Epoch : 11/200 
 train_loss : 7.00420 train_acc : 0.01451 



Training: 100%|██████████| 1731/1731 [09:17<00:00,  3.11it/s, nsp_loss=0.0013, mlm_loss=6.7739, loss=6.7753]


Epoch : 12/200 
 train_loss : 6.96027 train_acc : 0.01456 



Training: 100%|██████████| 1731/1731 [09:22<00:00,  3.08it/s, nsp_loss=0.0013, mlm_loss=7.2829, loss=7.2841]


Epoch : 13/200 
 train_loss : 6.92983 train_acc : 0.01455 



Training: 100%|██████████| 1731/1731 [09:32<00:00,  3.02it/s, nsp_loss=0.0013, mlm_loss=6.7767, loss=6.7780]


Epoch : 14/200 
 train_loss : 6.89251 train_acc : 0.01454 



Training: 100%|██████████| 1731/1731 [09:27<00:00,  3.05it/s, nsp_loss=0.0013, mlm_loss=6.7933, loss=6.7946]


Epoch : 15/200 
 train_loss : 6.84945 train_acc : 0.01459 



Training: 100%|██████████| 1731/1731 [09:26<00:00,  3.05it/s, nsp_loss=0.0013, mlm_loss=6.5736, loss=6.5749]


Epoch : 16/200 
 train_loss : 6.82113 train_acc : 0.01455 



Training: 100%|██████████| 1731/1731 [09:19<00:00,  3.09it/s, nsp_loss=0.0012, mlm_loss=6.7428, loss=6.7440]


Epoch : 17/200 
 train_loss : 6.79331 train_acc : 0.01460 



Training: 100%|██████████| 1731/1731 [10:38<00:00,  2.71it/s, nsp_loss=0.0011, mlm_loss=6.8537, loss=6.8548]


Epoch : 18/200 
 train_loss : 6.76502 train_acc : 0.01461 



Training: 100%|██████████| 1731/1731 [11:26<00:00,  2.52it/s, nsp_loss=0.0011, mlm_loss=6.5662, loss=6.5673]


Epoch : 19/200 
 train_loss : 6.74143 train_acc : 0.01454 



Training: 100%|██████████| 1731/1731 [12:24<00:00,  2.33it/s, nsp_loss=0.0011, mlm_loss=6.8324, loss=6.8335]


Epoch : 20/200 
 train_loss : 6.72159 train_acc : 0.01461 



Training: 100%|██████████| 1731/1731 [1:30:01<00:00,  3.12s/it, nsp_loss=0.0012, mlm_loss=6.8174, loss=6.8186]      


Epoch : 21/200 
 train_loss : 6.70040 train_acc : 0.01459 



Training: 100%|██████████| 1731/1731 [11:33<00:00,  2.50it/s, nsp_loss=0.0011, mlm_loss=6.3831, loss=6.3842]


Epoch : 22/200 
 train_loss : 6.67245 train_acc : 0.01461 



Training: 100%|██████████| 1731/1731 [11:24<00:00,  2.53it/s, nsp_loss=0.0010, mlm_loss=6.7139, loss=6.7149]


Epoch : 23/200 
 train_loss : 6.65459 train_acc : 0.01459 



Training: 100%|██████████| 1731/1731 [09:36<00:00,  3.00it/s, nsp_loss=0.0011, mlm_loss=6.6103, loss=6.6114]


Epoch : 24/200 
 train_loss : 6.63974 train_acc : 0.01458 



Training: 100%|██████████| 1731/1731 [10:11<00:00,  2.83it/s, nsp_loss=0.0011, mlm_loss=6.6269, loss=6.6279]


Epoch : 25/200 
 train_loss : 6.62418 train_acc : 0.01456 



Training: 100%|██████████| 1731/1731 [10:59<00:00,  2.63it/s, nsp_loss=0.0011, mlm_loss=6.5891, loss=6.5901]


Epoch : 26/200 
 train_loss : 6.60721 train_acc : 0.01462 



Training: 100%|██████████| 1731/1731 [10:54<00:00,  2.64it/s, nsp_loss=0.0010, mlm_loss=6.5592, loss=6.5602]


Epoch : 27/200 
 train_loss : 6.58842 train_acc : 0.01462 



Training: 100%|██████████| 1731/1731 [09:35<00:00,  3.01it/s, nsp_loss=0.0010, mlm_loss=6.4998, loss=6.5008]


Epoch : 28/200 
 train_loss : 6.57692 train_acc : 0.01460 



Training:  11%|█         | 191/1731 [01:01<08:18,  3.09it/s, nsp_loss=0.0010, mlm_loss=6.6831, loss=6.6841]


KeyboardInterrupt: 

In [None]:
bert_lm.bert.encoder_blocks

ModuleList(
  (0-1): 2 x EncoderBlock(
    (norm1): Norm(
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (multi_head_attention): MultiHeadAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (out_proj): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (add): AddResidual()
    (feed_forward_neural_network): FeedForwardNeuralNetwork(
      (fn): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (norm2): Norm(
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [None]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
bert_lm.load_state_dict(torch.load(r"bert_scratc_6.7553.pth",weights_only=True))

<All keys matched successfully>

In [None]:
trainer = BERTTrainer(model=bert_lm,
            embedding_dim=ModelArgs.embedding_dim,
            vocab_size=tokenizer.vocab_size,
            warmup_steps=10_000,
            train_dataloader=train_dataloader,
            val_dataloader=None,
            epochs=200,
            lr=1e-9,
            min_val_loss=1e-2,
            betas=(0.9,0.999),
            weight_decay=0.1,
            wandb=wandb,    
            device=ModelArgs.device
            )

trainer.train_and_evaluate(log_grad_norm=True)



Training:   0%|          | 0/6925 [00:00<?, ?it/s]

name : bert.bert_embeddings.token.weight grad_norm : 0.030429
name : bert.bert_embeddings.segment.weight grad_norm : 0.083727
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.125002
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 1/6925 [00:00<1:01:23,  1.88it/s, nsp_loss=0.0013, mlm_loss=6.2047, loss=6.2060]

name : bert.bert_embeddings.token.weight grad_norm : 0.022721
name : bert.bert_embeddings.segment.weight grad_norm : 0.056245
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.085627
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 2/6925 [00:00<34:59,  3.30it/s, nsp_loss=0.0011, mlm_loss=6.5978, loss=6.5989]  

name : bert.bert_embeddings.token.weight grad_norm : 0.033137
name : bert.bert_embeddings.segment.weight grad_norm : 0.077929
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.122776
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 3/6925 [00:00<27:24,  4.21it/s, nsp_loss=0.0011, mlm_loss=6.1897, loss=6.1909]

name : bert.bert_embeddings.token.weight grad_norm : 0.032030
name : bert.bert_embeddings.segment.weight grad_norm : 0.069821
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.100095
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 4/6925 [00:01<24:27,  4.72it/s, nsp_loss=0.0011, mlm_loss=6.5918, loss=6.5929]

name : bert.bert_embeddings.token.weight grad_norm : 0.028416
name : bert.bert_embeddings.segment.weight grad_norm : 0.070623
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.103616
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 5/6925 [00:01<23:19,  4.94it/s, nsp_loss=0.0011, mlm_loss=6.2718, loss=6.2729]

name : bert.bert_embeddings.token.weight grad_norm : 0.021399
name : bert.bert_embeddings.segment.weight grad_norm : 0.056173
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.086489
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 6/6925 [00:01<22:48,  5.06it/s, nsp_loss=0.0013, mlm_loss=6.6226, loss=6.6238]

name : bert.bert_embeddings.token.weight grad_norm : 0.021973
name : bert.bert_embeddings.segment.weight grad_norm : 0.026167
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.042617
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 7/6925 [00:01<20:58,  5.50it/s, nsp_loss=0.0011, mlm_loss=6.9632, loss=6.9643]

name : bert.bert_embeddings.token.weight grad_norm : 0.019383
name : bert.bert_embeddings.segment.weight grad_norm : 0.046990
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.063360
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 8/6925 [00:01<20:25,  5.64it/s, nsp_loss=0.0012, mlm_loss=6.7498, loss=6.7510]

name : bert.bert_embeddings.token.weight grad_norm : 0.031175
name : bert.bert_embeddings.segment.weight grad_norm : 0.058702
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.090762
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 9/6925 [00:01<20:10,  5.71it/s, nsp_loss=0.0011, mlm_loss=6.3247, loss=6.3259]

name : bert.bert_embeddings.token.weight grad_norm : 0.035593
name : bert.bert_embeddings.segment.weight grad_norm : 0.068615
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.105126
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 10/6925 [00:02<18:53,  6.10it/s, nsp_loss=0.0012, mlm_loss=6.3010, loss=6.3022]

name : bert.bert_embeddings.token.weight grad_norm : 0.035445
name : bert.bert_embeddings.segment.weight grad_norm : 0.078642
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.124909
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 11/6925 [00:02<19:27,  5.92it/s, nsp_loss=0.0013, mlm_loss=6.1867, loss=6.1880]

name : bert.bert_embeddings.token.weight grad_norm : 0.014054
name : bert.bert_embeddings.segment.weight grad_norm : 0.028199
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.044826
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 12/6925 [00:02<19:58,  5.77it/s, nsp_loss=0.0012, mlm_loss=7.0329, loss=7.0342]

name : bert.bert_embeddings.token.weight grad_norm : 0.033373
name : bert.bert_embeddings.segment.weight grad_norm : 0.079571
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.119537
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 13/6925 [00:02<20:06,  5.73it/s, nsp_loss=0.0012, mlm_loss=6.2223, loss=6.2234]

name : bert.bert_embeddings.token.weight grad_norm : 0.017140
name : bert.bert_embeddings.segment.weight grad_norm : 0.040809
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.056294
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 14/6925 [00:02<19:12,  6.00it/s, nsp_loss=0.0012, mlm_loss=6.6142, loss=6.6154]

name : bert.bert_embeddings.token.weight grad_norm : 0.020117
name : bert.bert_embeddings.segment.weight grad_norm : 0.057771
name : bert.encoder_blocks.0.norm1.norm.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.norm1.norm.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.query.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.value.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.multi_head_attention.out_proj.bias grad_norm : 0.083242
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.weight grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.fn.0.bias grad_norm : 0.000000
name : bert.encoder_blocks.0.feed_forward_neural_network.f

Training:   0%|          | 15/6925 [00:02<22:31,  5.11it/s, nsp_loss=0.0011, mlm_loss=6.6223, loss=6.6234]


KeyboardInterrupt: 

In [None]:
bert_lm

BERTLM(
  (bert): BERT(
    (bert_embeddings): BERTEmbeddings(
      (token): Embedding(24864, 128)
      (segment): Embedding(3, 128)
      (position): PositionalEmbeddings()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder_blocks): ModuleList(
      (0-1): 2 x EncoderBlock(
        (norm1): Norm(
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (multi_head_attention): MultiHeadAttention(
          (query): Linear(in_features=128, out_features=128, bias=True)
          (key): Linear(in_features=128, out_features=128, bias=True)
          (value): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (add): AddResidual()
        (feed_forward_neural_network): FeedForwardNeuralNetwork(
          (fn): Sequential(
            (0): Linear(in_features=128, out_features=512, bias=True)
       