# Implementing Transformer form scratch in pytorch

In [1]:
import requests
import gzip
import shutil
from time import sleep
import math
import pickle
from collections import Counter,defaultdict

In [2]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader

In [3]:
base_url = "https://github.com/multi30k/dataset/raw/refs/heads/master/data/task1/raw/"

train_url = ("train.de.gz","train.en.gz")
val_url = ("val.de.gz","val.en.gz",)
test_url = ("test_2016_flickr.de.gz","test_2016_flickr.en.gz",)

In [4]:
@dataclass
class ModelArgs:
    max_lr = 1e-4
    batch_size = 32
    embedding_dim = 64
    no_of_neurons_ffnn = 4*embedding_dim
    
    seq_len = 50
    
    num_heads = 3
    
    en_vocab_size = None
    de_vocab_size = None
    
    
    attn_dropout = 0.1
    dropout = 0.1
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    

In [7]:
2048/512

4.0

In [8]:
def download_file(url, file_name, retries=3):
    for attempt in range(retries):
        try:
            with requests.get(url, stream=True, timeout=10) as r:
                r.raise_for_status()
                with open(file_name, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print(f"Downloaded: {file_name}")
            break
        except (requests.exceptions.RequestException, ConnectionResetError) as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            sleep(2)  # wait before retrying
            if attempt == retries - 1:
                print(f"Failed to download {file_name} after {retries} attempts.")
    return file_name

def extract_data(in_file,out_file):
    with gzip.open(in_file,"rb") as f_in:
        with open(out_file,"wb") as f_out:
            shutil.copyfileobj(f_in,f_out)
    
    print(f"Extracted : {in_file}")
    
    return out_file

train_paths = [ download_file(base_url+url,url) for url in train_url ]
val_paths = [download_file(base_url+url,url) for url in val_url]
test_paths = [download_file(base_url+url,url) for url in test_url]

train_paths = [extract_data(file,file[:-3]) for file in train_paths]
val_paths = [extract_data(file,file[:-3]) for file in val_paths]
test_paths = [extract_data(file,file[:-3]) for file in test_paths]


Downloaded: train.de.gz
Downloaded: train.en.gz
Downloaded: val.de.gz
Downloaded: val.en.gz
Downloaded: test_2016_flickr.de.gz
Downloaded: test_2016_flickr.en.gz
Extracted : train.de.gz
Extracted : train.en.gz
Extracted : val.de.gz
Extracted : val.en.gz
Extracted : test_2016_flickr.de.gz
Extracted : test_2016_flickr.en.gz


In [9]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m94.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m81.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m


In [10]:
from collections import Counter , defaultdict
import io
import spacy

en_tokenizer = spacy.load("en_core_web_sm")
de_tokenizer = spacy.load("de_core_news_sm")

def tokenize(text,tokenizer):
    tokens = tokenizer(text)
    return [token.text.lower for token in tokens]

def build_vocab(file_path,tokenizer,min_freq=1,speacial_tokens = ["<bos>","<pad>","<unk>","<eos>"]):
    counter = Counter()
    with open(file_path,"r") as f:
        for string_ in f:
            tokens = tokenize(string_,tokenizer)
            counter.update(tokens)
    tokens = [token for token,freq in counter.items() if freq>=min_freq]
    vocab = {token:idx for idx,token in enumerate(tokens+speacial_tokens)}
    unk_idx = vocab["<unk>"]
    vocab = defaultdict(lambda : unk_idx,vocab)
    
    return vocab

de_vocab = build_vocab(train_paths[0],de_tokenizer)
en_vocab = build_vocab(train_paths[1],en_tokenizer)

ModelArgs.de_vocab_size = len(de_vocab)+1
ModelArgs.en_vocab_size = len(en_vocab)+1



In [11]:
def data_process(file_names):
    data = []
    
    raw_de_iter = iter(io.open(file_names[0],encoding="utf-8"))
    raw_en_iter = iter(io.open(file_names[1],encoding="utf-8"))
    
    for raw_de,raw_en in zip(raw_de_iter,raw_en_iter):
        de_tensor = torch.tensor([de_vocab[token] for token in tokenize(raw_de,de_tokenizer)],dtype=torch.long)
        en_tensor = torch.tensor([en_vocab[token] for token in tokenize(raw_en,en_tokenizer)],dtype=torch.long)
        
        en_tensor = torch.cat([torch.tensor([en_vocab["<bos>"]]), en_tensor , torch.tensor([en_vocab["<eos>"]])])
        de_tensor = torch.cat([torch.tensor([de_vocab["<bos>"]]), de_tensor , torch.tensor([de_vocab["<eos>"]])])
        
        data.append((de_tensor,en_tensor))
        
    return data

train_data = data_process(train_paths)
val_data = data_process(val_paths)
test_data = data_process(test_paths)

In [5]:
# get the data from the local
train_data = torch.load("/kaggle/input/transformer-dataset/train_data.pt")
val_data = torch.load("/kaggle/input/transformer-dataset/val_data.pt")
test_data = torch.load("/kaggle/input/transformer-dataset/test_data.pt")
with open("/kaggle/input/model-args-transformer/model_args.pkl","rb") as f:
    args = pickle.load(f)
ModelArgs.en_vocab_size = args["en_vocab_size"]
ModelArgs.de_vocab_size = args["de_vocab_size"]

with open("/kaggle/input/transformer-dataset/de_voab.pkl","rb") as f:
    de_vocab = pickle.load(f)
with open("/kaggle/input/transformer-dataset/en_vocab.pkl","rb") as f:
    en_vocab = pickle.load(f)



In [6]:

de_vocab = defaultdict(lambda : de_vocab["<unk>"],de_vocab)
en_vocab = defaultdict(lambda : en_vocab["<unk>"],en_vocab)

In [7]:
DE_MAX_SEQ_LEN,EN_MAX_SEQ_LEN = 0,0
for de,en in train_data:
    DE_MAX_SEQ_LEN = max(DE_MAX_SEQ_LEN,len(de))
    EN_MAX_SEQ_LEN = max(EN_MAX_SEQ_LEN,len(en))

DE_MAX_SEQ_LEN,EN_MAX_SEQ_LEN

(47, 44)

In [8]:
class TranslationDataset(Dataset):
    def __init__(self,data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx]
    
train_dataset = TranslationDataset(train_data)
val_dataset = TranslationDataset(val_data)
test_dataset = TranslationDataset(test_data)

In [9]:
ModelArgs.seq_len

50

In [10]:
def collate_fn(batch,seq_len=ModelArgs.seq_len):
    
    def pad_or_truncate(sequence,pad_value):
        if len(sequence) >= seq_len:
            mask = torch.full([seq_len],fill_value=1)
            return sequence[:seq_len],mask
        elif len(sequence) < seq_len:
            padding_len = seq_len - len(sequence)
            mask = torch.cat([torch.full([len(sequence)],fill_value=1) , torch.full([padding_len],fill_value=0)])
            return torch.cat([sequence,torch.full([padding_len],fill_value=pad_value)]) , mask
    de_batch,en_batch,de_pad_mask,en_pad_mask = [],[],[],[]
    
    for de_tensor,en_tensor in batch:
        
        de_tensor,de_mask = pad_or_truncate(de_tensor,pad_value=de_vocab["<pad>"])
        en_tensor,en_mask = pad_or_truncate(en_tensor,pad_value=en_vocab["<pad>"])
        
        de_batch.append(de_tensor)
        en_batch.append(en_tensor)
        de_pad_mask.append(de_mask)
        en_pad_mask.append(en_mask)
        
    de_batch = torch.stack(de_batch)
    en_batch = torch.stack(en_batch)
    en_pad_mask = torch.stack(en_pad_mask)
    de_pad_mask = torch.stack(de_pad_mask)      
    
    return de_batch,en_batch , en_pad_mask , de_pad_mask

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=ModelArgs.batch_size,
                              collate_fn=collate_fn,
                              drop_last=True,
                              shuffle=True)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=ModelArgs.batch_size,
                            collate_fn=collate_fn,
                            drop_last=True,
                            shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=ModelArgs.batch_size,
                             collate_fn=collate_fn,
                             drop_last=True,
                             shuffle=True)

In [11]:
sample_de,sample_en,sample_de_mask,sample_en_mask = next(iter(train_dataloader))

sample_de.shape,sample_en.shape,sample_de_mask.shape,sample_en_mask.shape

(torch.Size([32, 50]),
 torch.Size([32, 50]),
 torch.Size([32, 50]),
 torch.Size([32, 50]))

In [12]:
# s = torch.randn(size=[100,100])

# s[:,sample_en_mask[0]==0] = float('-inf')

# s

In [62]:

class Embeddings(nn.Module):
    def __init__(self,vocab_size,embedding_dim):
        super().__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embedding_dim)
        
    def forward(self,X):
        return self.embedding_layer(X)

In [63]:
class PositionEmbedding(nn.Module):
    def __init__(self,d_model):
        super().__init__()
        # d_model : embedding_dim 
        self.d_model = d_model
    def forward(self):
        def sin_fun(pos,i):
            return math.sin(pos / 10000 ** (2 * i / self.d_model) )
        def cos_fun(pos,i):
            return math.cos(pos / 10000 ** (2 * i / self.d_model) )
        funs = [sin_fun,cos_fun]
        pos_tensor = torch.stack( [torch.tensor([fun(pos,i) for fun in funs for i in range(0,self.d_model//2 - 1 + 1)],device=ModelArgs.device) for pos in range(ModelArgs.seq_len)] ) 
        # -1 + 1 becuase : by default the loop excludes the last step
        pos_tensor_for_batch = pos_tensor.unsqueeze(0).repeat(ModelArgs.batch_size,1,1)
        
        # embeddings = embeddings + pos_tensor_for_batch
        
        return pos_tensor_for_batch

In [64]:
class SelfAttention(nn.Module):
    def __init__(self,embedding_dim,q_dim,k_dim,v_dim,attn_dropout,mask=False,):
        super().__init__()
        self.mask = mask
        self.embedding_dim = embedding_dim
        self.w_q = nn.Linear(in_features=embedding_dim,out_features=q_dim)
        self.w_k = nn.Linear(in_features=embedding_dim,out_features=k_dim)
        self.w_v = nn.Linear(in_features=embedding_dim,out_features=v_dim)
        self.dropout = nn.Dropout(p=attn_dropout)
        
    def forward(self,embeddings,pad_mask):
        # embeddings : [batch_size , seq_len , embedding_dim]
        # print(f"embedding : {embeddings.device}")
        # print(f"w_q : {self.w_q.weight.device}")
        # print(f"{self.w_q.dtype}")
        # if torch.isnan(embeddings).any():
        #     print("input embeddings to self attention is nan")
        query = self.w_q(embeddings) # [batch_size , seq_len , q_dim]
        key = self.w_k(embeddings) # [batch_size , seq_len , k_dim]
        value = self.w_v(embeddings) # [batch_size , seq_len , v_dm]
        
        similarity_scores = torch.bmm(query , key.transpose(-2,-1)) # [batch_size , seq_len ,seq_len]
        
        # if torch.isnan(similarity_scores).any():
        #     print("scores in sa after q@k is nan")
        
        scaled_similarity_score = similarity_scores * (1 / math.sqrt(self.embedding_dim)) # [batch_size , seq_len ,seq_len]
        
        # if torch.isnan(scaled_similarity_score).any():
        #     print("score in sa after scaling is nan")
        
        if self.mask :
            mask = torch.ones_like(scaled_similarity_score)
            mask = torch.tril(mask)
            scaled_similarity_score = scaled_similarity_score.masked_fill(mask==0,float('-inf'))
            
        # if torch.isnan(scaled_similarity_score).any():
        #     print("scores in sa after masking is nan")
            
        if pad_mask != None:
            q_mask = pad_mask.unsqueeze(2) # [batch_size,seq_len] -> [batch_size,seq_len,1]
            k_mask = pad_mask.unsqueeze(1) # [batch_size,seq_len] -> [batch_size,1,seq_len]
            full_mask = q_mask & k_mask # [batch_size , seq_len , seq_len]
            identity = torch.eye(n=ModelArgs.seq_len,device=ModelArgs.device).unsqueeze(0).expand(ModelArgs.batch_size,-1,-1)
            full_mask = full_mask + identity # if all values in row becomes -inf the max = -inf then at stabilizing we got (-inf)-(-inf) = NaN so we prevent all values becoming -inf , we make the diagonal 1 by identity matrix
            scaled_similarity_score = scaled_similarity_score.masked_fill(full_mask==0,float('-inf'))
        
        # if torch.isnan(scaled_similarity_score).any():
        #     print("scores in sa after pad mask is nan")
            
        # Stabilize Attention Scores Before Softmax -> since the values becomming nan after softmax
        # if X = [x1,x2.....xn] if xi is very large or very small the exp(xi) will overflow and leads to nan values so we stabilize them
        scaled_similarity_score = scaled_similarity_score - scaled_similarity_score.max(dim=-1,keepdim=True).values # .values -> the max returns both indicies and values we access values by .values
        
        # if torch.isnan(scaled_similarity_score).any():
        #     print("scores in sa after stabilize is nan")

        weights = torch.softmax(scaled_similarity_score , dim = -1) # # [batch_size , seq_len ,seq_len]
        # if torch.isnan(weights).any():
        #     print("scores in sa after softmax is nan")
        
        weights = self.dropout(weights)
        # if torch.isnan(weights).any():
        #     print("scores in sa after dropout is nana")
        
        contextual_embedding = torch.matmul(weights,value) # [batch_size , seq_len ,seq_len] @ [batch_size , seq_len , v_dim]
        
        # if torch.isnan(contextual_embedding).any():
        #     print("embeddings in sa after weight@v is nan")
        
        return contextual_embedding # [batch_size , seq_len , v_dim]

In [65]:
s = torch.ones(size=[4,4])
torch.tril(s)

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [66]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embedding_dim,num_heads,attn_dropout):
        super().__init__()
        # self.embedding_layer = Embeddings(vocab_size=vocab_size,embedding_dim=embedding_dim)
        # self.position_encoding = PositionEmbedding(d_model=embedding_dim)
        self.multi_head_attention = nn.ModuleList([SelfAttention(embedding_dim=embedding_dim,q_dim=embedding_dim,k_dim=embedding_dim,v_dim=embedding_dim,attn_dropout=attn_dropout) for i in range(num_heads)])
        self.dropout = nn.Dropout(p=attn_dropout)
        self.ce_proj = nn.Linear(in_features=num_heads * embedding_dim , out_features=embedding_dim)
        
    def forward(self,embeddings,pad_mask):
        # # print(f"input shape : {embeddings.shape}")
        # embeddings = self.embedding_layer(embeddings)
        # # print(f"embedding : {embeddings.shape}")
        # embeddings = self.position_encoding(embeddings)
        # # print(f"pe + embedding : {embeddings.shape}")
        embeddings = torch.stack([self_attention(embeddings,pad_mask) for self_attention in self.multi_head_attention]) # [num_heads , batch_size , seq_len , embedding_dim] # embedding_dim = head_dim
        # if torch.isnan(embeddings).any():
        #     print("embeddings in mha after self attention is nan")
        # # print(f"multi head embeddings : {embeddings.shape}")
        embeddings = embeddings.permute(1,2,0,3) # [batch_size , seq_len , num_heads , head_dim]
        # if torch.isnan(embeddings).any():
        #     print("embeddings in mha after permute is nan")
        embeddings = embeddings.reshape(ModelArgs.batch_size ,ModelArgs.seq_len , -1 )
        # if torch.isnan(embeddings).any():
        #     print("embeddings in mha after reshape is nan")
        multi_head_embeddings = self.ce_proj(embeddings) # linear does not work with 4D tensors
        # if torch.isnan(multi_head_embeddings).any():
        #     print("embeddeings in mha after ce proj is nan")
        multi_head_embeddings = self.dropout(multi_head_embeddings)
        # if torch.isnan(multi_head_embeddings).any():
        #     print("embeddings in mha after dropout is nan")
        return multi_head_embeddings

In [67]:
# eb.shape

In [68]:
class LayerNorm(nn.Module):
    def __init__(self,embedding_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        
    def forward(self,input_for_norm ):
        return self.layer_norm(input_for_norm)

In [69]:
class AddResidual(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,X1,X2):
        return X1+X2

In [70]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self,embedding_dim,no_of_neurons,dropout):
        super().__init__()
        self.feed_nn = nn.Sequential(
            nn.Linear(in_features=embedding_dim,out_features=no_of_neurons),
            # nn.ReLU(),
            nn.GELU(),
            nn.Linear(in_features=no_of_neurons,out_features=embedding_dim),
            nn.Dropout(p=dropout)
        )
        
    def forward(self,X):
        return self.feed_nn(X)

In [71]:
t = torch.full(size=[4,4],fill_value=float('-inf'))
t = torch.triu(t,diagonal=1)
print(t)

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])


In [72]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self,embedding_dim,num_heads,attn_dropout):
        super().__init__()
        self.masked_multi_head_attention = nn.ModuleList([SelfAttention(embedding_dim=embedding_dim,
                                                                        q_dim=embedding_dim,
                                                                        k_dim=embedding_dim,
                                                                        v_dim=embedding_dim,
                                                                        attn_dropout=attn_dropout,
                                                                        mask=True) for i in range(num_heads)])

        self.ce_proj = nn.Linear(in_features=num_heads*embedding_dim , out_features=embedding_dim)
        self.dropout = nn.Dropout(p=attn_dropout)
    def forward(self,embeddings,pad_mask):
        # print("masked attention")
        embeddings = torch.stack([self_attention(embeddings,pad_mask) for self_attention in self.masked_multi_head_attention])
        # if torch.isnan(embeddings).any():
        #     print("self attention")
        embeddings = embeddings.permute(1,2,0,3) # [batch_size , seq_len , num_heads , head_dim]
        embeddings = embeddings.reshape(ModelArgs.batch_size ,ModelArgs.seq_len , -1 )
        multi_head_embeddings = self.ce_proj(embeddings)
        multi_head_embeddings = self.dropout(multi_head_embeddings)
        return multi_head_embeddings
        
        

In [73]:
class CrossAttention(nn.Module):
    def __init__(self,embedding_dim,q_dim,k_dim,v_dim,attn_dropout):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.w_q = nn.Linear(in_features=embedding_dim,out_features=q_dim)
        self.w_k = nn.Linear(in_features=embedding_dim,out_features=k_dim)
        self.w_v = nn.Linear(in_features=embedding_dim,out_features=v_dim)
        self.dropout = nn.Dropout(p=attn_dropout)
        
    def forward(self,encoder_embeddings,decoder_embeddings,encoder_pad_mask,decoder_pad_mask):
        query = self.w_q(decoder_embeddings) # [batch_size , out_seq_len , embedding_dim]
        key = self.w_k(encoder_embeddings) # [batch_size , in_seq_len , embedding_dim]
        value = self.w_v(encoder_embeddings) # [batch_size , in_seq_len , embedding_dim]
        
        similarity_scores = torch.bmm(query,key.transpose(-2,-1)) # [batch_size , out_seq_len, in_seq_len]
        
        scaled_similarity_scores = similarity_scores * (1 / math.sqrt(self.embedding_dim)) # [batch_size , out_seq_len , in_seq_len]
        
        if encoder_pad_mask != None and decoder_pad_mask != None:
            encoder_pad_mask = encoder_pad_mask.unsqueeze(1) # [batch_size , in_seq_len] -> [batch_size ,1 , in_seq_len]
            decoder_pad_mask = decoder_pad_mask.unsqueeze(2) # [batch_size , out_seq_len] -> [batch_size , out_seq_len, 1]
            full_mask = encoder_pad_mask & decoder_pad_mask
            eye = torch.eye(n=ModelArgs.seq_len,device=ModelArgs.device).unsqueeze(0).expand(ModelArgs.batch_size,-1,-1)
            full_mask = full_mask + eye
            
            scaled_similarity_scores = scaled_similarity_scores.masked_fill(full_mask==0,float('-inf'))
            
        weights = torch.softmax(scaled_similarity_scores , dim = -1) # [batch_size , out_seq_len , in_seq_len]    
        
        # print(f"weight : {weights.shape}")
        # print(f"value : {value.shape}")
        weights = self.dropout(weights)
        contextual_embedding = torch.matmul(weights , value)
        
        return contextual_embedding
        

In [74]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self,num_heads,embedding_dim,attn_dropout):
        super().__init__()
        self.multi_head_cross_attention = nn.ModuleList([CrossAttention(embedding_dim=embedding_dim,q_dim=embedding_dim,k_dim=embedding_dim,v_dim=embedding_dim,attn_dropout=attn_dropout) for i in range(num_heads)])
        self.ce_proj = nn.Linear(in_features=num_heads*embedding_dim,out_features=embedding_dim)
        self.dropout = nn.Dropout(p=attn_dropout)
        
    def forward(self,encoder_embeddings,decoder_embeddings,encoder_pad_mask,decoder_pad_mask):
        embeddings = torch.stack([cross_attention(encoder_embeddings=encoder_embeddings,decoder_embeddings=decoder_embeddings,encoder_pad_mask=encoder_pad_mask,decoder_pad_mask=decoder_pad_mask) for cross_attention in self.multi_head_cross_attention]) # each output of size -> [batch_size,seq_len,embedding_dim]
        embeddings = embeddings.permute(1,2,0,3) # [batch_size , seq_len , num_heads , head_dim]
        embeddings = embeddings.reshape(ModelArgs.batch_size ,ModelArgs.seq_len , -1 )
        multi_head_cross_attention = self.ce_proj(embeddings)
        multi_head_cross_attention = self.dropout(multi_head_cross_attention)
        return multi_head_cross_attention

In [89]:
class EncoderBlock(nn.Module):
    def __init__(self,num_heads,embedding_dim,ff_units,attn_dropout,dropout):
        super().__init__()
        # self.embedding_layer = Embeddings(vocab_size=vocab_size,embedding_dim=embedding_dim)
        # self.position_encoding = PositionEmbedding(d_model=embedding_dim)
        self.multi_head_attention = MultiHeadAttention(embedding_dim=embedding_dim,num_heads=num_heads,attn_dropout=attn_dropout)
        self.norm1 = LayerNorm(embedding_dim=embedding_dim)
        self.feed_nn = FeedForwardNeuralNetwork(embedding_dim=embedding_dim,no_of_neurons=ff_units,dropout=dropout)
        self.norm2 = LayerNorm(embedding_dim=embedding_dim)
        self.add = AddResidual()
        
    def forward(self,embeddings,pad_mask):
        # print("encoder")
        # embeddings = self.embedding_layer(input) # [batch_size , seq_len , embedding_dim]
        # positional_encoding = self.position_encoding() # [batch_size,seq_len , embedding_dim]
        # embeddings = embeddings + positional_encoding # [batch_size , seq_len , embedding_dim]

        # contextual_embeddings = self.multi_head_attention(embeddings,pad_mask) # [batch_size , seq_len , embedding_dim]

        # contextual_embeddings_norm = self.norm1(input_for_norm=contextual_embeddings ) # [batch_size , seq_len , embedding_dim]

        # contextual_embeddings_norm_fn = self.feed_nn(X=contextual_embeddings_norm) # [batch_size,seq_len,embedding_dim]

        # contextual_embeddings_norm_fn_norm = self.norm2(input_for_norm=contextual_embeddings_norm_fn) # [batch_size,seq_len,embedding_dim]


        embeddings_norm = self.norm1(input_for_norm=embeddings)
        embeddings_norm_mha = self.multi_head_attention(embeddings_norm,pad_mask)
        embeddings_norm_mha_add = self.add(embeddings,embeddings_norm_mha)

        embeddings_norm_mha_add_norm = self.norm2(input_for_norm=embeddings_norm_mha_add)
        embeddings_norm_mha_add_norm_fn = self.feed_nn(X=embeddings_norm_mha_add_norm)
        embeddubgs_norm_mha_add_norm_fn_add = self.add(embeddings,embeddings_norm_mha_add_norm_fn)        
        
        
        return embeddubgs_norm_mha_add_norm_fn_add # [batch_size,seq_len,embedding_dim]

In [81]:
class DecoderBlock(nn.Module):
    def __init__(self,embedding_dim,num_heads,ff_units,attn_dropout,dropout):
        super().__init__()
        # self.embedding_layer = Embeddings(vocab_size=vocab_size,embedding_dim=embedding_dim)
        # self.positional_encoding = PositionEmbedding(d_model=embedding_dim)
        self.masked_multi_head_attention = MaskedMultiHeadAttention(embedding_dim=embedding_dim,num_heads=num_heads,attn_dropout=attn_dropout)
        self.norm1 = LayerNorm(embedding_dim=embedding_dim)
        self.multi_head_cross_attention = MultiHeadCrossAttention(num_heads=num_heads,embedding_dim=embedding_dim,attn_dropout=attn_dropout)
        self.norm2 = LayerNorm(embedding_dim=embedding_dim)
        self.feed_nn = FeedForwardNeuralNetwork(embedding_dim=embedding_dim,no_of_neurons=ff_units,dropout=dropout)
        self.norm3 = LayerNorm(embedding_dim=embedding_dim)
        self.add = AddResidual()
        
    def forward(self,embeddings,encoder_embeddings,encoder_pad_mask,decoder_pad_mask):
        # embeddings = self.embedding_layer(input) # [batch_size,seq_len,embedding_dim]
        # positional_encoding = self.positional_encoding()
        # embeddings = embeddings+positional_encoding
        
        # embeddings_mmha = self.masked_multi_head_attention(embeddings=embeddings,pad_mask=decoder_pad_mask)

        # embeddings_mmha_norm = self.add_norm1(input_for_norm=embeddings_mmha,residual=embeddings)
        # embeddings_mmha_norm_mhca = self.multi_head_cross_attention(encoder_embeddings=encoder_embeddings,decoder_embeddings=embeddings_mmha_norm,encoder_pad_mask=encoder_pad_mask,decoder_pad_mask=decoder_pad_mask)

        # embeddings_mmha_norm_mhca_norm = self.add_norm2(input_for_norm=embeddings_mmha_norm_mhca,residual=embeddings_mmha_norm)
        # embeddings_mmha_norm_mhca_norm_ffnn = self.feed_nn(X=embeddings_mmha_norm_mhca_norm)
        # embeddings_mmha_norm_mhca_norm_ffnn_norm = self.add_norm3(input_for_norm=embeddings_mmha_norm_mhca_norm_ffnn,residual=embeddings_mmha_norm_mhca_norm)


        embeddings_norm = self.norm1(embeddings)
        embeddings_norm_mmha = self.masked_multi_head_attention(embeddings=embeddings,pad_mask=decoder_pad_mask)
        embeddings_norm_mmha_add = self.add(embeddings,embeddings_norm_mmha)
        
        embeddings_norm_mmha_add_norm = self.norm2(embeddings_norm_mmha_add)
        embeddings_norm_mmha_add_norm_mhca = self.multi_head_cross_attention(encoder_embeddings=encoder_embeddings,decoder_embeddings=embeddings,encoder_pad_mask=encoder_pad_mask,decoder_pad_mask=decoder_pad_mask)
        embeddings_norm_mmha_add_norm_mhca_add = self.add(embeddings_norm_mmha_add ,embeddings_norm_mmha_add_norm_mhca )

        embeddings_norm_mmha_add_norm_mhca_add_norm = self.norm3(embeddings_norm_mmha_add_norm_mhca_add)
        embeddings_norm_mmha_add_norm_mhca_add_norm_fn = self.feed_nn(X=embeddings_norm_mmha_add_norm_mhca_add_norm)
        embeddings_norm_mmha_add_norm_mhca_add_norm_fn_add = self.add(embeddings_norm_mmha_add_norm_mhca_add ,embeddings_norm_mmha_add_norm_mhca_add_norm_fn )
        return embeddings_norm_mmha_add_norm_mhca_add_norm_fn_add


In [82]:
class TransformerBlock(nn.Module):
    def __init__(self,nx,num_heads,src_vocab_size,dest_vocab_size,embedding_dim,ff_units,attn_dropout,dropout):
        super().__init__()
        self.src_embedding_layer = Embeddings(vocab_size=src_vocab_size,embedding_dim=embedding_dim)
        self.dest_embedding_layer = Embeddings(vocab_size=dest_vocab_size,embedding_dim=embedding_dim)
        self.positional_encoding = PositionEmbedding(d_model=embedding_dim)
        self.encoders = nn.ModuleList([EncoderBlock(num_heads=num_heads,
                                      embedding_dim=embedding_dim,
                                      ff_units=ff_units,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout) for i in range(nx)])
        self.decoders = nn.ModuleList([DecoderBlock(embedding_dim=embedding_dim,
                                      num_heads=num_heads,
                                      ff_units=ff_units,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout) for i in range(nx)])
        
    def forward(self,X,y,X_pad_mask,y_pad_mask):
        X_embedded = self.src_embedding_layer(X)
        y_embedded = self.dest_embedding_layer(y)
        
        # if torch.isnan(X_embedded).any():
        #     print("X_embd is nan before pe")
        # if torch.isnan(y_embedded).any():
        #     print("y embd is nan before pe")
        
        positional_encoding = self.positional_encoding()
        X_embedded = X_embedded + positional_encoding
        y_embedded = y_embedded + positional_encoding
        
        # if torch.isnan(X_embedded).any():
        #     print("X_embd is nan after pe")
        # if torch.isnan(y_embedded).any():
        #     print("y embd is nan after pe")

        
        
        for encoder in self.encoders:
            X_embedded = encoder(embeddings=X_embedded,pad_mask=X_pad_mask)
            
        for decoder in self.decoders:
            y_embedded = decoder(embeddings=y_embedded,encoder_embeddings=X_embedded,encoder_pad_mask=X_pad_mask,decoder_pad_mask=y_pad_mask)
        
        return X_embedded,y_embedded
        
        
        

In [83]:
class Transformer(nn.Module):
    def __init__(self,nx,num_heads,src_vocab_size,dest_vocab_size,embedding_dim,ff_units,attn_dropout,dropout):
        super().__init__()
        self.transformer = TransformerBlock(nx,num_heads,src_vocab_size,dest_vocab_size,embedding_dim,ff_units,attn_dropout=attn_dropout,dropout=dropout)
        self.classification_head = nn.Sequential(
            nn.Linear(in_features=embedding_dim,out_features=dest_vocab_size),
            nn.Softmax(dim=-1)
        )

        self.apply(self._init_weights)

    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            nn.init.xavier_uniform_(module.weight)
        elif isinstance(module,nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self,X,y,X_pad_mask,y_pad_mask):
        encoder_outputs,decoder_outputs = self.transformer(X=X,y=y,X_pad_mask=X_pad_mask,y_pad_mask=y_pad_mask)
        logits = self.classification_head(decoder_outputs)
        
        return logits
    

In [84]:
ModelArgs.device = "cuda" if torch.cuda.is_available() else "cpu"

In [90]:
model = Transformer(nx=5,
                    num_heads=4,
                    src_vocab_size=ModelArgs.de_vocab_size,
                    dest_vocab_size=ModelArgs.en_vocab_size,
                    embedding_dim=ModelArgs.embedding_dim,
                    ff_units=ModelArgs.no_of_neurons_ffnn,
                    attn_dropout=ModelArgs.attn_dropout,
                    dropout=ModelArgs.dropout)



In [91]:
model = model.to(ModelArgs.device)

In [94]:
from torchinfo import summary

summary(model=model,
        input_data=(sample_de.to(ModelArgs.device),sample_en.to(ModelArgs.device),sample_en_mask.to(ModelArgs.device),sample_de_mask.to(ModelArgs.device)),
        col_names=["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"],
       device=ModelArgs.device)

Layer (type (var_name))                                                     Input Shape          Output Shape         Param #              Trainable
Transformer (Transformer)                                                   [32, 50]             [32, 50, 297120]     --                   True
├─TransformerBlock (transformer)                                            --                   [32, 50, 64]         --                   True
│    └─Embeddings (src_embedding_layer)                                     [32, 50]             [32, 50, 64]         --                   True
│    │    └─Embedding (embedding_layer)                                     [32, 50]             [32, 50, 64]         20,649,600           True
│    └─Embeddings (dest_embedding_layer)                                    [32, 50]             [32, 50, 64]         --                   True
│    │    └─Embedding (embedding_layer)                                     [32, 50]             [32, 50, 64]         19,015,680   

In [92]:
sample_out = model(sample_de.to(ModelArgs.device),sample_en.to(ModelArgs.device),sample_en_mask.to(ModelArgs.device),sample_de_mask.to(ModelArgs.device))

In [93]:
sample_out.shape

torch.Size([32, 50, 297120])

In [50]:
# for

In [95]:
def train(model,model_name,train_dataloader,val_dataloader,criterion,optimizer,epochs,min_val_loss,device,wandb=None,scheduler=None,clip_grad=False):
    from tqdm import tqdm
    model = model.to(device)
    train_losses,val_losses = [],[]
    
    best_val_loss = float('inf')
    all_counters = []
    
    try:
        for epoch in range(epochs):
            counter = Counter()
            train_loss,correct,total = 0.0,0,0
            train_progress = tqdm(train_dataloader)
            for idx,(de_batch,en_batch,de_pad_mask,en_pad_mask) in enumerate(train_progress):
                de_batch = de_batch.to(device)
                en_batch = en_batch.to(device)
                de_pad_mask = de_pad_mask.to(device)
                en_pad_mask = en_pad_mask.to(device)
                optimizer.zero_grad()
                all_logits = model(de_batch,en_batch,de_pad_mask,en_pad_mask) # [batch_size,seq_len,vocab_size]
                all_logits = all_logits.view(-1,ModelArgs.en_vocab_size) # [batch_size , seq_len * vocab_size]
                en_batch = en_batch.view(-1) # [batch_size * seq_len]
                loss = criterion(all_logits,en_batch) 
                
                loss.backward()

                #######################################################################################################
                if clip_grad:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1.0)
                    
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2) # L2 norm
                        total_norm = param_norm.item()**2

                total_norm = total_norm**0.5

                    

                # print(total_norm)

                # for name,param in model.named_parameters():
                #     if param.grad is not None:
                #         print(f"{name} : {param.grad.norm().item():.6f}")

                if wandb is not None:
                    wandb.log({"norm":total_norm})

                    grad_groups = defaultdict(list)

                    for name, param in model.named_parameters():
                        if param.grad is None:
                            continue
                        # Get group name: e.g., "encoder.0", "decoder.1", "embedding", etc.
                        tokens = name.split('.')
                        group_name = '.'.join(tokens[:3]) if len(tokens) >= 3 else '.'.join(tokens[:2])
                
                        grad_norm = param.grad.norm().item()
                        grad_groups[group_name].append(grad_norm)
                
                    # Average gradients per group
                    avg_grad_per_group = {k: sum(v)/len(v) for k, v in grad_groups.items()}

                    wandb.log(avg_grad_per_group)

                #####################################################################################################
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                
                train_loss += loss.item()
                pred_probs = torch.softmax(all_logits,dim=-1) # [batch_size , seq_len * vocab_size]
                preds = torch.argmax(pred_probs,dim=-1)
                correct += (preds == en_batch).sum()
                total +=en_batch.shape[0]
                
                train_progress.set_postfix({"loss":f"{loss.item():.4f}"})

                train_losses.append(train_loss)
                counter.update(preds.tolist())
                if wandb is not None:
                    wandb.log({"train_loss":train_loss/(idx+1),"train_acc":correct/total})
                
            train_loss /= len(train_dataloader)
            train_acc = correct/total
                
            with torch.inference_mode():
                val_loss,correct,total = 0.0,0,0
                val_progress = tqdm(val_dataloader)
                for idx,(de_batch,en_batch,de_pad_mask,en_pad_mask) in enumerate(val_progress):
                    de_batch = de_batch.to(device)
                    en_batch = en_batch.to(device)
                    de_pad_mask = de_pad_mask.to(device)
                    en_pad_mask = en_pad_mask.to(device)
                    
                    all_logits = model(de_batch,en_batch,de_pad_mask,en_pad_mask)
                    all_logits = all_logits.view(-1,ModelArgs.en_vocab_size)
                    en_batch = en_batch.view(-1)
                    
                    loss = criterion(all_logits,en_batch)

                    pred_probs = torch.softmax(all_logits,dim=-1)
                    preds = torch.argmax(pred_probs,dim=-1)
                    
                    val_loss += loss.item()
                    correct += (preds==en_batch).sum()
                    total += en_batch.shape[0]
                    
                    val_progress.set_postfix({"loss":f"{loss.item():.4f}"})
                    
                    val_losses.append(val_loss)
                    if wandb is not None:
                        wandb.log({"val_loss":val_loss/(idx+1),"val_acc":correct/total})
                    
                    
                val_loss /= len(val_dataloader)
                val_acc = correct/total
                
            print(f"Epoch : {epoch}/{epochs} \n train loss : {train_loss:.5f} train acc : {train_acc:.5f}\n val loss : {val_loss:.5f}  val acc : {val_acc:.5f}")

            all_counters.append(counter)
            if best_val_loss > val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(),model_name)
            
            if val_loss < min_val_loss:
                print("model trained successully")
                return train_losses,val_losses , all_counters
        return train_losses,val_losses , all_counters
    

    except KeyboardInterrupt:
        return train_losses,val_losses , all_counters
            
        

In [96]:
!pip install dotenv



In [97]:
import os
from dotenv import load_dotenv
load_dotenv("/kaggle/input/wandb-key/.env")
wandb_key = os.getenv("WANDB_API_KEY")

In [98]:
import wandb
wandb.login(key=wandb_key)



True

In [112]:
# criterion = nn.CrossEntropyLoss(ignore_index=en_vocab["<pad>"],label_smoothing=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=en_vocab["<pad>"])

# optimizer = torch.optim.Adam(model.parameters(),lr=ModelArgs.max_lr)

from torch.optim.lr_scheduler import LambdaLR

def get_lr(step, d_model=ModelArgs.embedding_dim, warmup_steps=4500):
    step = max(step, 1)
    return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)

# optimizer = torch.optim.Adam(model.parameters(), lr=1e-8)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-5)
scheduler = LambdaLR(optimizer, lr_lambda=lambda step: get_lr(step, d_model=ModelArgs.embedding_dim))

In [113]:

wandb.init(
    project="transformer_form_scratch",
    name="logging layer wise grad + warmup 4500 + lr tuning to 1e-3 + weight intialization + optimizer Adam + nx to 5 + num_heads to 4 + tuned the embedding dim 32 -> 64 + pre normalization + scheduler + 500 epochs",
    config={
        "lr":ModelArgs.max_lr,
        "batch_size":ModelArgs.batch_size,
        "embedding_dim":ModelArgs.embedding_dim,
        "no_of_neurons_ffnn":ModelArgs.no_of_neurons_ffnn,
        "seq_len":ModelArgs.seq_len,
        "num_heads":ModelArgs.num_heads,
        "en_vocab_size":ModelArgs.en_vocab_size,
        "de_vocab_size":ModelArgs.de_vocab_size,
        
    }
)

0,1
classification_head.0.bias,▁▅▄▃▂▄▄▅▃▆▃▅▄▂▅▃▆▅▅▆▂▅█▆▅▂█▃▇▇█▆▆▄▅▇▂▅▄█
classification_head.0.weight,▆▁▅▄▆▁▂▄▃▃▃▁▃▄▅▃▅▅▅▅▄▅▆▆▂▆▃▁▄█▅▄▅▇▆▇▆▆▇▇
norm,▃▄▄▂▅▅▃▃▇▃▃▃▇▇▄▃▇▃▆▅▃▃▆▅▂▄▃▃▇█▂▁▂▅▄▅▂▆▇▅
train_acc,▆██▅▅▂▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▄▄
train_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
transformer.decoders.0,▃▂▄▂▂▂▅▁▃▃▆▄▂▂▂▃▃▄▃▂▃▃▆▂▄▂█▁▃▂▄▃▄▁▂▂▄▂▂▇
transformer.decoders.1,▂▄▃▃▅▅▂▄▄▆▆▄▃▄▆▄▁▂▃▂▃▄▆█▅▄▄▂▂▁▄▇▇▃▃▄▂▆▃▃
transformer.decoders.2,▃▃▄▆▄▂▅▃▂▃█▃▁▁▁▄▃▂▁▄▂▅▃▂▄▃▁▃▇▂▄▃▂▁▄▁▂▄▄█
transformer.decoders.3,▄▄▃▄▅▁▃▅▃▃▅▂▃▁▅▄▄▃▃▅█▄▆▂▂▅▄▅▃▂▅▄▆▂▄▃▄▆▅▄
transformer.decoders.4,▄█▄▃▅▂▂▄▄▂▅▅▆▅▃▄▅▃▁▄▅▅▅▆▂▄▅▄▅▂▅▃█▄▄▆▇▆▄▃

0,1
classification_head.0.bias,1e-05
classification_head.0.weight,0.00078
norm,1e-05
train_acc,0.00022
train_loss,12.60189
transformer.decoders.0,1e-05
transformer.decoders.1,1e-05
transformer.decoders.2,0.0
transformer.decoders.3,0.0
transformer.decoders.4,0.0


In [114]:
train_loss,val_loss,counter = train(model=model,
      model_name="transformer_form_scratch.pth",
      train_dataloader=train_dataloader,
      val_dataloader=val_dataloader,
      criterion=criterion,
      optimizer=optimizer,
      epochs=500,
      min_val_loss=1e-3,
    scheduler=scheduler,
      device=ModelArgs.device,
        clip_grad=True,
        wandb=wandb)

100%|██████████| 906/906 [03:46<00:00,  4.01it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.82it/s, loss=12.6019]


Epoch : 0/500 
 train loss : 12.60189 train acc : 0.00031
 val loss : 12.60189  val acc : 0.00030


100%|██████████| 906/906 [03:46<00:00,  4.00it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.75it/s, loss=12.6019]


Epoch : 1/500 
 train loss : 12.60189 train acc : 0.00031
 val loss : 12.60189  val acc : 0.00024


100%|██████████| 906/906 [03:45<00:00,  4.01it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.04it/s, loss=12.6019]


Epoch : 2/500 
 train loss : 12.60189 train acc : 0.00038
 val loss : 12.60189  val acc : 0.00048


100%|██████████| 906/906 [03:44<00:00,  4.04it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.08it/s, loss=12.6019]


Epoch : 4/500 
 train loss : 12.60189 train acc : 0.00090
 val loss : 12.60189  val acc : 0.00121


100%|██████████| 906/906 [03:43<00:00,  4.05it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.05it/s, loss=12.6019]


Epoch : 5/500 
 train loss : 12.60189 train acc : 0.00147
 val loss : 12.60189  val acc : 0.00240


100%|██████████| 906/906 [03:43<00:00,  4.05it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.09it/s, loss=12.6019]


Epoch : 6/500 
 train loss : 12.60189 train acc : 0.00240
 val loss : 12.60188  val acc : 0.00240


100%|██████████| 906/906 [03:44<00:00,  4.03it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.97it/s, loss=12.6019]


Epoch : 7/500 
 train loss : 12.60188 train acc : 0.00379
 val loss : 12.60188  val acc : 0.00494


100%|██████████| 906/906 [03:43<00:00,  4.05it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.05it/s, loss=12.6019]


Epoch : 8/500 
 train loss : 12.60188 train acc : 0.00530
 val loss : 12.60188  val acc : 0.00649


100%|██████████| 906/906 [03:43<00:00,  4.05it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.04it/s, loss=12.6019]


Epoch : 9/500 
 train loss : 12.60188 train acc : 0.00745
 val loss : 12.60188  val acc : 0.00883


100%|██████████| 906/906 [03:43<00:00,  4.06it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.81it/s, loss=12.6019]


Epoch : 10/500 
 train loss : 12.60188 train acc : 0.00995
 val loss : 12.60188  val acc : 0.01085


100%|██████████| 906/906 [03:44<00:00,  4.04it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.89it/s, loss=12.6019]


Epoch : 12/500 
 train loss : 12.60188 train acc : 0.01697
 val loss : 12.60188  val acc : 0.01974


100%|██████████| 906/906 [03:44<00:00,  4.04it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00, 10.00it/s, loss=12.6019]


Epoch : 14/500 
 train loss : 12.60188 train acc : 0.02551
 val loss : 12.60188  val acc : 0.02796


100%|██████████| 906/906 [03:45<00:00,  4.02it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.98it/s, loss=12.6019]


Epoch : 15/500 
 train loss : 12.60188 train acc : 0.02991
 val loss : 12.60188  val acc : 0.03593


100%|██████████| 906/906 [03:44<00:00,  4.03it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.92it/s, loss=12.6019]


Epoch : 16/500 
 train loss : 12.60188 train acc : 0.03514
 val loss : 12.60188  val acc : 0.03744


100%|██████████| 906/906 [03:44<00:00,  4.03it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.92it/s, loss=12.6019]


Epoch : 17/500 
 train loss : 12.60188 train acc : 0.04074
 val loss : 12.60187  val acc : 0.04319


100%|██████████| 906/906 [03:44<00:00,  4.03it/s, loss=12.6019]
100%|██████████| 31/31 [00:03<00:00,  9.95it/s, loss=12.6019]


Epoch : 18/500 
 train loss : 12.60187 train acc : 0.04675
 val loss : 12.60187  val acc : 0.05052


 52%|█████▏    | 471/906 [01:57<01:48,  4.01it/s, loss=12.6019]


In [115]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [49]:
torch.save(train_data,"train_data.pt")
torch.save(val_data,"val_data.pt")
torch.save(test_data,"test_data.pt")

In [1]:
model.parameters().__next___().dtype

NameError: name 'model' is not defined

In [51]:
torch.save(args,"model_args.pt")

In [52]:
import pickle
with open("en_vocab.pkl","wb") as f:
    pickle.dump(dict(en_vocab),f)

In [53]:
with open("de_voab.pkl","wb") as f:
    pickle.dump(dict(de_vocab),f)

In [55]:
args = {"en_vocab_size":ModelArgs.en_vocab_size,
       "de_vocab_size":ModelArgs.de_vocab_size}

with open("model_args.pkl","wb") as f:
    pickle.dump(args,f)