In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.nn import functional as F
import pickle

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device('cpu')
g = torch.Generator().manual_seed(42)

In [3]:
class DumbDataset(Dataset):
    """ 
    Dumb dataset that generates random x and sorted y
    """

    def __init__(self, length=6, num_digits=3):
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    

    def __getitem__(self, idx):
        inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]
        
        return inp, sol

In [4]:
# print an example instance of the dataset
train_ds = DumbDataset()
x, y = train_ds[0]
for a, b in zip(x,y):
    print(int(a),int(b))

2 0
0 2
2 2
2 2
2 2
2 2


In [11]:

class DecoderLayer(nn.Module):
    def __init__(self, model_size:int, nb_heads: int=1, dropout: float=0., bias:bool=True, mlp_factor=4 ) -> None:
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(model_size, device=device)
        self.attn = nn.MultiheadAttention(embed_dim=model_size, num_heads=nb_heads, 
                                          dropout=dropout, bias=bias, batch_first=True, device=device )
        
        self.mlp1 = nn.Linear(model_size, mlp_factor * model_size, bias=True, device=device)
        self.mlp2 = nn.Linear(mlp_factor * model_size, model_size, bias=True, device=device)
        self.activation = torch.nn.GELU()
        self.layer_norm2 = nn.LayerNorm(model_size, device=device)
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b,l,d = x.size()
        norm_x = self.layer_norm1(x)
        #mask = unidirectional_mask(seq_len=l)
        attn_x, attn_w = self.attn(norm_x, norm_x, norm_x)
        attn_x = x * attn_x
        
        norm_attn_x = self.layer_norm2(attn_x)
        lin1 = self.activation(self.mlp1(norm_attn_x))
        x = x + self.dropout(self.mlp2(lin1))
        return x


class MyGPT(nn.Module):
    
    def __init__(self, vocab_size: int, seq_len: int, model_size: int = 20, 
                 nb_layers:int=1, nb_heads:int =1, dropout: float=0., device=torch.device('cpu')) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.model_size = model_size
        
        self.device = device
        
        self.tok_emb = nn.Embedding(vocab_size, embedding_dim=model_size, device=device)
        pos = torch.arange(0, seq_len, dtype=torch.int).unsqueeze(0) # shape (1, max_seq_len)
        self.register_buffer('pos', pos)
        self.pos_emb = nn.Embedding(seq_len, model_size, device=device)
        self.layers = nn.ModuleList([DecoderLayer(model_size=model_size, nb_heads=nb_heads, dropout=dropout, bias=True) for i in range(nb_layers)])
       
        self.lin = nn.Linear(model_size, vocab_size, device=device)

    def forward(self, x):
        b, l = x.size()
        tok_emb = self.tok_emb(x)
        #pos = torch.arange(0, self.seq_len, dtype=torch.int).unsqueeze(0)
        
        x = tok_emb + self.pos_emb(self.pos)
        for layer in self.layers:
            x = layer(x)
        
        logits = self.lin(x)
        
        return logits 
        


In [12]:
class Trainer:
    
    def __init__(self, model, optimizer, device=torch.device('cpu'), max_iter=10):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.iter_num = 0
        self.max_iter = max_iter
        
    
    def run(self, train_ds):
        model= self.model
        # setup the optimizer
        

        # setup the dataloader
        train_loader = DataLoader(train_ds)
        model.train()
        data_iter = iter(train_loader)
        while True:

            # fetch the next batch (x, y) and re-init iterator if needed
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                batch = next(data_iter)
            batch = [t.to(self.device) for t in batch]
            x, y = batch

            # forward the model
            model.train()
            logits = model(x)
            #loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), y.view(-1))

            # backprop and update the parameters
            model.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()

            self.iter_num += 1
            print(self.iter_num)

            # termination conditions
            if self.max_iter is not None and self.iter_num >= self.max_iter:
                break
            
vocab_size = 3
seq_len = 11
model_size = 20
model = MyGPT(vocab_size, seq_len, model_size, nb_layers=1, nb_heads=1, device=device)            
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

trainer = Trainer(model, optimizer=optimizer, device=device, max_iter=5)
train_ds = DumbDataset(length=seq_len, num_digits=vocab_size)
trainer.run(train_ds)



1
2
3
4
5
