In [1]:
import json
import os
import pandas as pd
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenize,tokenize_to_id,detokenize_to_text
import gc
from tqdm import tqdm
import time
from collections import defaultdict
import numpy as np
import copy

In [2]:
TEXTPATH = "../../input/tinystories/TinyStoriesV2-GPT4-valid.txt"
TOKEN_ID_PATH = 'output/config/token_to_id_mapping.json'
ID_TOKEN_PATH = 'output/config/id_to_token_mapping.json'
SEQ_LEN = 32
BATCH_SIZE = 16
EPOCHS = 1
VOCAB_SIZE = 9772

EMBED_DIM = 300
TRANSFORMER_HEADS = 3
TRANSFORMER_LAYERS = 2
LR = 9e-5
DEVICE="cpu"

In [3]:
with open(TEXTPATH) as f:
    mainText = f.read()
# mainText = " ".join(mainText.split()[:10000])
train_index = int(0.9*len(mainText.split(" ")))
train = " ".join(mainText.split()[:train_index])
valid = " ".join(mainText.split()[train_index:])
print(len(train),len(valid))


19590883 2883876


In [4]:
with open(TOKEN_ID_PATH) as json_file:
    token_to_id_mapping = json.load(json_file)
    token_to_id_mapping = json.loads(token_to_id_mapping)
with open(ID_TOKEN_PATH) as json_file:
    id_to_token_mapping = json.load(json_file)
    id_to_token_mapping = json.loads(id_to_token_mapping)    
TOKEN_TO_ID_MAPPING = token_to_id_mapping
ID_TO_TOKEN_MAPPING = {int(t[0]):t[1] for t in id_to_token_mapping.items() }
# token_to_id_mapping

In [5]:
max(list(token_to_id_mapping.values())),max([int(t) for t in list(id_to_token_mapping.keys())])

(9771, 9771)

In [6]:
class TinyDataset(torch.utils.data.Dataset):
    def __init__(self,text):
        self.text = text
        
        
        self.tokens = tokenize(text)
    def __len__(self):
        return len(self.tokens) - SEQ_LEN
    def __getitem__(self,idx):
        x = self.tokens[idx:idx+SEQ_LEN]
        y = self.tokens[idx+1:idx+SEQ_LEN+1]
        x = tokenize_to_id(x,TOKEN_TO_ID_MAPPING)
        y = tokenize_to_id(y,TOKEN_TO_ID_MAPPING)
        x = torch.tensor(x,dtype=torch.long)
        y = torch.tensor(y,dtype=torch.long)
        return {'x':x,'y':y}
        

In [7]:
# https://habr.com/en/companies/ods/articles/708672/
class AttentionHead(nn.Module):
    """
    One head of the self-attention layer
    """

    def __init__(self, head_size, num_embed, block_size):
        super().__init__()
        """
        Initializes the AttentionHead module.

        Args:
            head_size (int): The size of each attention head.
            num_embed (int): The dimension of the input embeddings.
            block_size (int): The block size of the input sequence.
        """
        self.key = nn.Linear(num_embed, head_size, bias=False)
        self.query = nn.Linear(num_embed, head_size, bias=False)
        self.value = nn.Linear(num_embed, head_size, bias=False)
        # tril is a lower triangular matrix. it is not a parameter
        # of the model, so we assign it to the module using register_buffer
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        """
        Performs the forward pass of the AttentionHead module.

        Args:
            x (Tensor): The input tensor of shape (B, T, C), where B is the batch size, T is the sequence length,
                and C is the dimension of the input embeddings.

        Returns:
            Tensor: The output tensor of shape (B, T, H).
        """
        B, T, C = x.shape
        print("att dims:" ,B,T,C)
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        # (B, T, C) @ (B, C, T) -> (B, T, T)
        # we need to transpose k to match q
        wei = q @ k.transpose(-2, -1) * C**-0.5
        # Tril matrix (lower triagular matrix) is used to mask 
        # future positions (setting them to -inf) so that the
        # decoder "learns" to predict next words
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B,T,T)
        wei = F.softmax(wei, dim=-1)  # (B,T,T)
        # weighted aggregation of the values
        v = self.value(x)
        out = wei @ v  # (B,T,T) @ (B,T,C) ---> (B,T,C)
        print(out.shape)
        return out
class MultiHeadAttention(nn.Module):
    """
    Multiple Heads of self-attention in parallel
    """

    def __init__(self, num_heads, head_size, num_embed, block_size):
        """
        Initializes the MultiHeadAttention module.

        Args:
            num_heads (int): The number of attention heads.
            head_size (int): The size of each attention head.
            num_embed (int): The dimension of the input embeddings.
            block_size (int): The block size of the input sequence.
        """
        super().__init__()
        self.heads = nn.ModuleList(
            [
                AttentionHead(
                    head_size=head_size,
                    num_embed=num_embed,
                    block_size=block_size,
                )
                for _ in range(num_heads)
            ]
        )
        self.proj = nn.Linear(num_embed, num_embed)

    def forward(self, x):
        # output of the self-attention
        print([h(x).shape for h in self.heads])
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        print("mha cat att",out.shape)
        # apply the linear projection layer
        out = self.proj(out)
        print("mha proj",out.shape)
        return out    
class FeedForward(nn.Module):
    """
    A simple linear layer followed by ReLu
    """

    def __init__(self, num_embed):
        super().__init__()
        self.net = nn.Sequential(
            # in the Attention is All You Need paper
            # authors are using the size of the ffwd layer 2048
            # and the output of the model is 512
            # so we apply the same factor of 4
            nn.Linear(num_embed, 4 * num_embed),
            nn.ReLU(),
            # apply the linear projection layer
            nn.Linear(4 * num_embed, num_embed),
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    """
    Groups together MultiHeadAttention and FeedForward modules
      to form a Transformer block.
    """

    def __init__(self, num_heads, block_size, num_embed):
        super().__init__()
        head_size = num_embed // num_heads
        self.sa = MultiHeadAttention(
            num_heads=num_heads,
            head_size=head_size,
            num_embed=num_embed,
            block_size=block_size,
        )
        self.ffwd = FeedForward(num_embed=num_embed)
        # add the layer normalization
        self.ln1 = nn.LayerNorm(num_embed)
        self.ln2 = nn.LayerNorm(num_embed)

    def forward(self, x):
        # "x +" is the skip (or residual) connection
        # it helps with optimization
        # also we apply layer normalization before self-attention
        # and feed-forward (a reshufle from original paper)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x    
class Transformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # a simple lookup table that stores embeddings of a fixed dictionary and size
        # each token directly reads off the logits for the next token from a lookup table
        # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.vocab_size = kwargs.get("vocab_size", 100)
        self.num_embed = kwargs.get("num_embed", 32)
        self.block_size = kwargs.get("block_size", 8)
        self.num_heads = kwargs.get("num_heads", 4)
        self.num_layers = kwargs.get("num_layers", 4)
        # each token reads the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
        # each position from 0 to block_size-1 will get its embedding
        self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed)
        self.blocks = nn.Sequential(
            *[
                TransformerBlock(
                    num_heads=self.num_heads,
                    block_size=self.block_size,
                    num_embed=self.num_embed,
                )
                for _ in range(self.num_layers)
            ]
        )
        # we add the layer norm before the Linear layer
        self.ln_f = nn.LayerNorm(self.num_embed)
        self.lm_head = nn.Linear(self.num_embed, self.vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are (B,T) tensor of integers
        # the token_emb is (B, T, C), C = NUM_EMBED
        token_emb = self.token_embedding_table(idx)
        # (T, C)
        print("token_emb.shape",token_emb.shape)
        posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
        print("posit_emb.shape",posit_emb.shape)
        x = token_emb + posit_emb
        # apply one head of self-attention
        print("block input",x.shape)
        x = self.blocks(x)
        print("block output",x.shape)
        # (B, T, vocab_size)
        logits = self.lm_head(x)
        print("logits output",logits.shape)
        # compute the loss
        if targets != None:
            # cross_entropy accepts inputs in a (batch_size, num_classes)
            # so we need to reformat our logits dimensions to
            # (batch_size * time, dim_vocabulary), time = block_size
            B, T, C = logits.shape
            logits = torch.reshape(logits, (B * T, C))
            targets = torch.reshape(targets, (B * T,))
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop the context too the  last block_size tokens
            # because tokens don't communicate between blocks
            idx_crop = idx[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(idx_crop)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution with probabilities probs
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx    

In [8]:
train_data = TinyDataset(train)
eval_data = TinyDataset(valid)
train_dataloader = torch.utils.data.DataLoader(train_data,\
                        batch_size=BATCH_SIZE,\
                        shuffle=True)
eval_dataloader = torch.utils.data.DataLoader(eval_data,\
                    batch_size=BATCH_SIZE,\
                    shuffle=False)

In [9]:
model = Transformer(vocab_size = VOCAB_SIZE,
                    num_embed = EMBED_DIM,
                    block_size = SEQ_LEN,
                    num_heads = TRANSFORMER_HEADS,
                    num_layers = TRANSFORMER_LAYERS)
model.to(DEVICE)
param_optimizer = model.parameters()
optimizer = torch.optim.AdamW(param_optimizer, lr=LR)
len(train_data),len(eval_data)

(9317481, 1369773)

In [10]:
def train_one_epoch(model, optimizer,  dataloader, device, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        x = data['x'].to(device)
        y = data['y'].to(device)
        
        batch_size = x.size(0)
        
        logits,loss = model.forward(x,y)
        
            
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
            
                
        running_loss += loss.item()
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    return epoch_loss  
def valid_one_epoch(model, optimizer,  dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    with torch.no_grad():
        for step, data in bar:
            x = data['x'].to(device)
            y = data['y'].to(device)
            
            batch_size = x.size(0)
            
            logits,loss = model.forward(x,y)
            
                    
            running_loss += loss.item()
            dataset_size += batch_size
            
            epoch_loss = running_loss / dataset_size
            
            bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                            LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    return epoch_loss  

In [11]:
start = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_epoch_loss = np.inf
history = defaultdict(list)
MODEL_DIR = 'output/modelwt'
os.makedirs(MODEL_DIR,exist_ok=True)
for epoch in range(1,EPOCHS+1):
    gc.collect()
    train_loss = train_one_epoch(model,optimizer,train_dataloader,DEVICE,epoch)

    val_loss = valid_one_epoch(model,optimizer,eval_dataloader,DEVICE,epoch)

    history["TrainLoss"].append(train_loss)
    history["ValLoss"].append(val_loss)
    print('EPOCH: ',epoch)
    print({"Train Loss": train_loss})
    print({"Valid Loss": val_loss})
    break
# break
    
    # if val_loss < best_epoch_loss:
    #     print(f"Validation Loss Improved ({best_epoch_loss} ---> {val_loss})")
        
    #     best_epoch_loss = val_loss
        
    #     best_model_wts = copy.deepcopy(model.state_dict())
        
    #     PATH ="model.pt"
    #     torch.save(model.state_dict(), os.path.join(MODEL_DIR,PATH))
    #     # Save a model file from the current directory
    #     print("Model Saved")
    # print("###Evaluation###")
    # evaluate_samples= ["He said goodbye ","Once upon a time","He was the best"]
    # for sample in evaluate_samples:
    #     tokens = tokenize(sample)
    #     x = tokenize_to_id(tokens,TOKEN_TO_ID_MAPPING)
    #     x = torch.tensor(x,dtype=torch.long).reshape(1,-1).to(DEVICE)
    #     gen_seq = model.generate(idx=x, max_new_tokens=32, block_size=SEQ_LEN)
    #     output = detokenize_to_text(list(gen_seq.cpu().detach().numpy()[0]), ID_TO_TOKEN_MAPPING)
    #     print(output)
    #     print("\n")
end = time.time()
time_elapsed = end - start
print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
    time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
print("Best Loss: {:.4f}".format(best_epoch_loss))
PATH ="lst_model.pt"
# torch.save(model.state_dict(), os.path.join(MODEL_DIR,PATH))

  0%|          | 0/582343 [00:00<?, ?it/s]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 1/582343 [00:01<167:40:06,  1.04s/it, Epoch=1, LR=9e-5, Train_Loss=0.594]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 3/582343 [00:01<63:16:38,  2.56it/s, Epoch=1, LR=9e-5, Train_Loss=0.579] 

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 4/582343 [00:01<50:06:27,  3.23it/s, Epoch=1, LR=9e-5, Train_Loss=0.571]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att di

  0%|          | 5/582343 [00:01<42:44:31,  3.78it/s, Epoch=1, LR=9e-5, Train_Loss=0.562]

block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att t

  0%|          | 6/582343 [00:02<38:43:35,  4.18it/s, Epoch=1, LR=9e-5, Train_Loss=0.554]

logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch

  0%|          | 8/582343 [00:02<35:26:29,  4.56it/s, Epoch=1, LR=9e-5, Train_Loss=0.537]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 9/582343 [00:02<34:31:34,  4.69it/s, Epoch=1, LR=9e-5, Train_Loss=0.529]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 10/582343 [00:02<33:54:10,  4.77it/s, Epoch=1, LR=9e-5, Train_Loss=0.52]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 11/582343 [00:03<32:48:22,  4.93it/s, Epoch=1, LR=9e-5, Train_Loss=0.511]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 12/582343 [00:03<33:05:25,  4.89it/s, Epoch=1, LR=9e-5, Train_Loss=0.502]

att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300


  0%|          | 13/582343 [00:03<33:03:10,  4.89it/s, Epoch=1, LR=9e-5, Train_Loss=0.493]

torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att di

  0%|          | 14/582343 [00:03<32:17:37,  5.01it/s, Epoch=1, LR=9e-5, Train_Loss=0.485]

att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att t

  0%|          | 15/582343 [00:03<32:17:27,  5.01it/s, Epoch=1, LR=9e-5, Train_Loss=0.477]

mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torc

  0%|          | 16/582343 [00:03<31:47:31,  5.09it/s, Epoch=1, LR=9e-5, Train_Loss=0.47] 

block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att t

  0%|          | 17/582343 [00:04<31:14:14,  5.18it/s, Epoch=1, LR=9e-5, Train_Loss=0.463]

logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch

  0%|          | 19/582343 [00:04<31:40:35,  5.11it/s, Epoch=1, LR=9e-5, Train_Loss=0.452]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 20/582343 [00:04<32:05:27,  5.04it/s, Epoch=1, LR=9e-5, Train_Loss=0.447]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])


  0%|          | 21/582343 [00:04<32:05:47,  5.04it/s, Epoch=1, LR=9e-5, Train_Loss=0.442]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 22/582343 [00:05<32:06:48,  5.04it/s, Epoch=1, LR=9e-5, Train_Loss=0.437]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att di

  0%|          | 23/582343 [00:05<31:56:05,  5.07it/s, Epoch=1, LR=9e-5, Train_Loss=0.432]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16

  0%|          | 24/582343 [00:05<31:50:54,  5.08it/s, Epoch=1, LR=9e-5, Train_Loss=0.427]

att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att t

  0%|          | 25/582343 [00:05<31:34:58,  5.12it/s, Epoch=1, LR=9e-5, Train_Loss=0.422]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 3

  0%|          | 26/582343 [00:05<31:39:59,  5.11it/s, Epoch=1, LR=9e-5, Train_Loss=0.417]

block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att t

  0%|          | 27/582343 [00:06<31:56:15,  5.06it/s, Epoch=1, LR=9e-5, Train_Loss=0.413]

logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch

  0%|          | 29/582343 [00:06<32:29:17,  4.98it/s, Epoch=1, LR=9e-5, Train_Loss=0.405]

token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.S

  0%|          | 30/582343 [00:06<32:23:14,  4.99it/s, Epoch=1, LR=9e-5, Train_Loss=0.402]

torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16, 32, 100]), torch.Size([16, 32, 100])]
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
mha cat att torch.Size([16, 32, 300])
mha proj torch.Size([16, 32, 300])
block output torch.Size([16, 32, 300])
logits output torch.Size([16, 32, 9772])
token_emb.shape torch.Size([16, 32, 300])
posit_emb.shape torch.Size([32, 300])
block input torch.Size([16, 32, 300])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
att dims: 16 32 300
torch.Size([16, 32, 100])
[torch.Size([16, 32, 100]), torch.Size([16

In [None]:
evaluate_samples= ["He said goodbye ","Once upon a time","He was the best", "in a land far far away"]

for sample in evaluate_samples:
    tokens = tokenize(sample)
    x = tokenize_to_id(tokens,TOKEN_TO_ID_MAPPING)
    x = torch.tensor(x,dtype=torch.long).reshape(1,-1).to(DEVICE)
    gen_seq = model.generate(idx=x, max_new_tokens=64, block_size=SEQ_LEN)
    output = detokenize_to_text(list(gen_seq.cpu().detach().numpy()[0]), ID_TO_TOKEN_MAPPING)
    print(output)
    print("\n")

he said goodbye - a reward. momo felt happy and excited. finally, the little bug found it stuck and he slipped and fell. ollie was a quick, but he still


once upon a time there was a little girl who was playing in the mud. she was only three years old and she had a pet before. they wanted to go to another store


he was the best grape had ever made! meadow was so excited. but when she arrived at the park, the boy found a big pile of soft stones. he saw a big


in a land far far away. <|endoftext|> sally was playing and bob when he saw a new home. they wanted to play a trick on tv like wet. tom and sue decided to organize a


