In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

print(f"Torch version {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA Available: {torch.cuda.is_available()}")

import importlib
import models.nlp
importlib.reload(models.nlp)
from models.nlp import SimpleTextRnn, SimpleTransformerDecoderOnly

from config import ROOT_DIR
import os
import math
from datetime import datetime
from torch.utils.data import DataLoader
import custom_datasets.alice_in_wonderland
importlib.reload(custom_datasets.alice_in_wonderland)

# options
#mode = 'RNN'
mode = 'Transformer'

# hyperparameters
embedding_size = 256
lr = 0.0025
seq_length = 10
vocab_size = 100

# transformer only
ff_dimension = 1024
num_heads = 4
num_layers = 2

# rnn only
hidden_size = 32

dataset_train = None
dataset_test = None

dataset_train = custom_datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=True)
tokenizer = dataset_train.tokenizer
dataset_test = custom_datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=False, tokenizer=tokenizer)

log_vocab = math.log(tokenizer.get_vocab_size())
print(f"Log of vocab {log_vocab}")

train_dataloader = DataLoader(dataset_train, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=True)

if mode == 'RNN':
    model = SimpleTextRnn(vocab_size=tokenizer.get_vocab_size(), embed_size=embedding_size, hidden_size=hidden_size).to(device)
elif mode == 'Transformer':
    model = SimpleTransformerDecoderOnly(vocab_size=tokenizer.get_vocab_size(), embedding_dim=embedding_size, num_heads=num_heads, num_layers=num_layers, ff_dimension=ff_dimension, max_length=seq_length, dropout=0.1).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)


best_val_loss = float('inf')
best_epoch = 0
train_epoch_count = 100

if not os.path.exists(os.path.join(ROOT_DIR, '.models')):
    os.mkdir(os.path.join(ROOT_DIR, '.models'))

trained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_trained.pth')
untrained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_untrained.pth')

torch.save(model.state_dict(), untrained_path)
no_improvement_count = 0

try:
    for epoch in range(train_epoch_count+1):
        model.train()
        for batch_X, batch_Y in train_dataloader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_X)

            loss = criterion(outputs, batch_Y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        model.eval()
        with torch.no_grad():
            val_losses = []
            for val_X, val_Y in test_dataloader:
                val_X, val_Y = val_X.to(device), val_Y.to(device)
                val_output = model(val_X)
                val_loss = criterion(val_output, val_Y)
                val_losses.append(val_loss.item())
            avg_val_loss = sum(val_losses) / len(val_losses)
        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            no_improvement_count = 0
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), trained_path)
            best_epoch = epoch
            print(f"New best model at epoch {epoch} | val_loss: {avg_val_loss:.4f}")
        else:
            no_improvement_count += 1
            if no_improvement_count >= 5:
                print(f"Stopping training after {no_improvement_count} epochs without improvement")
                break

        scheduler.step(avg_val_loss)
except KeyboardInterrupt:
    print("Stopping training due to keyboard interrupt")

print(f"Using model from epoch {best_epoch} | val_loss: {best_val_loss:.4f} | entropy vocab: {log_vocab:.4f}")


current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
formatted_date = current_datetime.strftime("%Y_%m_%d")
with open(os.path.join(ROOT_DIR, '.models', f"{formatted_date}_alice_in_wonderland_training_log.txt"), 'a') as f:
    f.write(f"{formatted_datetime}: Train result\n")
    f.write(f"Mode: {mode}\n")
    f.write(f"Vocab size: {tokenizer.get_vocab_size()}\n")
    f.write(f"Entropy of vocab: {log_vocab:.4f}\n")
    f.write(f"Embedding size: {embedding_size}\n")

    if mode == 'RNN':
        f.write(f"Hidden size: {hidden_size}\n")
    if mode == 'Transformer':
        f.write(f"Num heads: {num_heads}\n")
        f.write(f"Num layers: {num_layers}\n")
        f.write(f"FF dimension: {ff_dimension}\n")

    f.write(f"Seq length: {dataset_train.seq_length}\n")
    f.write(f"Train epochs: {train_epoch_count}\n")
    f.write(f"Learning rate: {lr}\n")
    f.write(f"Best epoch: {best_epoch}\n")
    f.write(f"Best val loss: {best_val_loss:.4f}\n")
    f.write("***\n")

Torch version 2.7.1+cu118
CUDA Available: True
Log of vocab 4.605170185988092
Epoch 0 | Train Loss: 2.6673 | Val Loss: 3.1016
New best model at epoch 0 | val_loss: 3.1016
Epoch 1 | Train Loss: 3.1412 | Val Loss: 2.9007
New best model at epoch 1 | val_loss: 2.9007
Epoch 2 | Train Loss: 2.4748 | Val Loss: 2.7587
New best model at epoch 2 | val_loss: 2.7587
Epoch 3 | Train Loss: 2.6124 | Val Loss: 2.6776
New best model at epoch 3 | val_loss: 2.6776
Epoch 4 | Train Loss: 2.7788 | Val Loss: 2.6359
New best model at epoch 4 | val_loss: 2.6359
Epoch 5 | Train Loss: 3.1725 | Val Loss: 2.5644
New best model at epoch 5 | val_loss: 2.5644
Epoch 6 | Train Loss: 2.4069 | Val Loss: 2.5070
New best model at epoch 6 | val_loss: 2.5070
Epoch 7 | Train Loss: 2.3649 | Val Loss: 2.4936
New best model at epoch 7 | val_loss: 2.4936
Epoch 8 | Train Loss: 1.9982 | Val Loss: 2.4339
New best model at epoch 8 | val_loss: 2.4339
Epoch 9 | Train Loss: 2.4670 | Val Loss: 2.4268
New best model at epoch 9 | val_loss:

In [5]:
from tokenizers.decoders import Metaspace as MetaspaceDecoder
import torch.nn.functional as F

tokenizer.decoder = MetaspaceDecoder(replacement=" ", prepend_scheme="never")

seq_length = dataset_train.seq_length

def clean_text(text):
    return text.lower()

def generate_text(input, num_tokens, path):
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    leftover_prefix = []
    if len(input_ids) > seq_length:
        leftover_prefix = input_ids[:-seq_length]
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    for _ in range(num_tokens):
        input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = F.softmax(output / 0.3, dim=1)
            next_id = torch.multinomial(probabilities, num_samples=1).item()

        generated.append(next_id)
    return tokenizer.decode(leftover_prefix + generated).replace("▁", " ")

def print_next_token_probabilities(input, path):
    """
    Print the top 10 next token probabilities
    :param input:
    :param path:
    """
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    if len(input_ids) > seq_length:
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        top_probs, top_indices = torch.topk(probabilities[0], 10)
    
    for i in range(10):
      token_id = top_indices[i].item()
      prob = top_probs[i].item()
      token_text = tokenizer.decode([token_id])  # Assuming you have a tokenizer
      print(f"Token: '{token_text}' (ID: {token_id}) - Probability: {prob:.4f}")
    print("\n")


# Output and various stats
prompt = clean_text("Oh, you can't help that; we're all ")

encoded = tokenizer.encode(prompt, add_special_tokens=False)
token_ids = encoded.ids
tokens = [tokenizer.decode([t]) for t in token_ids]
print(f"Prompt: {prompt}\n")
print(f"Prompt tokens: {tokens}\n")

print_next_token_probabilities(prompt, trained_path)

result_untrained = generate_text(prompt, 100, untrained_path)
result_trained = generate_text(prompt, 100, trained_path)
print(f"Untrained sample: {result_untrained}\n")
print(f"Trained sample: {result_trained}\n")

Prompt: oh, you can't help that; we're all 

Prompt tokens: ['▁o', 'h', ',', '▁', 'you', '▁c', 'a', 'n', "'", 't', '▁he', 'l', 'p', '▁t', 'hat', ';', '▁w', 'e', "'", 're', '▁a', 'll', '▁']

Token: 'you' (ID: 98) - Probability: 0.3383
Token: 'v' (ID: 45) - Probability: 0.1275
Token: 'r' (ID: 41) - Probability: 0.0827
Token: 'h' (ID: 31) - Probability: 0.0740
Token: 'i' (ID: 32) - Probability: 0.0547
Token: 'ha' (ID: 66) - Probability: 0.0488
Token: 're' (ID: 67) - Probability: 0.0436
Token: 'j' (ID: 33) - Probability: 0.0403
Token: 'e' (ID: 28) - Probability: 0.0401
Token: 'q' (ID: 40) - Probability: 0.0378


Untrained sample:  oh, you can't help that; we're all f msgileuiit  n1 a]g tt o nougvq[ux&t noua]enc]sq[onou and n] "/] tt]iu·ouaeouen,): mit]goup[ion b].0 and t]iit7ùp aiceidlliiu]iitfg t].f"6

Trained sample:  oh, you can't help that; we're all you had executions to beland you know." "i've you, won't have you know." "what are you'll think you, what don't you," said the mock turtl

In [3]:
for name, param in model.named_parameters():
    print(name, param)

decoder_embedding.weight Parameter containing:
tensor([[-1.3233,  1.2643, -1.7912,  ..., -0.5810, -1.0481,  0.0549],
        [-1.0497, -0.3482,  0.5135,  ..., -0.2672, -0.5990,  0.1220],
        [-0.8384,  0.4315, -0.6385,  ...,  0.6479, -0.2344, -0.1590],
        ...,
        [ 0.6905,  0.0968, -0.5390,  ..., -1.8336,  0.8855, -0.9251],
        [-0.6524, -1.4054, -0.4443,  ..., -0.3894,  0.3077, -0.5570],
        [ 0.6964,  0.5013, -1.5934,  ...,  0.3673, -0.3531, -2.3285]],
       device='cuda:0', requires_grad=True)
decoder_layers.0.self_attn.W_q.weight Parameter containing:
tensor([[-0.0441,  0.1766,  0.4301,  ...,  0.1938,  0.4314, -0.1383],
        [ 0.0760,  0.4064,  0.2866,  ..., -0.0631,  0.2546,  0.5881],
        [-0.2954, -0.4650, -0.1654,  ..., -0.1698,  0.4043,  0.1639],
        ...,
        [ 0.0858,  0.2781, -0.3729,  ..., -0.0609,  0.0443, -0.3178],
        [-0.1457, -0.2222,  0.6933,  ...,  0.1617, -0.1677,  0.5245],
        [-0.1054,  0.0496, -0.0506,  ...,  0.8333,  