In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tokenizer import Tokenizer 
import utils
from config import vocab
from preprocess import data_dict, y
import random

torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size, learning_rate, epochs, block_size, vocab_size = utils.get_train_params()
embedding_dim = 64

class Transformer(nn.Module):
    
    def __init__ (self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding_table = nn.Embedding(block_size, embedding_dim)
        self.sa_head = SelfAttention(embedding_dim)
        self.ln_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, input, target):
        B, T = idx.shape
        #idx and target are both (B,T) tensors
        tok_embed = self.token_embedding_table(idx) #B,T,C
        pos_embed = self.position_embedding_table(torch.arange(T, device=device)) #T,C
        x = tok_embed + pos_embed # B,T,C
        x = self.sa_head(x)
        logits = self.ln_head(x) # B, T, vocab size
        loss = torch.nn.functional.l1_loss(input, target, size_average=None, reduce=None, reduction='mean')
        return logits


class SelfAttention(nn.Module):

    def __init__ (self, head_size):
        super().__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

    def forward(self, input):
        B,T,C = input.shape
        k = self.key(input) #(B,T,16)
        q = self.query(input) #(B,T,16)
        wei = q @ k.transpose(-2, -1) * head_size**-0.5 #(B,T,16)@(B,16,T)-->(B,T,T)
        wei = F.softmax(wei, dim=1)
        v = self.value(input) #(B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out



model = Transformer()
m = model.to(device)

print(sum(p.numel() for p in m.parameters()), 'M parameters')

tokenizer = Tokenizer(vocab)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)


for iters in range(epochs):

    keys = list(data_dict.keys())
    rndm_key = random.choice(keys)
    input = torch.tensor(tokenizer.encode(data_dict[rndm_key]['Text'], block_size), dtype=torch.float32, device=device)
    target = y[rndm_key]
    print(input, input.dtype, input.shape)
    print(target, target.dtype, target.shape)
    break

#block_size, vocab_size




194609 M parameters
tensor([ 820.,   19., 1112.,   19.,   88.,    3.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.], device='cuda:0') torch.float32 torch.Size([41])
tensor([[-33.6867, -31.9190, -32.2245,  ..., -80.0000, -80.0000, -80.0000],
        [-27.3985, -28.2982, -31.5363,  ..., -80.0000, -80.0000, -80.0000],
        [-23.9671, -29.0446, -40.8727,  ..., -80.0000, -80.0000, -80.0000],
        ...,
        [-58.9457, -60.9604, -58.6698,  ..., -80.0000, -80.0000, -80.0000],
        [-66.7848, -68.3981, -67.1611,  ..., -80.0000, -80.0000, -80.0000],
        [-70.9027, -76.8588, -80.0000,  ..., -80.0000, -80.0000, -80.0000]]) torch.float32 torch.Size([128, 302])
