In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

print(f"Torch version {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA Available: {torch.cuda.is_available()}")

Torch version 2.7.1+cu118
True


In [104]:
from config import ROOT_DIR
import os
import math
import importlib
from torch.utils.data import DataLoader
import datasets.alice_in_wonderland
importlib.reload(datasets.alice_in_wonderland)

dataset_train = None
dataset_test = None

dataset_train = datasets.alice_in_wonderland.AliceInWonderlandDataset(train=True)
tokenizer = dataset_train.tokenizer
dataset_test = datasets.alice_in_wonderland.AliceInWonderlandDataset(train=False, tokenizer=tokenizer)

log_vocab = math.log(tokenizer.get_vocab_size())
print(f"Log of vocab {log_vocab}")

class SimpleTextRnn(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(SimpleTextRnn, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        out, _ = self.rnn(x)
        last_step = out[:, -1, :]
        logits = self.fc(last_step)
        return logits

train_dataloader = DataLoader(dataset_train, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=True)

model = SimpleTextRnn(vocab_size=tokenizer.get_vocab_size(), embed_size=50, hidden_size=64).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

best_val_loss = float('inf')
best_epoch = 0
train_epoch_count = 20


if not os.path.exists(os.path.join(ROOT_DIR, '.models')):
    os.mkdir(os.path.join(ROOT_DIR, '.models'))

torch.save(model.state_dict(), os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_untrained.pth'))

for epoch in range(train_epoch_count+1):
    model.train()
    for batch_X, batch_Y in train_dataloader:
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_Y)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        val_losses = []
        for val_X, val_Y in test_dataloader:
            val_X, val_Y = val_X.to(device), val_Y.to(device)
            val_output = model(val_X)
            val_loss = criterion(val_output, val_Y)
            val_losses.append(val_loss.item())
        avg_val_loss = sum(val_losses) / len(val_losses)
    print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_trained.pth'))
        best_epoch = epoch
        print(f"New best model at epoch {epoch} | val_loss: {avg_val_loss:.4f}")

print(f"Using model from epoch {best_epoch} | val_loss: {best_val_loss:.4f} | entropy vocab: {log_vocab:.4f}")

Log of vocab 4.605170185988092
Epoch 0 | Train Loss: 3.9820 | Val Loss: 3.5189
New best model at epoch 0 | val_loss: 3.5189
Epoch 1 | Train Loss: 3.2766 | Val Loss: 3.1393
New best model at epoch 1 | val_loss: 3.1393
Epoch 2 | Train Loss: 2.7256 | Val Loss: 2.9497
New best model at epoch 2 | val_loss: 2.9497
Epoch 3 | Train Loss: 2.8989 | Val Loss: 2.8320
New best model at epoch 3 | val_loss: 2.8320
Epoch 4 | Train Loss: 2.8700 | Val Loss: 2.7464
New best model at epoch 4 | val_loss: 2.7464
Epoch 5 | Train Loss: 2.8895 | Val Loss: 2.6862
New best model at epoch 5 | val_loss: 2.6862
Epoch 6 | Train Loss: 1.3958 | Val Loss: 2.6364
New best model at epoch 6 | val_loss: 2.6364
Epoch 7 | Train Loss: 2.7366 | Val Loss: 2.5970
New best model at epoch 7 | val_loss: 2.5970
Epoch 8 | Train Loss: 2.2786 | Val Loss: 2.5653
New best model at epoch 8 | val_loss: 2.5653
Epoch 9 | Train Loss: 1.7157 | Val Loss: 2.5474
New best model at epoch 9 | val_loss: 2.5474
Epoch 10 | Train Loss: 2.3107 | Val Los

In [111]:
from tokenizers.decoders import Metaspace as MetaspaceDecoder
import torch.nn.functional as F

tokenizer.decoder = MetaspaceDecoder(replacement=" ", prepend_scheme="never")

seq_length = dataset_train.seq_length
def generate_text(input, num_tokens, path):
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input = input.lower()

    input_ids = tokenizer.encode(input).ids
    leftover_prefix = []
    if len(input_ids) > seq_length:
        leftover_prefix = input_ids[:-seq_length]
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    for _ in range(num_tokens):
        input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = F.softmax(output / 0.3, dim=1)
            next_id = torch.multinomial(probabilities, num_samples=1).item()

        generated.append(next_id)
    return tokenizer.decode(leftover_prefix + generated).replace("▁", " ")

result_untrained = generate_text("let's go to the", 100, os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_untrained.pth'))
result_trained = generate_text("let's go to the", 100, os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_trained.pth'))
print(f"Untrained sample: {result_untrained}\n")
print(f"Trained sample: {result_trained}")

Untrained sample:  let's go to theer b;?on:llhaowliceve qu. wouf ha p b, smrere t tmes the[hm it fhant qu qu alicest y' oha. o w] querw)es m youceinqou andin nm n mwlind quit: w.]'s skenn hle nhatd bx itheb you queshi't h,haler

Trained sample:  let's go to the room!   alice  i don't tell is the doors at least as if you don't tell me the stupose wid you playing at the queen off you don't tell the ready everybody excame.   alice 


In [114]:
# Show tokenization for debugging
prompt = "when the"
encoded = tokenizer.encode(prompt, add_special_tokens=False)
token_ids = encoded.ids
tokens = [tokenizer.decode([t]) for t in token_ids]
print(tokens)


['▁w', 'he', 'n', '▁the']
