In [None]:
import torch
from classes.Head import Head
from torch.nn import functional as F

with open('../nano-gpt-messages.txt', 'r', encoding='utf-8') as f:
    text = f.read()

vocab = sorted(set(text))

vocab_size = len(vocab)

stoi = { ch:i for i,ch in enumerate(vocab) } # dict comperhension eg. A:0
itos = { i:ch for i,ch in enumerate(vocab) } # dict comperhension eg. 0:A

def encode(input: str):
    return [stoi[c] for c in input]

def decode(input: list[int]):
    return ''.join([itos[i] for i in input])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * (len(data)))
train_data = data[:n]
val_data = data[n:]

print(len(data))

import torch.nn as nn
n_embd = 8
block_size = 4

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd, n_embd, block_size)
        self.lm_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(idx)
        x = tok_emb + pos_emb
        x = self.sa_head.forward(x)
        logits = self.lm_head.forward(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx: str, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self.forward(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            c = torch.multinomial(probs, num_samples=1)
            idx.conc(idx, c)

        return idx




314937
