In [1]:
import random

import torch
import torch.nn as nn
from transformers import AutoTokenizer

from rnn_utils import download_and_prepare_data, get_hyperparameters


def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.mps.manual_seed(seed)
    # torch.backends.mps.deterministic = True
    # torch.backends.mps.benchmark = True

In [2]:
class ElmanRNNUnit(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.Uh = nn.Parameter(torch.randn(emb_dim, emb_dim))
        self.Wh = nn.Parameter(torch.randn(emb_dim, emb_dim))
        self.b = nn.Parameter(torch.randn(emb_dim))

    def forward(self, x, h):
        return torch.tanh(x @ self.Wh + h @ self.Uh + self.b)

class ElmanRNN(nn.Module):
    def __init__(self, emb_dim, num_layers):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.rnn_units = nn.ModuleList([ElmanRNNUnit(emb_dim) for _ in range(num_layers)])

    def forward(self, x):
        batch_size, seq_len, emb_dim = x.shape
        h_prev = [
            torch.zeros(batch_size, emb_dim, device=x.device) for _ in range(self.num_layers)
        ]
        output = []
        for t in range(seq_len):
            input_t = x[:, t]
            for l, rnn_unit in enumerate(self.rnn_units):
                h_new = rnn_unit(input_t, h_prev[l])
                h_prev[l] = h_new
                input_t = h_new
            output.append(input_t)
        return torch.stack(output, dim=1)

class RecurrentLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_layers, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        # self.embedding will transform a tensor of indices into a vector of embeddings.
        # the valid indices are 0 ... vocab_size - 1
        self.rnn = ElmanRNN(emb_dim, num_layers)
        self.fc = nn.Linear(emb_dim, vocab_size)

    def forward(self, x):
        embeddings = self.embedding(x)
        rnn_output = self.rnn(embeddings)
        logits = self.fc(rnn_output)
        return logits



In [3]:
def initialize_weights(model):
    # Loop through all named parameters in the model
    for name, param in model.named_parameters():
        # Check if parameter has more than 1 dimension (e.g., weight matrices)
        if param.dim() > 1:
            # Use Xavier uniform initialization for weight matrices
            # This helps prevent vanishing/exploding gradients by keeping the variance constant
            nn.init.xavier_uniform_(param)
        else:
            # For 1D parameters (like biases), use simple uniform initialization
            nn.init.uniform_(param)

In [4]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
vocab_size = len(tokenizer)
emb_dim, num_layers, batch_size, learning_rate, num_epochs = get_hyperparameters()
data_url = "https://www.thelmbook.com/data/news"
train_loader, test_loader = download_and_prepare_data(data_url, batch_size, tokenizer)
model = RecurrentLanguageModel(vocab_size, emb_dim, num_layers, tokenizer.pad_token_id)
initialize_weights(model)
model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

news.tar.gz already downloaded.
Data files already extracted.
Counting sentences in news/train.txt...
Found 22034911 sentences in news/train.txt.
Counting sentences in news/test.txt...
Found 449693 sentences in news/test.txt.
Training sentences: 22034911
Test sentences: 449693


In [5]:
vocab_size

32011

In [7]:
torch.tensor(tokenizer.encode("man child"))

tensor([ 767, 2278])

In [9]:
print(
    model.embedding(
        torch.tensor(tokenizer.encode("man"), device=device)
    )
)

tensor([[ 8.5834e-03, -7.8187e-03, -1.1247e-04, -1.0457e-02, -7.3725e-03,
          1.3349e-02,  7.0731e-03, -4.5699e-03,  2.5393e-03,  7.0062e-03,
         -9.3500e-03, -2.5188e-03,  4.5688e-03,  4.7263e-03,  9.6404e-03,
         -3.5349e-03,  1.4849e-03,  5.3625e-03,  8.7292e-03, -1.1444e-02,
          2.3268e-03,  1.1292e-02,  4.8490e-03,  1.0792e-02, -6.2727e-03,
          1.2222e-02, -1.0780e-02, -3.3701e-03, -7.6370e-03, -4.4505e-03,
         -1.1230e-02, -4.2551e-03, -1.5404e-03,  6.9711e-03,  5.3803e-03,
         -8.8852e-03, -9.1255e-03,  6.7602e-03,  1.3893e-03, -8.2448e-03,
          7.6066e-03, -3.7837e-03,  1.1616e-02, -4.1171e-03,  1.2506e-02,
          4.1735e-03,  5.5551e-03,  4.6947e-03,  1.2231e-02,  2.7532e-03,
         -9.1870e-03,  1.1686e-02,  2.8765e-03, -1.1717e-02, -1.0116e-02,
         -9.2402e-03,  7.2016e-03, -8.2284e-03,  1.3208e-02,  3.5805e-03,
          7.0149e-03, -5.2691e-03,  7.1728e-03,  8.4686e-03, -9.9511e-03,
          1.0647e-02, -1.1975e-02, -1.

In [32]:
for name, param in model.named_parameters():
    print(name)

embedding.weight
rnn.rnn_units.0.Uh
rnn.rnn_units.0.Wh
rnn.rnn_units.0.b
rnn.rnn_units.1.Uh
rnn.rnn_units.1.Wh
rnn.rnn_units.1.b
fc.weight
fc.bias


In [42]:
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        input_seq, target_seq = batch
        input_seq = input_seq.to(device)
        target_seq = target_seq.to(device)
        batch_size_current, seq_len = input_seq.shape

        optimizer.zero_grad()
        output = model(input_seq)
        output = output.reshape(batch_size_current * seq_len, vocab_size)
        target = target_seq.reshape(batch_size_current * seq_len)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'vocab_size': vocab_size,
    'emb_dim': emb_dim,
    'num_layers': num_layers,
    'tokenizer': tokenizer
}, 'rnn_model_checkpoint.pt')


In [58]:
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0):
    model.eval()
    tokens = tokenizer.encode(prompt)
    for _ in range(max_length):
        with torch.no_grad():
            input_tensor = torch.tensor([tokens[-1]]).unsqueeze(0).to(device)  # (1, 1)
            output = model(input_tensor)
            probs = torch.softmax(output[0, -1] / temperature, dim=-1)
            next_token = torch.multinomial(probs, 1).item()
        tokens.append(next_token)
        if next_token == tokenizer.eos_token_id:
            break
    return tokenizer.decode(tokens)

print(generate_text(model, tokenizer, "ping"))
# sample("ping")

Once upon a time to his way in a 'We ' I amendin are mystifying African individuals and CEOpoint put outbreak 'The daylight now said there was nurses crystal punch up to travel to stop all made upstream revamp


In [62]:
print(generate_text(model, tokenizer, "Most of the slick has been"))

Most of the slick has been another layer has released headlines still must serve as it is an artist a terrible that could resurred in Tributes for six-cake e-free video ; again by every contact with Sonal just dozbola Barnesia will


In [64]:
loss.item()

4.316527366638184

In [70]:
tokenizer("We train a recurrent neural network as a language model")

{'input_ids': [1334, 7945, 263, 1162, 1264, 19677, 3564, 408, 263, 4086, 1904], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [10]:
enc = tokenizer.encode("We train a recurrent neural network as a language model", max_length=30, truncation=True)
print(enc)

[1334, 7945, 263, 1162, 1264, 19677, 3564, 408, 263, 4086, 1904]


In [11]:
str(tokenizer.convert_ids_to_tokens(enc))

"['▁We', '▁train', '▁a', '▁rec', 'urrent', '▁neural', '▁network', '▁as', '▁a', '▁language', '▁model']"

In [12]:
tokenizer.decode(enc)

'We train a recurrent neural network as a language model'

In [25]:
print(model.embedding(torch.tensor(tokenizer.encode("hello"), device=device)))

tensor([[-4.9974e-03,  7.2700e-03, -3.7599e-03,  8.3719e-04,  9.5455e-03,
          2.6397e-03,  4.7868e-03,  1.2753e-02, -1.0245e-02, -7.3390e-03,
          9.9733e-03,  5.9080e-03,  1.2885e-02, -3.7159e-03, -1.1508e-02,
         -8.0394e-03, -4.9860e-03,  1.8733e-03,  4.4149e-03,  7.5154e-03,
         -1.2728e-02,  6.5686e-03,  1.0498e-02,  1.1756e-02, -3.1670e-03,
          1.1699e-02, -8.6411e-03, -6.2517e-03,  8.3845e-03,  6.9486e-03,
         -7.1743e-04,  7.3013e-04,  4.6167e-05,  7.3659e-03,  5.4039e-03,
          5.6552e-03,  9.9196e-03,  8.4207e-03, -7.9735e-03,  9.6027e-03,
          4.1047e-04,  3.3144e-03,  1.9215e-03,  1.3058e-02, -4.7337e-04,
          5.7468e-03, -2.5526e-03,  1.1254e-02, -3.4482e-03,  4.4873e-03,
         -2.0660e-03,  8.8983e-03,  9.7610e-03,  1.1558e-02,  9.5240e-04,
         -9.1947e-03, -1.2801e-02, -1.4333e-04,  2.1315e-03,  8.7773e-03,
         -3.4797e-03, -9.8399e-03,  5.6754e-04, -1.1483e-03,  1.2687e-02,
          7.0264e-03,  5.2931e-03,  2.