## MiniGPT --> Word level language model trained on tinyshakespeare dataset

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

'wget' is not recognized as an internal or external command,
operable program or batch file.


In [2]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [4]:
# let's look at the first 1000 characters
print(text[:102])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You a


In [5]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [6]:
import torch
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from typing import List
from torch.utils.data import DataLoader, TensorDataset, random_split

def tokenize_sentences(sentences: List[str], vocab_size: int = 100000):
    """Tokenize sentences and create a vocabulary."""
    tokenized_sentences = [sentence.lower().split() for sentence in sentences]
    all_words = [word for sentence in tokenized_sentences for word in sentence]
    
    # Build a vocabulary of the most common words
    word_counts = Counter(all_words)
    vocab = {word: i+2 for i, (word, _) in enumerate(word_counts.most_common(vocab_size))}
    inv_vocab={i+2: word for i, (word, _) in enumerate(word_counts.most_common(vocab_size))}
    vocab['<pad>'] = 0
    vocab['<unk>'] = 1
    inv_vocab[0]='<pad>'
    inv_vocab[1]='<unk>'
    
    return tokenized_sentences, vocab ,inv_vocab

def sentence_to_tensor(sentence: List[str], vocab: dict, sequence_length: int = 10):
    """Convert a sentence to a tensor of fixed length."""
    indices = [vocab.get(word, vocab['<unk>']) for word in sentence[:sequence_length]]
    return torch.tensor(indices, dtype=torch.long)


def prepare_dataset(sentences: List[str], context_length: int, max_length: int, batch_size: int):
    """Prepare the dataset and split into train, validation, and test sets."""
    tokenized_sentences, vocab,inv_vocab = tokenize_sentences(sentences)
    input_tensors = []
    target_tensors = []

    for sentence in tokenized_sentences:
        if len(sentence)<=1:
            continue
        sentence = sentence[:max_length] + ['<pad>'] * max(0, max_length - len(sentence))
        for i in range(len(sentence) - context_length):
            input_seq = sentence[i:i+context_length]
            target_seq = sentence[i+1:i+1+context_length]
            input_tensors.append(sentence_to_tensor(input_seq, vocab, context_length))
            target_tensors.append(sentence_to_tensor(target_seq, vocab, context_length))

    input_tensor = torch.stack(input_tensors)
    target_tensor = torch.stack(target_tensors)
    # Create TensorDataset
    dataset = TensorDataset(input_tensor, target_tensor)
    
    # Split dataset into train, validation, and test sets
    train_size = int(0.7 * len(dataset))
    valid_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

    # Create DataLoader for each set
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    #return train_loader, valid_loader, test_loader, vocab
    return train_loader,valid_loader,test_loader,vocab,inv_vocab

In [7]:
sentences=text.split('\n')
#input,target,vocab=prepare_dataset(sentences,context_length=15,max_length=50)

context_length=5
max_length=10
batch_size=3
train_loader, valid_loader, test_loader, vocab ,inv_vocab= prepare_dataset(sentences, context_length, max_length, batch_size)

In [8]:
print(f" train_loader length :{train_loader.__len__()}")

#print(f" target.shape :{target.shape}")

 train_loader length :31858


In [9]:
print(f"Length of sentences :{len(sentences)}")
vocab_size=len(vocab)
print(f"vocab size: {vocab_size}")

Length of sentences :40001
vocab size: 23643


In [22]:
# hyperparameters
maxlen=50
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 15 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 512
n_head = 4
n_layer = 4
dropout = 0.0

In [26]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class miniGPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        #print(idx.shape)
        B,T = idx.shape
        #print(f"B : {B} C:{T}")
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [30]:
from tqdm import tqdm
num_epochs=10
model = miniGPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

n_embd = 512
n_head = 8
n_layer = 6
dropout = 0.1
learning_rate=4e-5
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
eval_interval=1000
print(f"learning_rate:{learning_rate} num_head: {n_head} number_of_embeddings: {n_embd} number_of_layers: {n_layer}")
for epoch in range(num_epochs):
    model.train()
    total_loss=0
    for iter, (xb, yb) in tqdm(enumerate(train_loader)):
        xb, yb = xb.to(device), yb.to(device)
        logits, loss = model(xb, yb)
        total_loss+=loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if iter % eval_interval == 0:
            print(f"Epoch {epoch}, Step {iter}: train loss {loss:.4f}")
    avg_loss=total_loss/len(train_loader)
    print(f"==========Epoch {epoch}: Average training loss: {avg_loss:.4f}=============")
    
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for xv, yv in valid_loader:
            xv, yv = xv.to(device), yv.to(device)
            val_logits, val_loss = model(xv, yv)
            total_val_loss += val_loss.item()
    
    avg_val_loss = total_val_loss / len(valid_loader)
    print(f"Epoch {epoch}: Average validation loss: {avg_val_loss:.4f}")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
    }
    torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pth.tar")



36.846171 M parameters
learning_rate:4e-05 num_head: 8 number_of_embeddings: 512 number_of_layers: 6


4it [00:00, 13.69it/s]

Epoch 0, Step 0: train loss 10.2910


1004it [00:47, 20.70it/s]

Epoch 0, Step 1000: train loss 6.8119


2003it [01:35, 22.02it/s]

Epoch 0, Step 2000: train loss 5.8729


3004it [02:22, 21.34it/s]

Epoch 0, Step 3000: train loss 7.3609


4004it [03:10, 19.21it/s]

Epoch 0, Step 4000: train loss 5.5416


5005it [03:57, 22.42it/s]

Epoch 0, Step 5000: train loss 3.4412


6005it [04:45, 21.67it/s]

Epoch 0, Step 6000: train loss 4.7978


7004it [05:33, 19.76it/s]

Epoch 0, Step 7000: train loss 4.2346


8005it [06:21, 21.54it/s]

Epoch 0, Step 8000: train loss 5.3159


9003it [07:09, 20.58it/s]

Epoch 0, Step 9000: train loss 6.1314


10003it [07:56, 19.61it/s]

Epoch 0, Step 10000: train loss 3.1942


11004it [08:43, 21.06it/s]

Epoch 0, Step 11000: train loss 4.8625


12004it [09:31, 22.10it/s]

Epoch 0, Step 12000: train loss 3.0743


13004it [10:18, 20.59it/s]

Epoch 0, Step 13000: train loss 5.1094


14002it [11:05, 19.99it/s]

Epoch 0, Step 14000: train loss 6.1813


15004it [11:52, 21.75it/s]

Epoch 0, Step 15000: train loss 6.1395


16003it [12:39, 19.32it/s]

Epoch 0, Step 16000: train loss 3.9581


17005it [13:25, 23.10it/s]

Epoch 0, Step 17000: train loss 6.5786


18005it [14:13, 21.71it/s]

Epoch 0, Step 18000: train loss 4.6969


19003it [15:00, 19.77it/s]

Epoch 0, Step 19000: train loss 1.0075


20003it [15:47, 19.45it/s]

Epoch 0, Step 20000: train loss 7.9932


21003it [16:34, 20.49it/s]

Epoch 0, Step 21000: train loss 4.7274


22005it [17:21, 23.02it/s]

Epoch 0, Step 22000: train loss 6.6975


23004it [18:08, 20.34it/s]

Epoch 0, Step 23000: train loss 6.0108


24004it [18:55, 21.14it/s]

Epoch 0, Step 24000: train loss 5.9495


25004it [19:44, 19.90it/s]

Epoch 0, Step 25000: train loss 5.1990


26004it [20:31, 21.94it/s]

Epoch 0, Step 26000: train loss 5.3124


27003it [21:17, 23.30it/s]

Epoch 0, Step 27000: train loss 4.6932


28005it [22:04, 21.81it/s]

Epoch 0, Step 28000: train loss 3.7212


29003it [22:52, 20.07it/s]

Epoch 0, Step 29000: train loss 4.7987


30004it [23:39, 21.44it/s]

Epoch 0, Step 30000: train loss 4.7401


31003it [24:25, 22.51it/s]

Epoch 0, Step 31000: train loss 2.1714


31858it [25:05, 21.16it/s]


Epoch 0: Average validation loss: 4.8220


3it [00:00, 22.74it/s]

Epoch 1, Step 0: train loss 5.8437


1002it [00:47, 20.77it/s]

Epoch 1, Step 1000: train loss 4.6232


2003it [01:35, 21.52it/s]

Epoch 1, Step 2000: train loss 3.3390


3006it [02:22, 22.98it/s]

Epoch 1, Step 3000: train loss 4.3802


4005it [03:09, 22.30it/s]

Epoch 1, Step 4000: train loss 4.0230


5003it [03:56, 21.89it/s]

Epoch 1, Step 5000: train loss 4.6700


6005it [04:43, 22.53it/s]

Epoch 1, Step 6000: train loss 4.4579


7004it [05:31, 22.11it/s]

Epoch 1, Step 7000: train loss 3.6362


8003it [06:18, 22.17it/s]

Epoch 1, Step 8000: train loss 3.8637


9003it [07:04, 21.32it/s]

Epoch 1, Step 9000: train loss 5.3522


10003it [07:51, 21.23it/s]

Epoch 1, Step 10000: train loss 5.0915


11005it [08:38, 22.76it/s]

Epoch 1, Step 11000: train loss 3.9541


12005it [09:26, 21.89it/s]

Epoch 1, Step 12000: train loss 3.4750


13003it [10:13, 22.32it/s]

Epoch 1, Step 13000: train loss 5.3489


14004it [11:00, 20.78it/s]

Epoch 1, Step 14000: train loss 4.9084


15003it [11:47, 21.56it/s]

Epoch 1, Step 15000: train loss 3.4323


16003it [12:35, 21.41it/s]

Epoch 1, Step 16000: train loss 5.1568


17003it [13:22, 22.69it/s]

Epoch 1, Step 17000: train loss 3.6775


18005it [14:09, 22.51it/s]

Epoch 1, Step 18000: train loss 4.8517


19004it [14:56, 19.04it/s]

Epoch 1, Step 19000: train loss 4.9108


20005it [15:43, 23.33it/s]

Epoch 1, Step 20000: train loss 5.7224


21004it [16:31, 22.32it/s]

Epoch 1, Step 21000: train loss 4.2206


22005it [17:18, 20.79it/s]

Epoch 1, Step 22000: train loss 5.6688


23003it [18:06, 21.97it/s]

Epoch 1, Step 23000: train loss 5.6394


24004it [18:53, 21.08it/s]

Epoch 1, Step 24000: train loss 3.6868


25005it [19:39, 22.31it/s]

Epoch 1, Step 25000: train loss 4.2295


26003it [20:26, 23.13it/s]

Epoch 1, Step 26000: train loss 3.7577


27005it [21:12, 21.02it/s]

Epoch 1, Step 27000: train loss 4.3567


28003it [21:58, 22.46it/s]

Epoch 1, Step 28000: train loss 4.7975


29005it [22:44, 22.16it/s]

Epoch 1, Step 29000: train loss 3.8523


30004it [23:30, 23.51it/s]

Epoch 1, Step 30000: train loss 3.7472


31004it [24:18, 20.84it/s]

Epoch 1, Step 31000: train loss 6.0246


31858it [24:59, 21.24it/s]


Epoch 1: Average validation loss: 4.4466


6it [00:00, 23.13it/s]

Epoch 2, Step 0: train loss 4.4685


1003it [00:46, 19.76it/s]

Epoch 2, Step 1000: train loss 1.9493


2004it [01:35, 19.87it/s]

Epoch 2, Step 2000: train loss 3.8532


3005it [02:22, 22.92it/s]

Epoch 2, Step 3000: train loss 5.7381


4002it [03:09, 22.01it/s]

Epoch 2, Step 4000: train loss 3.1054


5002it [03:56, 20.36it/s]

Epoch 2, Step 5000: train loss 5.1029


6004it [04:42, 22.36it/s]

Epoch 2, Step 6000: train loss 3.3400


7002it [05:29, 23.32it/s]

Epoch 2, Step 7000: train loss 4.2323


8004it [06:16, 21.25it/s]

Epoch 2, Step 8000: train loss 3.0072


9005it [07:03, 21.71it/s]

Epoch 2, Step 9000: train loss 4.9518


10005it [07:50, 22.99it/s]

Epoch 2, Step 10000: train loss 3.5184


11003it [08:38, 22.01it/s]

Epoch 2, Step 11000: train loss 3.4948


12004it [09:25, 21.15it/s]

Epoch 2, Step 12000: train loss 3.8980


13003it [10:13, 23.20it/s]

Epoch 2, Step 13000: train loss 4.3701


14005it [10:59, 21.96it/s]

Epoch 2, Step 14000: train loss 2.7115


15003it [11:45, 22.26it/s]

Epoch 2, Step 15000: train loss 2.7792


16004it [12:31, 22.38it/s]

Epoch 2, Step 16000: train loss 3.1593


17004it [13:17, 22.77it/s]

Epoch 2, Step 17000: train loss 1.7112


18003it [14:03, 20.93it/s]

Epoch 2, Step 18000: train loss 4.3856


19002it [14:50, 22.80it/s]

Epoch 2, Step 19000: train loss 2.8499


20005it [15:36, 23.42it/s]

Epoch 2, Step 20000: train loss 5.5107


21004it [16:22, 20.17it/s]

Epoch 2, Step 21000: train loss 4.0573


22003it [17:08, 21.07it/s]

Epoch 2, Step 22000: train loss 2.4762


23005it [17:54, 22.94it/s]

Epoch 2, Step 23000: train loss 5.1034


24004it [18:40, 22.55it/s]

Epoch 2, Step 24000: train loss 3.2621


25003it [19:27, 21.24it/s]

Epoch 2, Step 25000: train loss 4.4644


26004it [20:14, 21.93it/s]

Epoch 2, Step 26000: train loss 2.4353


27003it [21:01, 22.15it/s]

Epoch 2, Step 27000: train loss 4.9836


28005it [21:49, 22.61it/s]

Epoch 2, Step 28000: train loss 3.8126


29003it [22:36, 22.03it/s]

Epoch 2, Step 29000: train loss 5.3431


30005it [23:25, 21.81it/s]

Epoch 2, Step 30000: train loss 4.7577


31002it [24:13, 21.92it/s]

Epoch 2, Step 31000: train loss 4.9582


31858it [24:53, 21.33it/s]


Epoch 2: Average validation loss: 4.1636


5it [00:00, 20.73it/s]

Epoch 3, Step 0: train loss 2.7884


1003it [00:47, 20.42it/s]

Epoch 3, Step 1000: train loss 3.6525


2004it [01:34, 20.58it/s]

Epoch 3, Step 2000: train loss 4.5753


3004it [02:22, 20.35it/s]

Epoch 3, Step 3000: train loss 1.7767


4004it [03:11, 21.14it/s]

Epoch 3, Step 4000: train loss 2.0244


5005it [03:59, 22.23it/s]

Epoch 3, Step 5000: train loss 2.1214


6002it [04:47, 19.67it/s]

Epoch 3, Step 6000: train loss 3.4030


7004it [05:36, 20.41it/s]

Epoch 3, Step 7000: train loss 4.4977


8005it [06:24, 22.59it/s]

Epoch 3, Step 8000: train loss 6.1813


9003it [07:12, 20.85it/s]

Epoch 3, Step 9000: train loss 3.1783


10004it [07:59, 21.37it/s]

Epoch 3, Step 10000: train loss 3.9595


11002it [08:47, 19.63it/s]

Epoch 3, Step 11000: train loss 3.0762


12005it [09:36, 19.61it/s]

Epoch 3, Step 12000: train loss 3.5813


13003it [10:24, 22.44it/s]

Epoch 3, Step 13000: train loss 4.9545


14002it [11:13, 20.35it/s]

Epoch 3, Step 14000: train loss 2.6737


15003it [12:02, 21.66it/s]

Epoch 3, Step 15000: train loss 3.1092


16002it [12:51, 20.06it/s]

Epoch 3, Step 16000: train loss 3.3840


17004it [13:40, 22.62it/s]

Epoch 3, Step 17000: train loss 3.0379


18003it [14:29, 19.08it/s]

Epoch 3, Step 18000: train loss 0.0226


19003it [15:17, 20.64it/s]

Epoch 3, Step 19000: train loss 4.3848


20002it [16:05, 21.11it/s]

Epoch 3, Step 20000: train loss 1.8716


21003it [16:54, 22.74it/s]

Epoch 3, Step 21000: train loss 3.7701


22005it [17:42, 23.75it/s]

Epoch 3, Step 22000: train loss 3.5619


23003it [18:30, 20.94it/s]

Epoch 3, Step 23000: train loss 4.5402


24004it [19:18, 22.56it/s]

Epoch 3, Step 24000: train loss 4.0090


25004it [20:05, 22.77it/s]

Epoch 3, Step 25000: train loss 5.6327


26003it [20:53, 19.94it/s]

Epoch 3, Step 26000: train loss 2.8920


27005it [21:40, 19.74it/s]

Epoch 3, Step 27000: train loss 3.0405


28004it [22:28, 19.60it/s]

Epoch 3, Step 28000: train loss 4.4121


29004it [23:16, 19.80it/s]

Epoch 3, Step 29000: train loss 2.3838


30005it [24:03, 21.47it/s]

Epoch 3, Step 30000: train loss 3.8287


31004it [24:51, 20.99it/s]

Epoch 3, Step 31000: train loss 2.5565


31858it [25:32, 20.79it/s]


Epoch 3: Average validation loss: 3.8912


3it [00:00, 19.15it/s]

Epoch 4, Step 0: train loss 4.1716


1003it [00:47, 22.26it/s]

Epoch 4, Step 1000: train loss 0.5449


2003it [01:34, 22.37it/s]

Epoch 4, Step 2000: train loss 4.3051


3005it [02:22, 22.09it/s]

Epoch 4, Step 3000: train loss 1.6435


4005it [03:10, 21.67it/s]

Epoch 4, Step 4000: train loss 2.8632


5004it [03:58, 21.11it/s]

Epoch 4, Step 5000: train loss 3.8728


6004it [04:45, 21.32it/s]

Epoch 4, Step 6000: train loss 3.8528


7003it [05:32, 21.52it/s]

Epoch 4, Step 7000: train loss 2.8450


8006it [06:20, 20.96it/s]

Epoch 4, Step 8000: train loss 3.2799


9004it [07:07, 21.30it/s]

Epoch 4, Step 9000: train loss 3.0621


10005it [07:55, 21.95it/s]

Epoch 4, Step 10000: train loss 2.2351


11004it [08:43, 18.19it/s]

Epoch 4, Step 11000: train loss 2.4023


12003it [09:30, 22.05it/s]

Epoch 4, Step 12000: train loss 3.8411


13004it [10:17, 22.78it/s]

Epoch 4, Step 13000: train loss 4.0978


14005it [11:03, 22.80it/s]

Epoch 4, Step 14000: train loss 3.6562


15004it [11:50, 21.47it/s]

Epoch 4, Step 15000: train loss 3.5561


16003it [12:38, 18.38it/s]

Epoch 4, Step 16000: train loss 3.1934


17004it [13:26, 18.68it/s]

Epoch 4, Step 17000: train loss 2.6871


18004it [14:13, 22.05it/s]

Epoch 4, Step 18000: train loss 2.8129


19003it [15:01, 20.70it/s]

Epoch 4, Step 19000: train loss 3.2182


20004it [15:48, 19.71it/s]

Epoch 4, Step 20000: train loss 3.3022


21004it [16:36, 19.40it/s]

Epoch 4, Step 21000: train loss 3.0993


22003it [17:23, 21.42it/s]

Epoch 4, Step 22000: train loss 5.0465


23002it [18:11, 22.11it/s]

Epoch 4, Step 23000: train loss 2.2966


24003it [18:58, 22.05it/s]

Epoch 4, Step 24000: train loss 3.8719


25004it [19:47, 19.44it/s]

Epoch 4, Step 25000: train loss 2.9129


26005it [20:34, 21.02it/s]

Epoch 4, Step 26000: train loss 2.4584


27003it [21:21, 21.56it/s]

Epoch 4, Step 27000: train loss 2.2191


28003it [22:09, 21.12it/s]

Epoch 4, Step 28000: train loss 2.7575


29005it [22:56, 21.90it/s]

Epoch 4, Step 29000: train loss 2.2636


30004it [23:44, 20.99it/s]

Epoch 4, Step 30000: train loss 4.2812


31002it [24:33, 19.87it/s]

Epoch 4, Step 31000: train loss 3.5397


31858it [25:13, 21.05it/s]


Epoch 4: Average validation loss: 3.6509


3it [00:00, 21.33it/s]

Epoch 5, Step 0: train loss 2.7035


1005it [00:48, 21.41it/s]

Epoch 5, Step 1000: train loss 2.0843


2003it [01:34, 22.61it/s]

Epoch 5, Step 2000: train loss 2.3808


3004it [02:21, 21.70it/s]

Epoch 5, Step 3000: train loss 1.6234


4003it [03:07, 22.84it/s]

Epoch 5, Step 4000: train loss 3.5504


5005it [03:53, 22.60it/s]

Epoch 5, Step 5000: train loss 2.4630


6005it [04:39, 22.81it/s]

Epoch 5, Step 6000: train loss 1.2441


7003it [05:26, 22.43it/s]

Epoch 5, Step 7000: train loss 3.0137


8003it [06:12, 23.80it/s]

Epoch 5, Step 8000: train loss 1.7746


9005it [06:58, 21.40it/s]

Epoch 5, Step 9000: train loss 3.1410


10003it [07:44, 23.82it/s]

Epoch 5, Step 10000: train loss 3.4950


11002it [08:30, 20.58it/s]

Epoch 5, Step 11000: train loss 3.0709


12003it [09:16, 22.76it/s]

Epoch 5, Step 12000: train loss 3.9415


13005it [10:02, 22.16it/s]

Epoch 5, Step 13000: train loss 2.7819


14003it [10:48, 19.36it/s]

Epoch 5, Step 14000: train loss 1.9115


15005it [11:35, 20.91it/s]

Epoch 5, Step 15000: train loss 1.8702


16004it [12:22, 21.68it/s]

Epoch 5, Step 16000: train loss 0.0000


17003it [13:08, 22.35it/s]

Epoch 5, Step 17000: train loss 4.5237


18004it [13:56, 19.34it/s]

Epoch 5, Step 18000: train loss 4.0051


19003it [14:44, 20.17it/s]

Epoch 5, Step 19000: train loss 3.1783


20003it [15:31, 21.61it/s]

Epoch 5, Step 20000: train loss 0.3745


21003it [16:18, 22.06it/s]

Epoch 5, Step 21000: train loss 2.1835


22002it [17:05, 22.27it/s]

Epoch 5, Step 22000: train loss 3.9661


23004it [17:53, 23.00it/s]

Epoch 5, Step 23000: train loss 2.2549


24003it [18:40, 20.39it/s]

Epoch 5, Step 24000: train loss 3.4060


25003it [19:27, 21.45it/s]

Epoch 5, Step 25000: train loss 3.3393


26003it [20:14, 20.44it/s]

Epoch 5, Step 26000: train loss 3.5486


27006it [21:02, 20.90it/s]

Epoch 5, Step 27000: train loss 1.7173


28003it [21:49, 22.07it/s]

Epoch 5, Step 28000: train loss 2.8431


29005it [22:36, 22.83it/s]

Epoch 5, Step 29000: train loss 0.8788


30005it [23:23, 22.31it/s]

Epoch 5, Step 30000: train loss 3.2101


31002it [24:10, 22.13it/s]

Epoch 5, Step 31000: train loss 1.6157


31858it [24:50, 21.38it/s]


Epoch 5: Average validation loss: 3.4689


3it [00:00, 23.53it/s]

Epoch 6, Step 0: train loss 2.0097


1003it [00:46, 22.06it/s]

Epoch 6, Step 1000: train loss 1.7322


2003it [01:33, 22.58it/s]

Epoch 6, Step 2000: train loss 3.0186


3005it [02:21, 21.11it/s]

Epoch 6, Step 3000: train loss 2.7437


4006it [03:09, 23.39it/s]

Epoch 6, Step 4000: train loss 1.9773


5005it [03:55, 23.82it/s]

Epoch 6, Step 5000: train loss 2.8421


6004it [04:42, 21.93it/s]

Epoch 6, Step 6000: train loss 2.1961


7005it [05:29, 19.01it/s]

Epoch 6, Step 7000: train loss 3.1622


8004it [06:16, 20.73it/s]

Epoch 6, Step 8000: train loss 2.0444


9005it [07:03, 22.17it/s]

Epoch 6, Step 9000: train loss 3.5451


10003it [07:50, 22.76it/s]

Epoch 6, Step 10000: train loss 2.8807


11004it [08:37, 22.36it/s]

Epoch 6, Step 11000: train loss 2.1773


12005it [09:24, 20.69it/s]

Epoch 6, Step 12000: train loss 2.1286


13005it [10:11, 22.18it/s]

Epoch 6, Step 13000: train loss 3.2382


14004it [10:58, 22.61it/s]

Epoch 6, Step 14000: train loss 3.0481


15004it [11:46, 23.00it/s]

Epoch 6, Step 15000: train loss 1.6696


16003it [12:33, 23.43it/s]

Epoch 6, Step 16000: train loss 2.9568


17003it [13:20, 19.46it/s]

Epoch 6, Step 17000: train loss 2.8666


18004it [14:07, 19.12it/s]

Epoch 6, Step 18000: train loss 3.0819


19005it [14:55, 21.32it/s]

Epoch 6, Step 19000: train loss 3.6096


20003it [15:43, 21.66it/s]

Epoch 6, Step 20000: train loss 2.4326


21002it [16:30, 20.87it/s]

Epoch 6, Step 21000: train loss 1.6621


22005it [17:18, 22.37it/s]

Epoch 6, Step 22000: train loss 3.8232


23004it [18:05, 21.41it/s]

Epoch 6, Step 23000: train loss 3.9826


24004it [18:52, 19.45it/s]

Epoch 6, Step 24000: train loss 4.5553


25004it [19:39, 19.71it/s]

Epoch 6, Step 25000: train loss 4.0141


26005it [20:26, 23.13it/s]

Epoch 6, Step 26000: train loss 2.9665


27003it [21:12, 22.77it/s]

Epoch 6, Step 27000: train loss 2.0940


28003it [21:59, 21.28it/s]

Epoch 6, Step 28000: train loss 2.4988


29002it [22:46, 19.98it/s]

Epoch 6, Step 29000: train loss 3.2394


30004it [23:33, 20.83it/s]

Epoch 6, Step 30000: train loss 2.7844


31004it [24:20, 22.03it/s]

Epoch 6, Step 31000: train loss 2.6963


31858it [25:00, 21.24it/s]


Epoch 6: Average validation loss: 3.3594


4it [00:00, 11.16it/s]

Epoch 7, Step 0: train loss 2.1367


1004it [00:45, 20.85it/s]

Epoch 7, Step 1000: train loss 2.7909


2005it [01:32, 23.32it/s]

Epoch 7, Step 2000: train loss 1.9859


3002it [02:19, 23.13it/s]

Epoch 7, Step 3000: train loss 2.4767


4004it [03:06, 20.54it/s]

Epoch 7, Step 4000: train loss 1.0024


5005it [03:53, 21.27it/s]

Epoch 7, Step 5000: train loss 2.0491


6005it [04:41, 21.54it/s]

Epoch 7, Step 6000: train loss 2.0766


7005it [05:28, 21.45it/s]

Epoch 7, Step 7000: train loss 2.9436


8005it [06:15, 21.88it/s]

Epoch 7, Step 8000: train loss 2.2667


9006it [07:02, 22.10it/s]

Epoch 7, Step 9000: train loss 2.6662


10004it [07:48, 23.53it/s]

Epoch 7, Step 10000: train loss 2.9505


11005it [08:35, 21.18it/s]

Epoch 7, Step 11000: train loss 1.2905


12005it [09:22, 22.55it/s]

Epoch 7, Step 12000: train loss 2.2389


13005it [10:09, 22.57it/s]

Epoch 7, Step 13000: train loss 2.3344


14005it [10:56, 20.44it/s]

Epoch 7, Step 14000: train loss 3.4372


15004it [11:44, 19.56it/s]

Epoch 7, Step 15000: train loss 4.1397


16005it [12:32, 21.56it/s]

Epoch 7, Step 16000: train loss 1.1829


17004it [13:19, 22.80it/s]

Epoch 7, Step 17000: train loss 1.4616


18001it [14:07, 19.71it/s]

Epoch 7, Step 18000: train loss 3.3277


19004it [14:56, 21.12it/s]

Epoch 7, Step 19000: train loss 1.4033


20003it [15:46, 20.64it/s]

Epoch 7, Step 20000: train loss 2.5870


21003it [16:34, 21.64it/s]

Epoch 7, Step 21000: train loss 2.5717


22004it [17:22, 21.75it/s]

Epoch 7, Step 22000: train loss 2.4704


23006it [18:09, 22.22it/s]

Epoch 7, Step 23000: train loss 2.2389


24003it [18:57, 22.20it/s]

Epoch 7, Step 24000: train loss 1.1911


25004it [19:44, 21.80it/s]

Epoch 7, Step 25000: train loss 2.8738


26002it [20:32, 22.31it/s]

Epoch 7, Step 26000: train loss 2.5996


27004it [21:19, 19.50it/s]

Epoch 7, Step 27000: train loss 2.8606


28004it [22:07, 23.14it/s]

Epoch 7, Step 28000: train loss 1.3672


29004it [22:55, 21.91it/s]

Epoch 7, Step 29000: train loss 2.3058


30002it [23:42, 20.31it/s]

Epoch 7, Step 30000: train loss 2.4266


31005it [24:30, 21.91it/s]

Epoch 7, Step 31000: train loss 2.5454


31858it [25:12, 21.07it/s]


Epoch 7: Average validation loss: 3.3097


3it [00:00, 21.03it/s]

Epoch 8, Step 0: train loss 3.4038


1003it [00:46, 21.94it/s]

Epoch 8, Step 1000: train loss 2.6859


2002it [01:32, 21.83it/s]

Epoch 8, Step 2000: train loss 1.6070


3004it [02:19, 23.47it/s]

Epoch 8, Step 3000: train loss 1.5500


4003it [03:08, 18.83it/s]

Epoch 8, Step 4000: train loss 1.7096


5004it [03:55, 20.13it/s]

Epoch 8, Step 5000: train loss 0.4537


6003it [04:42, 22.97it/s]

Epoch 8, Step 6000: train loss 2.3356


7005it [05:29, 21.77it/s]

Epoch 8, Step 7000: train loss 0.7318


8005it [06:16, 22.80it/s]

Epoch 8, Step 8000: train loss 1.9523


9005it [07:03, 20.50it/s]

Epoch 8, Step 9000: train loss 1.8091


10005it [07:51, 20.70it/s]

Epoch 8, Step 10000: train loss 2.2462


11005it [08:37, 22.24it/s]

Epoch 8, Step 11000: train loss 0.9789


12004it [09:24, 22.39it/s]

Epoch 8, Step 12000: train loss 2.3640


13003it [10:11, 20.23it/s]

Epoch 8, Step 13000: train loss 2.3700


14003it [10:58, 20.75it/s]

Epoch 8, Step 14000: train loss 1.1029


15002it [11:45, 22.41it/s]

Epoch 8, Step 15000: train loss 1.8316


16004it [12:32, 21.36it/s]

Epoch 8, Step 16000: train loss 1.7499


17004it [13:20, 19.85it/s]

Epoch 8, Step 17000: train loss 0.4829


18005it [14:06, 20.71it/s]

Epoch 8, Step 18000: train loss 1.8761


19004it [14:53, 19.35it/s]

Epoch 8, Step 19000: train loss 2.1925


20005it [15:40, 22.21it/s]

Epoch 8, Step 20000: train loss 2.7347


21003it [16:27, 20.49it/s]

Epoch 8, Step 21000: train loss 3.6483


22005it [17:16, 22.12it/s]

Epoch 8, Step 22000: train loss 0.7199


23003it [18:05, 20.55it/s]

Epoch 8, Step 23000: train loss 2.6264


24005it [18:54, 22.14it/s]

Epoch 8, Step 24000: train loss 3.7633


25004it [19:43, 23.04it/s]

Epoch 8, Step 25000: train loss 2.7156


26005it [20:31, 19.99it/s]

Epoch 8, Step 26000: train loss 1.3102


27003it [21:20, 20.97it/s]

Epoch 8, Step 27000: train loss 0.6193


28005it [22:07, 20.79it/s]

Epoch 8, Step 28000: train loss 4.4250


29004it [22:55, 20.45it/s]

Epoch 8, Step 29000: train loss 1.7300


30003it [23:42, 21.95it/s]

Epoch 8, Step 30000: train loss 3.4515


31005it [24:30, 19.74it/s]

Epoch 8, Step 31000: train loss 2.6030


31858it [25:10, 21.09it/s]


Epoch 8: Average validation loss: 3.2955


3it [00:00, 21.39it/s]

Epoch 9, Step 0: train loss 1.6379


1004it [00:48, 19.61it/s]

Epoch 9, Step 1000: train loss 1.5450


2004it [01:36, 21.85it/s]

Epoch 9, Step 2000: train loss 1.5716


3002it [02:23, 21.55it/s]

Epoch 9, Step 3000: train loss 2.0977


4005it [03:11, 20.16it/s]

Epoch 9, Step 4000: train loss 1.3468


5004it [03:59, 21.52it/s]

Epoch 9, Step 5000: train loss 2.4424


6004it [04:46, 22.77it/s]

Epoch 9, Step 6000: train loss 2.3059


7003it [05:34, 22.36it/s]

Epoch 9, Step 7000: train loss 1.4233


8003it [06:23, 21.44it/s]

Epoch 9, Step 8000: train loss 3.0294


9004it [07:11, 21.73it/s]

Epoch 9, Step 9000: train loss 1.5444


10003it [07:59, 21.69it/s]

Epoch 9, Step 10000: train loss 2.8141


11004it [08:48, 20.83it/s]

Epoch 9, Step 11000: train loss 3.6105


12005it [09:38, 19.41it/s]

Epoch 9, Step 12000: train loss 1.7046


13003it [10:27, 21.40it/s]

Epoch 9, Step 13000: train loss 1.2592


14004it [11:16, 21.55it/s]

Epoch 9, Step 14000: train loss 0.1051


15003it [12:03, 23.26it/s]

Epoch 9, Step 15000: train loss 3.1299


16004it [12:50, 22.60it/s]

Epoch 9, Step 16000: train loss 3.6670


17003it [13:38, 22.53it/s]

Epoch 9, Step 17000: train loss 2.4242


18005it [14:26, 21.90it/s]

Epoch 9, Step 18000: train loss 2.1269


19003it [15:14, 21.02it/s]

Epoch 9, Step 19000: train loss 1.7603


20003it [16:03, 20.48it/s]

Epoch 9, Step 20000: train loss 0.6802


21004it [16:51, 23.70it/s]

Epoch 9, Step 21000: train loss 2.8157


22003it [17:39, 21.34it/s]

Epoch 9, Step 22000: train loss 2.8838


23003it [18:26, 21.32it/s]

Epoch 9, Step 23000: train loss 2.6814


24004it [19:14, 21.44it/s]

Epoch 9, Step 24000: train loss 1.9872


25005it [20:01, 22.56it/s]

Epoch 9, Step 25000: train loss 1.3917


26005it [20:49, 19.88it/s]

Epoch 9, Step 26000: train loss 2.0416


27003it [21:37, 20.80it/s]

Epoch 9, Step 27000: train loss 2.0826


28003it [22:25, 20.69it/s]

Epoch 9, Step 28000: train loss 1.3089


29003it [23:13, 20.31it/s]

Epoch 9, Step 29000: train loss 3.7242


30004it [24:01, 22.42it/s]

Epoch 9, Step 30000: train loss 3.2582


31004it [24:49, 20.85it/s]

Epoch 9, Step 31000: train loss 0.1406


31858it [25:30, 20.82it/s]


Epoch 9: Average validation loss: 3.2572


In [53]:
# generate from the model
seed=[[vocab['he']]]
context = torch.tensor(seed, dtype=torch.long, device=device)
output_tokens=m.generate(context, max_new_tokens=200)[0].tolist()
res=''
for i in output_tokens:
  if inv_vocab[i]=='<pad>':
    continue
  res+=inv_vocab[i]
  res+=' '
print(res)

he may ours. your broken in before, sorrow lucentio. 
