In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

print(f"Torch version {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA Available: {torch.cuda.is_available()}")

Torch version 2.7.1+cu118
CUDA Available: True


In [28]:
import importlib
import models.nlp
importlib.reload(models.nlp)
from models.nlp import SimpleTextRnn, SimpleTransformerDecoderOnly

from config import ROOT_DIR
import os
import math
from datetime import datetime
from torch.utils.data import DataLoader
import datasets.alice_in_wonderland
importlib.reload(datasets.alice_in_wonderland)

# options
#mode = 'RNN'
mode = 'Transformer'

# hyperparameters
embedding_size = 256
hidden_size = 32
lr = 0.0025
seq_length = 10
vocab_size = 500

dataset_train = None
dataset_test = None

dataset_train = datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=True)
tokenizer = dataset_train.tokenizer
dataset_test = datasets.alice_in_wonderland.AliceInWonderlandDataset(seq_length=seq_length, vocab_size=vocab_size, train=False, tokenizer=tokenizer)

log_vocab = math.log(tokenizer.get_vocab_size())
print(f"Log of vocab {log_vocab}")

train_dataloader = DataLoader(dataset_train, batch_size=64, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=True)

if mode == 'RNN':
    model = SimpleTextRnn(vocab_size=tokenizer.get_vocab_size(), embed_size=embedding_size, hidden_size=hidden_size).to(device)
elif mode == 'Transformer':
    model = SimpleTransformerDecoderOnly(vocab_size=tokenizer.get_vocab_size(), embedding_dim=embedding_size, num_heads=4, num_layers=2, ff_dimension=128, max_length=seq_length, dropout=0.1).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)


best_val_loss = float('inf')
best_epoch = 0
train_epoch_count = 100

if not os.path.exists(os.path.join(ROOT_DIR, '.models')):
    os.mkdir(os.path.join(ROOT_DIR, '.models'))

trained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_trained.pth')
untrained_path = os.path.join(ROOT_DIR, '.models', 'alice_in_wonderland_untrained.pth')

torch.save(model.state_dict(), untrained_path)
no_improvement_count = 0

try:
    for epoch in range(train_epoch_count+1):
        model.train()
        for batch_X, batch_Y in train_dataloader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_X)

            loss = criterion(outputs, batch_Y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        model.eval()
        with torch.no_grad():
            val_losses = []
            for val_X, val_Y in test_dataloader:
                val_X, val_Y = val_X.to(device), val_Y.to(device)
                val_output = model(val_X)
                val_loss = criterion(val_output, val_Y)
                val_losses.append(val_loss.item())
            avg_val_loss = sum(val_losses) / len(val_losses)
        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            no_improvement_count = 0
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), trained_path)
            best_epoch = epoch
            print(f"New best model at epoch {epoch} | val_loss: {avg_val_loss:.4f}")
        else:
            no_improvement_count += 1
            if no_improvement_count >= 5:
                print(f"Stopping training after {no_improvement_count} epochs without improvement")
                break

        scheduler.step(avg_val_loss)
except KeyboardInterrupt:
    print("Stopping training due to keyboard interrupt")

print(f"Using model from epoch {best_epoch} | val_loss: {best_val_loss:.4f} | entropy vocab: {log_vocab:.4f}")


current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
formatted_date = current_datetime.strftime("%Y_%m_%d")
with open(os.path.join(ROOT_DIR, '.models', f"{formatted_date}_alice_in_wonderland_training_log.txt"), 'a') as f:
    f.write(f"{formatted_datetime}: Train result\n")
    f.write(f"Mode: {mode}\n")
    f.write(f"Vocab size: {tokenizer.get_vocab_size()}\n")
    f.write(f"Entropy of vocab: {log_vocab:.4f}\n")
    f.write(f"Embedding size: {embedding_size}\n")
    f.write(f"Hidden size: {hidden_size}\n")
    f.write(f"Seq length: {dataset_train.seq_length}\n")
    f.write(f"Train epochs: {train_epoch_count}\n")
    f.write(f"Learning rate: {lr}\n")
    f.write(f"Best epoch: {best_epoch}\n")
    f.write(f"Best val loss: {best_val_loss:.4f}\n")
    f.write("***\n")

Log of vocab 6.214608098422191
Epoch 0 | Train Loss: 5.8170 | Val Loss: 5.6404
New best model at epoch 0 | val_loss: 5.6404
Epoch 1 | Train Loss: 5.5702 | Val Loss: 5.2805
New best model at epoch 1 | val_loss: 5.2805
Epoch 2 | Train Loss: 5.4175 | Val Loss: 5.0783
New best model at epoch 2 | val_loss: 5.0783
Epoch 3 | Train Loss: 5.2295 | Val Loss: 4.9453
New best model at epoch 3 | val_loss: 4.9453
Epoch 4 | Train Loss: 4.8739 | Val Loss: 4.9173
New best model at epoch 4 | val_loss: 4.9173
Epoch 5 | Train Loss: 4.4088 | Val Loss: 4.8631
New best model at epoch 5 | val_loss: 4.8631
Epoch 6 | Train Loss: 4.6895 | Val Loss: 4.8101
New best model at epoch 6 | val_loss: 4.8101
Epoch 7 | Train Loss: 4.1526 | Val Loss: 4.7598
New best model at epoch 7 | val_loss: 4.7598
Epoch 8 | Train Loss: 4.8414 | Val Loss: 4.7209
New best model at epoch 8 | val_loss: 4.7209
Epoch 9 | Train Loss: 4.5583 | Val Loss: 4.7266
Epoch 10 | Train Loss: 5.0736 | Val Loss: 4.7213
Epoch 11 | Train Loss: 4.4614 | Val

In [29]:
from tokenizers.decoders import Metaspace as MetaspaceDecoder
import torch.nn.functional as F

tokenizer.decoder = MetaspaceDecoder(replacement=" ", prepend_scheme="never")

seq_length = dataset_train.seq_length

def clean_text(text):
    return text.lower()

def generate_text(input, num_tokens, path):
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    leftover_prefix = []
    if len(input_ids) > seq_length:
        leftover_prefix = input_ids[:-seq_length]
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    for _ in range(num_tokens):
        input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = F.softmax(output / 0.2, dim=1)
            next_id = torch.multinomial(probabilities, num_samples=1).item()

        generated.append(next_id)
    return tokenizer.decode(leftover_prefix + generated).replace("▁", " ")

def print_next_token_probabilities(input, path):
    """
    Print the top 10 next token probabilities
    :param input:
    :param path:
    """
    model.load_state_dict(torch.load(path, weights_only=True))
    model.eval()

    input_ids = tokenizer.encode(input).ids
    if len(input_ids) > seq_length:
        input_ids = input_ids[-seq_length:]

    generated = input_ids.copy()
    input_tensor = torch.tensor([generated[-seq_length:]], dtype=torch.long).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        top_probs, top_indices = torch.topk(probabilities[0], 10)
    
    for i in range(10):
      token_id = top_indices[i].item()
      prob = top_probs[i].item()
      token_text = tokenizer.decode([token_id])  # Assuming you have a tokenizer
      print(f"Token: '{token_text}' (ID: {token_id}) - Probability: {prob:.4f}")
    print("\n")


# Output and various stats
prompt = clean_text("Oh, you can't help that; we're all ")

encoded = tokenizer.encode(prompt, add_special_tokens=False)
token_ids = encoded.ids
tokens = [tokenizer.decode([t]) for t in token_ids]
print(f"Prompt: {prompt}\n")
print(f"Prompt tokens: {tokens}\n")

print_next_token_probabilities(prompt, trained_path)

result_untrained = generate_text(prompt, 100, untrained_path)
result_trained = generate_text(prompt, 100, trained_path)
print(f"Untrained sample: {result_untrained}\n")
print(f"Trained sample: {result_trained}\n")

Prompt: oh, you can't help that; we're all 

Prompt tokens: ['▁o', 'h,', '▁you', '▁can', "'t", '▁he', 'l', 'p', '▁that', ';', '▁w', 'e', "'re", '▁all', '▁']

Token: 'ought' (ID: 231) - Probability: 0.1276
Token: 'ver' (ID: 119) - Probability: 0.1043
Token: 'x' (ID: 47) - Probability: 0.0937
Token: ''t' (ID: 142) - Probability: 0.0696
Token: '1' (ID: 10) - Probability: 0.0672
Token: 'our' (ID: 224) - Probability: 0.0667
Token: 'id' (ID: 80) - Probability: 0.0635
Token: 'z' (ID: 49) - Probability: 0.0594
Token: '4' (ID: 13) - Probability: 0.0528
Token: 'ed' (ID: 73) - Probability: 0.0391


Untrained sample:  oh, you can't help that; we're all  mock king when0ackveroutverose theirow.'t mock neverout de wheni back litppureenamesverose inxy,lyly, shaver'tverunverose not marchkedust when0 when0 when a got "andome whenance,"at't gollss lastceningup "nroortroortverose sa su lj ne when't she alice, lastag neveroutver kn alice, ducheovenonkedw g ab. shaackver

Trained sample:  oh, you can't help

In [13]:
for name, param in model.named_parameters():
    print(name, param)

embed.weight Parameter containing:
tensor([[-0.9728,  0.4297,  0.1628,  ...,  1.3895, -0.0952, -1.7533],
        [-1.6825,  0.7517, -0.2961,  ...,  0.3894,  2.3665, -0.5458],
        [ 1.0750, -0.3850, -0.0465,  ...,  1.3011,  0.4996, -0.9055],
        ...,
        [ 1.2738,  0.2364, -1.5954,  ...,  0.7503, -0.8904,  1.3999],
        [ 0.9704, -0.0747,  0.5486,  ..., -0.3408,  0.2730,  0.4335],
        [ 0.6290, -0.6324, -0.5008,  ...,  0.4935, -0.6753,  0.5932]],
       device='cuda:0', requires_grad=True)
rnn.weight_ih_l0 Parameter containing:
tensor([[ 0.2566,  0.0724,  0.0699,  ..., -0.2430, -0.2904,  0.1087],
        [-0.3271,  0.0460, -0.2292,  ...,  0.4650, -0.4118,  0.0568],
        [ 0.0422,  0.0064,  0.0764,  ...,  0.5999, -0.2273,  0.0566],
        ...,
        [ 0.1442, -0.7310, -0.4212,  ..., -0.1402,  0.2384, -0.3092],
        [ 0.0835, -0.1230,  0.1182,  ...,  0.4482, -0.5176, -0.1190],
        [ 0.3978,  0.3557, -0.0986,  ...,  0.1547, -0.0208, -0.2164]],
       device=