In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
torch.cuda.is_available()

False

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    data = f.read()

In [5]:
uc = set(data)

In [6]:
n = 0
stoi = {}
itos = {}
for c in uc:
    stoi[c] = n
    itos[n] = c
    n+=1

In [7]:
data_lst = []
for d in data:
    data_lst.append(stoi[d])

In [8]:
len(data_lst)

1115394

In [9]:
data_lst = torch.tensor(data_lst)

In [10]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, block_size, emb_size) -> None:
        super(EmbeddingLayer, self).__init__()

        self.vocab_emd = nn.Embedding(vocab_size, emb_size)
        self.pos_emd = nn.Embedding(block_size, emb_size)

    def forward(self, x):
        # if I want the input be variable ie as small as 1 to block_size
        seq_len = x.size(1)
        vocab_out = self.vocab_emd(x)
        pos_out = self.pos_emd(torch.arange(seq_len, device=x.device))
        emb_out = vocab_out + pos_out

        return emb_out

In [11]:
class AttentionHead(nn.Module):
    def __init__(self, emb_size, h_size) -> None:
        super(AttentionHead, self).__init__()
        
        self.emb_size = emb_size
        self.h_size = h_size
        
        
        self.key = nn.Linear(self.emb_size, self.h_size)
        self.query = nn.Linear(self.emb_size, self.h_size)
        self.value = nn.Linear(self.emb_size, self.h_size)
    
    def forward(self, emb_out):

        k_out = self.key(emb_out)
        q_out = self.query(emb_out)
        v_out = self.value(emb_out)

        kq_out = q_out @ k_out.transpose(-2, -1)/k_out.shape[-1]
        

        kq_out = torch.tril(kq_out)
        kq_out = torch.masked_fill(kq_out, kq_out==0, float("-inf"))
        
        kq_out = F.softmax(kq_out, dim = -1)

        out = kq_out @ v_out


        return out

In [12]:
class MHAttentionBlock(nn.Module):
    def __init__(self, n_head, emb_size) -> None:
        super(MHAttentionBlock, self).__init__()
        
        self.emb_size = emb_size
        self.h_size = int(self.emb_size/n_head)

        self.attention_heads = [AttentionHead(self.emb_size, self.h_size) for _ in range(0, n_head) ]

        self.out_layer = nn.Linear(self.emb_size, self.emb_size)
    
    def forward(self, emb_out ):

        attention_out = [h(emb_out) for h in self.attention_heads]
        final_attention = torch.concat(attention_out, dim=-1)
        out = self.out_layer(final_attention)

        return out


        

In [13]:
class FeedForwardLayer(nn.Module):
    def __init__(self, emb_size) -> None:
        super(FeedForwardLayer, self).__init__()

        self.layer1 = nn.Linear(emb_size, 4 * emb_size)
        self.layer2 = nn.Linear(4 * emb_size, emb_size)

    def forward(self, attention_out ):
        tmp_out = self.layer1(attention_out)
        out = self.layer2(tmp_out)

        return out

In [14]:
class TModel(nn.Module):
    def __init__(self, vocab_size, block_size, emb_size, n_head):
        super(TModel, self).__init__()
        
        self.embedding = EmbeddingLayer(vocab_size, block_size, emb_size)

        self.multiheadattention = MHAttentionBlock(n_head, emb_size)
        
        self.ffl = FeedForwardLayer(emb_size)

        self.dlm = nn.Linear(emb_size, vocab_size)

    def forward(self, x, y=None):
        
        emb_out = self.embedding(x)
        
        attention_out = self.multiheadattention(emb_out)

        out = self.ffl(attention_out)
        
        logits = self.dlm(out)


        
        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            print(logits.shape, y.shape)
            logits = logits.view(B*T, -1)
            y = y.view(B*T)
            loss = nn.functional.cross_entropy(logits, y)


        return logits, loss
    
    def generate(self, idx, max_new_tokens, block_size):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # get the last token prediction
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [15]:
model = TModel(65, 8, 100, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [16]:
for i in range(100):
    ix = torch.randint(len(data_lst) - 8, (4,))
    x = torch.stack([data_lst[i:i+8] for i in ix])
    y = torch.stack([data_lst[i+1:i+8+1] for i in ix])
    logits, loss = model.forward(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    print(loss)


torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(4.1366, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(4.1386, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(4.0341, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.9417, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.8944, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.8823, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.8138, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.8178, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.6599, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.2991, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.5648, grad_fn=<NllLossBackward0>)
torch.Size([4, 8, 65]) torch.Size([4, 8])
tensor(3.7495, grad_fn=

In [17]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

context = torch.tensor([encode("a")], dtype=torch.long)
print(decode(model.generate(context, 5, 8)[0].tolist()))

a:A
e 
