In [None]:
with open("data/input.txt", "r") as f:
    data = f.read().lower()

In [None]:
from tokenizers import Tokenizer, models, decoders, trainers, tools, pre_tokenizers

tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(vocab_size=8192)

In [None]:
tokenizer.train(["data/input.txt"], trainer=trainer)
tokenizer.get_vocab_size()
viz = tools.EncodingVisualizer(tokenizer)
viz(data[:512])

In [None]:
# train val split
import torch
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        text = self.data[idx]
        encoding = self.tokenizer.encode(text)
        input_ids = torch.tensor(encoding.ids[:-1])
        target_ids = torch.tensor(encoding.ids[1:])
        return input_ids, target_ids
        
    
train_val = 0.8
train_data = data[:int(len(data)*train_val)]
val_data = data[int(len(data)*train_val):]

train_dataset = TextDataset(train_data, tokenizer)
val_dataset = TextDataset(val_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

In [None]:
from torch import nn
from tqdm import tqdm
from nn_zoo.models.components.attention import SelfAttention


class Model(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(Model, self).__init__()
        self.positional_embedding = nn.Embedding(vocab_size, embed_size)
        self.blocks = nn.Sequential(
            *[SelfAttention(embed_size, hidden_size) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x):
        x = self.positional_embedding(x)
        x = self.blocks(x)
        x = self.fc(x)
        return x
    
model = Model(tokenizer.get_vocab_size(), 512, 512, 6)
model = model.to("mps")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    for x, y in tqdm(train_loader):
        optimizer.zero_grad()
        x, y = x.long().to("mps"), y.long().to("mps")
        y_hat = model(x)
        loss = criterion(y_hat.view(-1, y_hat.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()
        
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for x, y in val_loader:
            x, y = x.to("mps"), y.to("mps")
            y_hat = model(x)
            loss = criterion(y_hat.view(-1, y_hat.size(-1)), y.view(-1))
            total_loss += loss.item()
        print(f"Epoch {epoch} Loss: {total_loss/len(val_loader)}")
        