In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

print(f"Torch version {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA Available: {torch.cuda.is_available()}")

Torch version 2.7.1+cu118
CUDA Available: True


In [41]:
from config import ROOT_DIR
import os
import math
import importlib
from datetime import datetime
from torch.utils.data import DataLoader
import datasets.alice_in_wonderland
importlib.reload(datasets.alice_in_wonderland)

# hyperparameters
embedding_size = 50
hidden_size = 64
lr = 0.001
seq_length = 5
vocab_size = 100

dataset_train = None
dataset_test = None

dataset_train = datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=True)
tokenizer = dataset_train.tokenizer
dataset_test = datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=False, tokenizer=tokenizer)

log_vocab = math.log(tokenizer.get_vocab_size())
print(f"Log of vocab {log_vocab}")

class SimpleTextRnn(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(SimpleTextRnn, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        out, _ = self.rnn(x)
        last_step = out[:, -1, :]
        logits = self.fc(last_step)
        return logits

train_dataloader = DataLoader(dataset_train, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=True)

model = SimpleTextRnn(vocab_size=tokenizer.get_vocab_size(), embed_size=embedding_size, hidden_size=hidden_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

best_val_loss = float('inf')
best_epoch = 0
train_epoch_count = 20

if not os.path.exists(os.path.join(ROOT_DIR, '.models')):
    os.mkdir(os.path.join(ROOT_DIR, '.models'))

trained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_trained.pth')
untrained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_untrained.pth')

torch.save(model.state_dict(), untrained_path)

for epoch in range(train_epoch_count+1):
    model.train()
    for batch_X, batch_Y in train_dataloader:
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_Y)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        val_losses = []
        for val_X, val_Y in test_dataloader:
            val_X, val_Y = val_X.to(device), val_Y.to(device)
            val_output = model(val_X)
            val_loss = criterion(val_output, val_Y)
            val_losses.append(val_loss.item())
        avg_val_loss = sum(val_losses) / len(val_losses)
    print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), trained_path)
        best_epoch = epoch
        print(f"New best model at epoch {epoch} | val_loss: {avg_val_loss:.4f}")

print(f"Using model from epoch {best_epoch} | val_loss: {best_val_loss:.4f} | entropy vocab: {log_vocab:.4f}")


current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
with open(os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_training_log.txt'), 'a') as f:
    f.write(f"{formatted_datetime}: Train result\n")
    f.write(f"Vocab size: {tokenizer.get_vocab_size()}\n")
    f.write(f"Entropy of vocab: {log_vocab:.4f}\n")
    f.write(f"Embedding size: {embedding_size}\n")
    f.write(f"Hidden size: {hidden_size}\n")
    f.write(f"Seq length: {dataset_train.seq_length}\n")
    f.write(f"Train epochs: {train_epoch_count}\n")
    f.write(f"Learning rate: {lr}\n")
    f.write(f"Best epoch: {best_epoch}\n")
    f.write(f"Best val loss: {best_val_loss:.4f}\n")
    f.write("***\n")

Log of vocab 4.605170185988092
Epoch 0 | Train Loss: 3.1294 | Val Loss: 3.0685
New best model at epoch 0 | val_loss: 3.0685
Epoch 1 | Train Loss: 2.5324 | Val Loss: 2.7666
New best model at epoch 1 | val_loss: 2.7666
Epoch 2 | Train Loss: 2.1445 | Val Loss: 2.6188
New best model at epoch 2 | val_loss: 2.6188
Epoch 3 | Train Loss: 2.2889 | Val Loss: 2.5225
New best model at epoch 3 | val_loss: 2.5225
Epoch 4 | Train Loss: 2.2487 | Val Loss: 2.4561
New best model at epoch 4 | val_loss: 2.4561
Epoch 5 | Train Loss: 2.5141 | Val Loss: 2.4052
New best model at epoch 5 | val_loss: 2.4052
Epoch 6 | Train Loss: 2.6724 | Val Loss: 2.3671
New best model at epoch 6 | val_loss: 2.3671
Epoch 7 | Train Loss: 2.3075 | Val Loss: 2.3347
New best model at epoch 7 | val_loss: 2.3347
Epoch 8 | Train Loss: 1.8914 | Val Loss: 2.3116
New best model at epoch 8 | val_loss: 2.3116
Epoch 9 | Train Loss: 2.3460 | Val Loss: 2.2918
New best model at epoch 9 | val_loss: 2.2918
Epoch 10 | Train Loss: 1.8476 | Val Los

In [43]:
from tokenizers.decoders import Metaspace as MetaspaceDecoder
import torch.nn.functional as F

tokenizer.decoder = MetaspaceDecoder(replacement=" ", prepend_scheme="never")

seq_length = dataset_train.seq_length

def clean_text(text):
    return text.lower()

def generate_text(input, num_tokens, path):
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    leftover_prefix = []
    if len(input_ids) > seq_length:
        leftover_prefix = input_ids[:-seq_length]
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    for _ in range(num_tokens):
        input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = F.softmax(output / 0.05, dim=1)
            next_id = torch.multinomial(probabilities, num_samples=1).item()

        generated.append(next_id)
    return tokenizer.decode(leftover_prefix + generated).replace("▁", " ")

def print_next_token_probabilities(input, path):
    """
    Print the top 10 next token probabilities
    :param input:
    :param path:
    """
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    if len(input_ids) > seq_length:
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        top_probs, top_indices = torch.topk(probabilities[0], 10)
    
    for i in range(10):
      token_id = top_indices[i].item()
      prob = top_probs[i].item()
      token_text = tokenizer.decode([token_id])  # Assuming you have a tokenizer
      print(f"Token: '{token_text}' (ID: {token_id}) - Probability: {prob:.4f}")
    print("\n")


# Output and various stats
prompt = clean_text("Oh, you can't help that; we're all ")

encoded = tokenizer.encode(prompt, add_special_tokens=False)
token_ids = encoded.ids
tokens = [tokenizer.decode([t]) for t in token_ids]
print(f"Prompt: {prompt}\n")
print(f"Prompt tokens: {tokens}\n")

print_next_token_probabilities(prompt, trained_path)

result_untrained = generate_text(prompt, 100, untrained_path)
result_trained = generate_text(prompt, 100, trained_path)
print(f"Untrained sample: {result_untrained}\n")
print(f"Trained sample: {result_trained}\n")

Prompt: oh, you can't help that; we're all 

Prompt tokens: ['▁o', 'h', ',', '▁', 'you', '▁c', 'a', 'n', "'", 't', '▁he', 'l', 'p', '▁t', 'hat', ';', '▁w', 'e', "'", 're', '▁a', 'll', '▁']

Token: 'e' (ID: 29) - Probability: 0.2377
Token: 'u' (ID: 45) - Probability: 0.1019
Token: 'r' (ID: 42) - Probability: 0.0908
Token: 'ha' (ID: 67) - Probability: 0.0884
Token: 'h' (ID: 32) - Probability: 0.0867
Token: 'v' (ID: 46) - Probability: 0.0756
Token: 'you' (ID: 99) - Probability: 0.0614
Token: 'i' (ID: 33) - Probability: 0.0603
Token: 're' (ID: 68) - Probability: 0.0338
Token: 'j' (ID: 34) - Probability: 0.0233


Untrained sample:  oh, you can't help that; we're all  saar2 g&&  6 "2 n and of88233n s: sa sa(edt mr itha sided5 s8 of8 l8 of sa"2cstc sah;ed5t05letrtled and1s ofas it cas5?aseds8ghe saj t l?ed5 sand he82 sas sandhaowcowc

Trained sample:  oh, you can't help that; we're all everything is explanage your everything replied. "i've you know, and the king.  "there's alice, "and your kn