# Chapter 12: A Language model from scratch

preparing dataset from karpathy makemore https://github.com/karpathy/makemore/blob/master/makemore.py and using shakespeare dataset 

In [1]:
from torch.utils.data import Dataset
import torch

class CharDataset(Dataset):

    def __init__(self, words, chars, max_word_length):
        self.words = words
        self.chars = chars
        self.max_word_length = max_word_length
        # char to idx
        self.stoi = {ch:i+1 for i,ch in enumerate(chars)}
        # indx to char
        self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping

    def __len__(self):
        return len(self.words)

    def contains(self, word):
        return word in self.words

    def get_vocab_size(self):
        return len(self.chars) + 1 # all the possible characters and special 0 token

    def get_output_length(self):
        return self.max_word_length + 1 # <START> token followed by words

    def encode(self, word):
        ix = torch.tensor([self.stoi[w] for w in word], dtype=torch.long)
        return ix

    def decode(self, ix):
        word = ''.join(self.itos[i] for i in ix)
        return word

    def __getitem__(self, idx):
        """so what is happening here is we encode then for
        x: start at 1 to leave for start token
        y: start at 0, so theyre offset. x->y pred; mask everything beyond that"""
        word = self.words[idx]
        ix = self.encode(word)
        x = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        y = torch.zeros(self.max_word_length + 1, dtype=torch.long)
        x[1:1+len(ix)] = ix # starts at 1 to leave 0 for the <START> token
        y[:len(ix)] = ix
        y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations
        return x, y


def create_datasets(input_file):

    # preprocessing of the input text file
    with open(input_file, 'r') as f:
        data = f.read()
    words = data.splitlines()
    words = [w.strip() for w in words] # get rid of any leading or trailing white space
    words = [w for w in words if w] # get rid of any empty strings
    chars = sorted(list(set(''.join(words)))) # all the possible characters
    max_word_length = max(len(w) for w in words)
    print(f"number of examples in the dataset: {len(words)}")
    print(f"max word length: {max_word_length}")
    print(f"number of unique characters in the vocabulary: {len(chars)}")
    print("vocabulary:")
    print(''.join(chars))

    # partition the input data into a training and the test set
    test_set_size = min(1000, int(len(words) * 0.1)) # 10% of the training set, or up to 1000 examples
    rp = torch.randperm(len(words)).tolist()
    train_words = [words[i] for i in rp[:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]
    print(f"split up the dataset into {len(train_words)} training examples and {len(test_words)} test examples")

    # wrap in dataset objects
    train_dataset = CharDataset(train_words, chars, max_word_length)
    test_dataset = CharDataset(test_words, chars, max_word_length)

    return train_dataset, test_dataset

In [2]:
train_dataset, test_dataset = create_datasets('../data/input.txt')
vocab_size = train_dataset.get_vocab_size() # unique characters
block_size = train_dataset.get_output_length() # longest word + 1 char for start sequence
print(f"dataset determined that: {vocab_size=}, {block_size=}")

number of examples in the dataset: 32777
max word length: 63
number of unique characters in the vocabulary: 64
vocabulary:
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
split up the dataset into 31777 training examples and 1000 test examples
dataset determined that: vocab_size=65, block_size=64


In [3]:
# ### detour to understand the __getitem__

# def lol():
#     word = train_dataset.words[0]
#     print(f'WORD: {word}')
#     print(f'LENGTH WORD: {len(word)}')
#     ix = train_dataset.encode(word)
#     print(ix)
#     print(f'LENGTH IX: {len(ix)}') # should match since its idx:char mapping

#     # that is 0th + longest sequence length
#     x = torch.zeros(train_dataset.max_word_length + 1, dtype=torch.long)
#     y = torch.zeros(train_dataset.max_word_length + 1, dtype=torch.long)
#     print(f'X Y shape: {x.shape} {y.shape}')

#     # so x and y are offset because 0 --> 32, 32 --> 46, etc
#     x[1:1+len(ix)] = ix # start at 1; leave 0 for start token; rest is padded zeros for sequence length
#     print(x)
#     y[:len(ix)] = ix # start at 0, its the encoding
#     y[len(ix)+1:] = -1 # mask the reset
#     print(y)

# lol()

In [4]:
from torch.utils.data import DataLoader

# Create data loaders with drop_last=True to ensure all batches have the same size
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, drop_last=True)

In [5]:
from tqdm import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for batch in progress_bar:
        X, Y = batch

print(X.size(), Y.size())

Epoch 1/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1290.46it/s]
Epoch 2/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1291.95it/s]
Epoch 3/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1369.72it/s]
Epoch 4/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1354.00it/s]
Epoch 5/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1381.78it/s]
Epoch 6/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1384.04it/s]
Epoch 7/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1382.36it/s]
Epoch 8/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1398.81it/s]
Epoch 9/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1394.99it/s]
Epoch 10/10 [Train]: 100%|██████████| 496/496 [00:00<00:00, 1388.43it/s]

torch.Size([64, 64]) torch.Size([64, 64])





In [6]:
X.size(), Y.size()

(torch.Size([64, 64]), torch.Size([64, 64]))

### start with simple RNN

In [35]:
import torch.nn as nn
import torch
import torch.nn.functional as F 

mps_device = torch.device('mps')


class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Embedding(vocab_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs, targets=None):
        batch_size, sequence_length = inputs.size()
        device = inputs.device

        h_prev = self.initHidden(batch_size, device)
        hiddens = []
        for i in range(sequence_length):
            h = h_prev + self.i2h(inputs[:, i]) # embed and combine with hidden unit
            h = F.tanh(h) # apply non-linearity)
            h = self.h2h(h) # linear transform to hidden size
            h_prev = h
            hiddens.append(h)
        
        # decode the outputs
        hidden = torch.stack(hiddens, dim=1)
        logits = self.h2o(hidden)
        h = h.detach()

        # calculate loss if targets are provided
        loss = None
        if targets is not None:
            if targets is not None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        
        return logits, loss
        
    
    def initHidden(self, batch_size, device):
        # Initialize the hidden state based on the current batch size
        return torch.zeros(batch_size, self.hidden_size).to(device)

In [36]:
@torch.inference_mode()
def evaluate(model, dataset, batch_size=50, max_batches=None):
    model.eval()
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0)
    losses = []
    total_correct = 0
    total_predictions = 0

    for i, batch in enumerate(loader):
        batch = [t.to(mps_device) for t in batch]
        X, Y = batch

        logits, loss = model(X, Y)
        losses.append(loss.item())

        # Calculate accuracy
        preds = torch.argmax(logits, dim=-1)  # Predictions
        total_correct += (preds == Y).sum().item()  # Count correct predictions
        total_predictions += Y.numel()  # Total number of predictions

        if max_batches is not None and i >= max_batches:
            break

    mean_loss = torch.tensor(losses).mean().item()
    accuracy = total_correct / total_predictions  # Calculate accuracy

    model.train()  # Reset model back to training mode
    return mean_loss, accuracy

In [37]:
mps_device = torch.device('mps')
hidden_size=64

model = RNN(vocab_size, hidden_size)
model.to(mps_device)
optimizer = torch.optim.Adam(model.parameters())


In [51]:
from tqdm import tqdm

num_epochs = 5

step = 0
for epoch in range(num_epochs):
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    
    # get batch
    for batch in progress_bar:
        X, Y = batch
        X, Y = X.to(mps_device), Y.to(mps_device)

        # feed into the model
        logits, loss = model(X, Y)

        # calculate the gradient, update the weights
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # logging
        if step % 100 == 0:
            print(f"step {step} | loss {loss.item():.4f}")
        
        if step > 0 and step % 500 == 0:
            train_loss, train_accuracy = evaluate(model, train_dataset, batch_size=64, max_batches=10)
            test_loss, test_accuracy  = evaluate(model, test_dataset,  batch_size=64, max_batches=10)
            print(f'LOSS')
            print(train_loss, test_loss)
            print('ACCURACY')
            print(train_accuracy, test_accuracy)
        
        step += 1

Epoch 1/5 [Train]:   1%|          | 6/496 [00:00<00:17, 28.07it/s]

step 0 | loss 1.8545


Epoch 1/5 [Train]:  21%|██▏       | 106/496 [00:03<00:12, 30.55it/s]

step 100 | loss 1.9131


Epoch 1/5 [Train]:  41%|████      | 204/496 [00:06<00:10, 28.81it/s]

step 200 | loss 1.8504


Epoch 1/5 [Train]:  62%|██████▏   | 306/496 [00:10<00:06, 28.72it/s]

step 300 | loss 1.8463


Epoch 1/5 [Train]:  81%|████████▏ | 404/496 [00:13<00:02, 31.95it/s]

step 400 | loss 1.8252


Epoch 1/5 [Train]: 100%|██████████| 496/496 [00:16<00:00, 29.84it/s]
Epoch 2/5 [Train]:   1%|          | 4/496 [00:00<00:16, 30.14it/s]

step 500 | loss 1.8420


Epoch 2/5 [Train]:   2%|▏         | 8/496 [00:00<00:37, 12.88it/s]

LOSS
1.858009696006775 1.8755420446395874
ACCURACY
0.23464133522727273 0.23772638494318182


Epoch 2/5 [Train]:  22%|██▏       | 107/496 [00:03<00:12, 30.44it/s]

step 600 | loss 1.8327


Epoch 2/5 [Train]:  42%|████▏     | 208/496 [00:07<00:09, 30.32it/s]

step 700 | loss 1.8473


Epoch 2/5 [Train]:  62%|██████▏   | 308/496 [00:10<00:06, 31.18it/s]

step 800 | loss 1.8674


Epoch 2/5 [Train]:  82%|████████▏ | 408/496 [00:13<00:02, 30.32it/s]

step 900 | loss 1.9037


Epoch 2/5 [Train]: 100%|██████████| 496/496 [00:16<00:00, 29.93it/s]
Epoch 3/5 [Train]:   1%|          | 6/496 [00:00<00:16, 29.25it/s]

step 1000 | loss 1.8607


Epoch 3/5 [Train]:   2%|▏         | 12/496 [00:00<00:29, 16.26it/s]

LOSS
1.8779292106628418 1.8526597023010254
ACCURACY
0.23801491477272727 0.24400745738636365


Epoch 3/5 [Train]:  23%|██▎       | 114/496 [00:04<00:12, 31.32it/s]

step 1100 | loss 1.8864


Epoch 3/5 [Train]:  43%|████▎     | 215/496 [00:07<00:09, 29.45it/s]

step 1200 | loss 1.8089


Epoch 3/5 [Train]:  63%|██████▎   | 311/496 [00:10<00:06, 30.13it/s]

step 1300 | loss 1.8830


Epoch 3/5 [Train]:  84%|████████▎ | 415/496 [00:13<00:02, 31.44it/s]

step 1400 | loss 1.8254


Epoch 3/5 [Train]: 100%|██████████| 496/496 [00:16<00:00, 30.17it/s]
Epoch 4/5 [Train]:   2%|▏         | 12/496 [00:00<00:15, 31.97it/s]

step 1500 | loss 1.8178


Epoch 4/5 [Train]:   3%|▎         | 16/496 [00:00<00:28, 16.71it/s]

LOSS
1.8369488716125488 1.8403290510177612
ACCURACY
0.23790394176136365 0.2400790127840909


Epoch 4/5 [Train]:  23%|██▎       | 116/496 [00:03<00:12, 31.58it/s]

step 1600 | loss 1.7904


Epoch 4/5 [Train]:  44%|████▎     | 216/496 [00:07<00:08, 32.41it/s]

step 1700 | loss 1.7954


Epoch 4/5 [Train]:  64%|██████▎   | 316/496 [00:10<00:05, 30.94it/s]

step 1800 | loss 1.8780


Epoch 4/5 [Train]:  84%|████████▍ | 416/496 [00:13<00:02, 31.43it/s]

step 1900 | loss 1.7907


Epoch 4/5 [Train]: 100%|██████████| 496/496 [00:15<00:00, 31.05it/s]
Epoch 5/5 [Train]:   3%|▎         | 16/496 [00:00<00:15, 31.12it/s]

step 2000 | loss 1.8256


Epoch 5/5 [Train]:   4%|▍         | 20/496 [00:00<00:27, 17.30it/s]

LOSS
1.8139830827713013 1.835822582244873
ACCURACY
0.253173828125 0.24345259232954544


Epoch 5/5 [Train]:  24%|██▍       | 120/496 [00:04<00:11, 32.39it/s]

step 2100 | loss 1.7965


Epoch 5/5 [Train]:  44%|████▍     | 220/496 [00:07<00:08, 32.17it/s]

step 2200 | loss 1.8056


Epoch 5/5 [Train]:  65%|██████▍   | 320/496 [00:10<00:05, 31.12it/s]

step 2300 | loss 1.8396


Epoch 5/5 [Train]:  85%|████████▍ | 420/496 [00:13<00:02, 31.82it/s]

step 2400 | loss 1.8272


Epoch 5/5 [Train]: 100%|██████████| 496/496 [00:15<00:00, 31.10it/s]


In [52]:
def generate_sequence(model, start_sequence, length, stoi, itos):
    model.eval()  # Set model to evaluation mode
    
    # Convert start sequence to tensor
    input_indices = [stoi[char] for char in start_sequence]
    input_tensor = torch.tensor([input_indices], dtype=torch.long).to(mps_device)
    
    # Generate sequence
    generated_sequence = start_sequence
    for _ in range(length - len(start_sequence)):
        with torch.inference_mode():
            logits, _ = model(input_tensor)
            # Get last character from the last position
            last_char_logits = logits[0, -1, :]
            predicted_char_index = torch.argmax(last_char_logits).item()
            predicted_char = itos[predicted_char_index]
            
            # Update the input for the next prediction
            generated_sequence += predicted_char
            input_tensor = torch.cat([input_tensor, torch.tensor([[predicted_char_index]], dtype=torch.long).to(mps_device)], dim=1)
    
    return generated_sequence


In [53]:
print(generate_sequence(model, 'i am very interested', 150, train_dataset.stoi, train_dataset.itos))

i am very interested with the shall the with the shall the with the shall the with the shall the with the shall the with the shall the with the shall 


### RNN with more layers

In [54]:
import torch.nn as nn
import torch
import torch.nn.functional as F 

mps_device = torch.device('mps')


class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Embedding(vocab_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2h2 = nn.Linear(hidden_size, hidden_size) # second RNN layer
        self.h2o = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, inputs, targets=None):
        batch_size, sequence_length = inputs.size()
        device = inputs.device

        h_prev = self.initHidden(batch_size, device)
        hiddens = []
        for i in range(sequence_length):
            h = h_prev + self.i2h(inputs[:, i]) # embed and combine with hidden unit

            # first RNN
            h = F.tanh(h) # apply non-linearity)
            h = self.h2h(h) # linear transform to hidden size

            h = F.tanh(h) # apply non-linearity
            h = self.h2h2(h)

            h_prev = h
            hiddens.append(h)
        
        # decode the outputs
        hidden = torch.stack(hiddens, dim=1)
        logits = self.h2o(hidden)
        h = h.detach()

        # calculate loss if targets are provided
        loss = None
        if targets is not None:
            if targets is not None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        
        return logits, loss
        
    
    def initHidden(self, batch_size, device):
        # Initialize the hidden state based on the current batch size
        return torch.zeros(batch_size, self.hidden_size).to(device)

In [55]:
mps_device = torch.device('mps')
hidden_size=64

model = RNN(vocab_size, hidden_size)
model.to(mps_device)
optimizer = torch.optim.Adam(model.parameters())

num_epochs = 5

step = 0
for epoch in range(num_epochs):
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    
    # get batch
    for batch in progress_bar:
        X, Y = batch
        X, Y = X.to(mps_device), Y.to(mps_device)

        # feed into the model
        logits, loss = model(X, Y)

        # calculate the gradient, update the weights
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # logging
        if step % 100 == 0:
            print(f"step {step} | loss {loss.item():.4f}")
        
        if step > 0 and step % 500 == 0:
            train_loss, train_accuracy = evaluate(model, train_dataset, batch_size=64, max_batches=10)
            test_loss, test_accuracy  = evaluate(model, test_dataset,  batch_size=64, max_batches=10)
            print(f'LOSS')
            print(train_loss, test_loss)
            print('ACCURACY')
            print(train_accuracy, test_accuracy)
        
        step += 1

Epoch 1/5 [Train]:   1%|          | 5/496 [00:00<00:25, 19.54it/s]

step 0 | loss 4.2373


Epoch 1/5 [Train]:  21%|██        | 104/496 [00:05<00:19, 20.47it/s]

step 100 | loss 2.6489


Epoch 1/5 [Train]:  41%|████      | 204/496 [00:10<00:14, 20.56it/s]

step 200 | loss 2.4441


Epoch 1/5 [Train]:  61%|██████    | 303/496 [00:14<00:09, 21.01it/s]

step 300 | loss 2.2876


Epoch 1/5 [Train]:  82%|████████▏ | 405/496 [00:19<00:04, 21.07it/s]

step 400 | loss 2.2166


Epoch 1/5 [Train]: 100%|██████████| 496/496 [00:24<00:00, 20.64it/s]
Epoch 2/5 [Train]:   1%|          | 3/496 [00:00<00:23, 20.77it/s]

step 500 | loss 2.1128


Epoch 2/5 [Train]:   2%|▏         | 8/496 [00:00<00:47, 10.32it/s]

LOSS
2.139448642730713 2.1553823947906494
ACCURACY
0.20106090198863635 0.20192649147727273


Epoch 2/5 [Train]:  22%|██▏       | 108/496 [00:05<00:18, 21.06it/s]

step 600 | loss 2.0799


Epoch 2/5 [Train]:  42%|████▏     | 207/496 [00:10<00:13, 21.07it/s]

step 700 | loss 2.0445


Epoch 2/5 [Train]:  62%|██████▏   | 309/496 [00:15<00:08, 20.98it/s]

step 800 | loss 2.0393


Epoch 2/5 [Train]:  82%|████████▏ | 408/496 [00:19<00:04, 21.09it/s]

step 900 | loss 2.0512


Epoch 2/5 [Train]: 100%|██████████| 496/496 [00:24<00:00, 20.66it/s]
Epoch 3/5 [Train]:   1%|          | 6/496 [00:00<00:23, 20.70it/s]

step 1000 | loss 1.9563


Epoch 3/5 [Train]:   2%|▏         | 11/496 [00:00<00:42, 11.29it/s]

LOSS
2.004164934158325 2.0106141567230225
ACCURACY
0.22369939630681818 0.21952681107954544


Epoch 3/5 [Train]:  23%|██▎       | 112/496 [00:05<00:18, 21.02it/s]

step 1100 | loss 1.9487


Epoch 3/5 [Train]:  43%|████▎     | 211/496 [00:10<00:13, 20.54it/s]

step 1200 | loss 2.0290


Epoch 3/5 [Train]:  63%|██████▎   | 313/496 [00:15<00:09, 20.00it/s]

step 1300 | loss 1.9559


Epoch 3/5 [Train]:  83%|████████▎ | 411/496 [00:20<00:04, 20.98it/s]

step 1400 | loss 1.9229


Epoch 3/5 [Train]: 100%|██████████| 496/496 [00:24<00:00, 20.51it/s]
Epoch 4/5 [Train]:   2%|▏         | 12/496 [00:00<00:22, 21.28it/s]

step 1500 | loss 1.9187


Epoch 4/5 [Train]:   3%|▎         | 17/496 [00:01<00:38, 12.34it/s]

LOSS
1.9215422868728638 1.9208076000213623
ACCURACY
0.23206676136363635 0.23326526988636365


Epoch 4/5 [Train]:  23%|██▎       | 115/496 [00:05<00:18, 21.04it/s]

step 1600 | loss 1.9148


Epoch 4/5 [Train]:  44%|████▍     | 217/496 [00:10<00:13, 21.05it/s]

step 1700 | loss 1.9083


Epoch 4/5 [Train]:  64%|██████▎   | 316/496 [00:15<00:08, 21.00it/s]

step 1800 | loss 1.8427


Epoch 4/5 [Train]:  84%|████████▎ | 415/496 [00:20<00:03, 21.04it/s]

step 1900 | loss 1.8992


Epoch 4/5 [Train]: 100%|██████████| 496/496 [00:23<00:00, 20.71it/s]
Epoch 5/5 [Train]:   3%|▎         | 14/496 [00:00<00:22, 21.14it/s]

step 2000 | loss 1.9133


Epoch 5/5 [Train]:   4%|▍         | 20/496 [00:01<00:36, 13.18it/s]

LOSS
1.8672634363174438 1.8908804655075073
ACCURACY
0.23479669744318182 0.23652787642045456


Epoch 5/5 [Train]:  24%|██▍       | 119/496 [00:06<00:18, 20.92it/s]

step 2100 | loss 1.8781


Epoch 5/5 [Train]:  45%|████▍     | 221/496 [00:10<00:13, 21.10it/s]

step 2200 | loss 1.8564


Epoch 5/5 [Train]:  65%|██████▍   | 320/496 [00:15<00:08, 21.27it/s]

step 2300 | loss 1.8068


Epoch 5/5 [Train]:  84%|████████▍ | 419/496 [00:20<00:03, 20.95it/s]

step 2400 | loss 1.7893


Epoch 5/5 [Train]: 100%|██████████| 496/496 [00:23<00:00, 20.71it/s]


In [56]:
print(generate_sequence(model, 'i am very interested', 150, train_dataset.stoi, train_dataset.itos))

i am very interested the shall the shall the shall the shall the shall the shall the shall the shall the shall the shall the shall the shall the shall


In [None]:
# lstm

In [None]:
# lstm with dropout

In [None]:
import requests
import gzip
import pandas as pd

# URL of the gzipped text file
url = "https://github.com/lsb/human-numbers/blob/trunk/one-hundred-thousand-numbers.txt.gz?raw=true"

# Downloading the file using requests
response = requests.get(url)
response.raise_for_status()  # This will raise an error if the download failed

# Unzipping the content
content = gzip.decompress(response.content).decode('utf-8')

# Since the file contains numbers, each number on a new line, we can split the content into a list
numbers = content.splitlines()

In [None]:
text = ' '.join(x for x in numbers)
tokens = text.split(' ')

tokens[:10]

In [None]:
vocab = sorted(list(set(tokens)))
len(vocab)

In [None]:
# token to numbers
word2idx = {w:i for i, w in enumerate(vocab)}
idx2word = {i:w for i, w in enumerate(vocab)}
nums = [word2idx[i] for i in tokens]
nums[:10]

In [None]:
for idx in nums[:25]:
    print(idx, vocab[idx])

### dataset prep

In [None]:
dummy_tokens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']  # Example list

for i in range(0, len(dummy_tokens) - 3, 1):
    three_tokens = dummy_tokens[i:i+3]  # Get a slice of three tokens
    next_token = dummy_tokens[i+3]      # Get the token immediately following the slice
    print(f"Three tokens: {three_tokens}, Next token: {next_token}, Step: {i}")

In [None]:
### predict next token based on three previous tokens; book uses step size of 3 which has no overlap, i prefer 1
[(tokens[i:i+3], tokens[i+3]) for i in range(0, len(tokens)-3, 1)][2000:2020] # change from 4-2 to 3-1

In [None]:
import torch

# Assuming 'mps_device' is defined as your MPS device
mps_device = torch.device('mps')

In [None]:
# create dataset karpathy style
xs = []
ys = []
for i in range(0, len(tokens) - 3, 1):
    three_tokens = torch.tensor(nums[i:i+3])  # Get a slice of three tokens
    next_token = torch.tensor(nums[i+3])      # Get the token immediately following the slice
    xs.append(three_tokens)
    ys.append(next_token)

In [None]:
from collections import Counter

def get_most_common_scalar(data):
    """given list of scalars, find the most comon scalar value"""
    counter = Counter(data)
    most_common = counter.most_common(1)[0][0]
    return idx2word[most_common.item()]

get_most_common_scalar(ys)

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class LLMModel1(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super(LLMModel1, self).__init__()  # Initialize the superclass
        self.i_h = nn.Embedding(vocab_size, n_hidden) #vocab to hidden
        self.h_h = nn.Linear(n_hidden, n_hidden) # hidden to hidden
        self.h_o = nn.Linear(n_hidden, vocab_size) # hidden to vocab (logits)
        
    def forward(self, x):
        """hidden states are accumulated. subsequent hidden state is added to embedding of next token before being passed through next linear layer and ReLU"""
        # create first hidden state from first word
        # embed --> linear --> relu
        h = F.relu(self.h_h(self.i_h(x[:, 0])))

        # second hidden state from second word
        h = h + self.i_h(x[:, 1])
        h = F.relu(self.h_h(h))

        # hidden state from third word
        h = h + self.i_h(x[:, 2])
        h = F.relu(self.h_h(h))
        return self.h_o(h)

In [None]:
from torch.utils.data import TensorDataset, DataLoader

# Assuming X and Y are your data tensors
X = torch.stack(xs)
Y = torch.stack(ys)
dataset = TensorDataset(X, Y)

# Calculate the sizes of splits
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset (this method shuffles the data)
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Split the dataset without shuffling
train_dataset = TensorDataset(X[:train_size], Y[:train_size])
val_dataset = TensorDataset(X[train_size:], Y[train_size:])

# Create data loaders with drop_last=True to ensure all batches have the same size
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)


In [None]:
from tqdm import tqdm

def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in progress_bar:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update progress bar
            avg_loss = total_loss / total
            accuracy = 100 * correct / total
            progress_bar.set_postfix(loss=avg_loss, accuracy=f'{accuracy:.2f}%')

        train_losses.append(total_loss / len(train_loader))
        train_accuracies.append(100 * correct / total)

        # Validation phase
        model.eval()
        total_loss = 0
        correct = 0
        total = 0
        progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Validation]')
        with torch.no_grad():
            for batch in progress_bar:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Update progress bar
                avg_loss = total_loss / total
                accuracy = 100 * correct / total
                progress_bar.set_postfix(loss=avg_loss, accuracy=f'{accuracy:.2f}%')

        val_losses.append(total_loss / len(val_loader))
        val_accuracies.append(100 * correct / total)

    return train_losses, val_losses, train_accuracies, val_accuracies, model


In [None]:
vocab_size = len(vocab)
n_hidden = 64
model = LLMModel1(vocab_size, n_hidden)
model.to(mps_device)  # Move model to MPS device

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

num_epochs = 2
device = mps_device

train_losses, val_losses, train_accuracies, val_accuracies, model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device)

In [None]:
# print number of params in model
def print_params(model):
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(f"#Params: {num_params}")
    return num_params

def predict_next_word(model, input_sequence, word2idx, idx2word):
    # Ensure the model is in evaluation mode
    model.eval()
    
    # Convert the input sequence to a tensor of word indices
    input_indices = [word2idx[word] for word in input_sequence]
    input_tensor = torch.tensor(input_indices, dtype=torch.long).unsqueeze(0)  # Add a batch dimension
    input_tensor = input_tensor.to(device)  # Move the tensor to the appropriate device
    
    # Get the prediction
    with torch.no_grad():  # No need to track gradients for prediction
        output = model(input_tensor)
    
    # Get the predicted word index
    _, predicted_index = torch.max(output, 1) # output is of shape (1, vocab_size)
    predicted_index = predicted_index.item()  # Convert to a Python integer
    
    # Convert the predicted index to the corresponding word
    predicted_word = idx2word[predicted_index]
    print(f'input sequence: {input_sequence}.\nprediction: {predicted_word}')
    model.reset()


print_params(model)

# Example usage
input_sequence = ['one', 'two', 'three']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['twenty', 'one', 'twenty']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['zero', 'one', 'two']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['one', 'hundred', 'one']
predict_next_word(model, input_sequence, word2idx, idx2word)


In [None]:
input_sequence = ['twenty', 'one', 'twenty']

model.eval()
    
# Convert the input sequence to a tensor of word indices
input_indices = [word2idx[word] for word in input_sequence]
input_tensor = torch.tensor(input_indices, dtype=torch.long).unsqueeze(0)  # Add a batch dimension
input_tensor = input_tensor.to(device)  # Move the tensor to the appropriate device

with torch.no_grad():  # No need to track gradients for prediction
    output = model(input_tensor)

_, predicted_index = torch.max(output, 1) # output is of shape (1, vocab_size)
predicted_index = predicted_index.item()  # Convert to a Python integer
    
# Convert the predicted index to the corresponding word
predicted_word = idx2word[predicted_index]

predicted_word

In [None]:
output.shape

In [None]:
import matplotlib.pyplot as plt

def plot_training_history(train_losses, val_losses, train_accuracies=None, val_accuracies=None):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 4))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'bo-', label='Training loss')
    plt.plot(epochs, val_losses, 'ro-', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy if provided
    if train_accuracies and val_accuracies:
        plt.subplot(1, 2, 2)
        plt.plot(epochs, train_accuracies, 'bo-', label='Training accuracy')
        plt.plot(epochs, val_accuracies, 'ro-', label='Validation accuracy')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

    plt.show()


In [None]:
plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies)

### refactor to RNN

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class LLMModel2(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super(LLMModel2, self).__init__()  # Initialize the superclass
        self.i_h = nn.Embedding(vocab_size, n_hidden) #vocab to hidden
        self.h_h = nn.Linear(n_hidden, n_hidden) # hidden to hidden
        self.h_o = nn.Linear(n_hidden, vocab_size) # hidden to vocab (logits)
        
    def forward(self, x):
        """hidden states are accumulated. subsequent hidden state is added to embedding of next token before being passed through next linear layer and ReLU"""
        h = 0
        for i in range(3):
            h = h + self.i_h(x[:, i])
            h = F.relu(self.h_h(h))
        return self.h_o(h)



In [None]:
torch.mps.empty_cache()

In [None]:
vocab_size = len(vocab)
n_hidden = 64
model = LLMModel2(vocab_size, n_hidden)
model.to(mps_device)  # Move model to MPS device

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

train_losses, val_losses, train_accuracies, val_accuracies, model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device)

In [None]:
# Example usage
input_sequence = ['one', 'two', 'three']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['twenty', 'one', 'twenty']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['zero', 'one', 'two']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['one', 'hundred', 'one']
predict_next_word(model, input_sequence, word2idx, idx2word)

### maintaining hidden state

- instead of resetting to 0 for each forward

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class LLMModel3(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super(LLMModel3, self).__init__()  # Initialize the superclass
        self.i_h = nn.Embedding(vocab_size, n_hidden) #vocab to hidden
        self.h_h = nn.Linear(n_hidden, n_hidden) # hidden to hidden
        self.h_o = nn.Linear(n_hidden, vocab_size) # hidden to vocab (logits)
        self.h = 0
        
    def forward(self, x):
        """hidden states are accumulated. subsequent hidden state is added to embedding of next token before being passed through next linear layer and ReLU"""
        for i in range(3):
            self.h = self.h + self.i_h(x[:, i])
            self.h = F.relu(self.h_h(self.h))
        out = self.h_o(self.h)
        self.h = self.h.detach() # stop tracking previous gradients
        return out

    def reset(self):
        # Resets or reinitializes the hidden state to zero
        self.h = 0


# detach means we only calculate gradients of previous 3 steps (for loop) aka the steps used to calculate this self.h; this is used to avoid vanishing/exploding graidnets

In [None]:
torch.mps.empty_cache()


model = LLMModel3(vocab_size, n_hidden)
model.to(mps_device)  # Move model to MPS device

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

train_losses, val_losses, train_accuracies, val_accuracies, model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device)

In [None]:
print_params(model)

In [None]:
model

In [None]:
model.reset()

# Example usage
input_sequence = ['one', 'two', 'three']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['twenty', 'one', 'twenty']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['zero', 'one', 'two']
predict_next_word(model, input_sequence, word2idx, idx2word)

input_sequence = ['one', 'hundred', 'one']
predict_next_word(model, input_sequence, word2idx, idx2word)



slight boost in performance...

### now instead of outputting one word per three input words, output after each step...

this gives it more signal for back prop

In [None]:
sl = 16
nums[i:i+sl], nums[i+1:i+sl+1]

In [None]:
# create dataset karpathy style
sl = 16

xs = []
ys = []
for i in range(0, len(nums) - sl-1, 1):
    sixteen_tokens = torch.tensor(nums[i:i+16])  # Get a slice of three tokens
    next_sixteen_tokens = torch.tensor(nums[i+1: i+sl+1])      # Get the token immediately following the slice
    xs.append(sixteen_tokens)
    ys.append(next_sixteen_tokens)

# Assuming X and Y are your data tensors
X = torch.stack(xs)
Y = torch.stack(ys)
dataset = TensorDataset(X, Y)

# Calculate the sizes of splits
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset (this method shuffles the data)
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Split the dataset without shuffling
train_dataset = TensorDataset(X[:train_size], Y[:train_size])
val_dataset = TensorDataset(X[train_size:], Y[train_size:])

# Create data loaders with drop_last=True to ensure all batches have the same size
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)


In [None]:
i_ = 1
xs[i_], ys[i_]

for num in xs[i_]:
    print(vocab[num])

print('\n')

for num in ys[i_]:
    print(vocab[num])

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class LLMModel4(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super(LLMModel4, self).__init__()  # Initialize the superclass
        self.i_h = nn.Embedding(vocab_size, n_hidden) #vocab to hidden
        self.h_h = nn.Linear(n_hidden, n_hidden) # hidden to hidden
        self.h_o = nn.Linear(n_hidden, vocab_size) # hidden to vocab (logits)
        self.h = 0
        
    def forward(self, x):
        """output a prediction after each word

        will return outs as shape batch_size x seq_len x vocab_sz
        """
        outs = [] # (batch_size x vocab_sz)
        for i in range(sl): # CHANGED TO SL
            self.h = self.h + self.i_h(x[:, i])
            self.h = F.relu(self.h_h(self.h))
            outs.append(self.h_o(self.h))
        self.h = self.h.detach() # stop tracking previous gradients
        return torch.stack(outs, dim=1)

    def reset(self):
        # Resets or reinitializes the hidden state to zero
        self.h = 0


In [None]:
# dumb but we'll update the function for outputs.shape

from tqdm import tqdm

def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in progress_bar:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            # FLATTTEN AND RESHAPE
            outputs = outputs.view(-1, vocab_size)
            labels = labels.view(-1)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update progress bar
            avg_loss = total_loss / total
            accuracy = 100 * correct / total
            progress_bar.set_postfix(loss=avg_loss, accuracy=f'{accuracy:.2f}%')

        train_losses.append(total_loss / len(train_loader))
        train_accuracies.append(100 * correct / total)

        # Validation phase
        model.eval()
        total_loss = 0
        correct = 0
        total = 0
        progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Validation]')
        with torch.no_grad():
            for batch in progress_bar:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)

                outputs = outputs.view(-1, vocab_size)
                labels = labels.view(-1)

                loss = criterion(outputs, labels)

                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Update progress bar
                avg_loss = total_loss / total
                accuracy = 100 * correct / total
                progress_bar.set_postfix(loss=avg_loss, accuracy=f'{accuracy:.2f}%')

        val_losses.append(total_loss / len(val_loader))
        val_accuracies.append(100 * correct / total)

    return train_losses, val_losses, train_accuracies, val_accuracies, model

In [None]:
torch.mps.empty_cache()


model = LLMModel4(vocab_size, n_hidden)
model.to(mps_device)  # Move model to MPS device

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 2

train_losses, val_losses, train_accuracies, val_accuracies, model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device)

In [None]:
# Define your sequence length
sl = 16

model.reset()

# Make sure to have your word2idx dictionary ready to convert words to indices
input_sequence = ['one', 'hundred', 'one', 
                  'one', 'hundred', 'two',
                  'one', 'hundred', 'three', 
                  'one', 'hundred', 'four',
                  'one', 'hundred', 'five',
                  'one']  # sequence of length sl

# Convert input sequence to tensor of word indices
input_indices = [word2idx[word] for word in input_sequence]
input_tensor = torch.tensor(input_indices, dtype=torch.long).unsqueeze(0)  # Add batch dimension
print(input_tensor.shape)
input_tensor = input_tensor.to(device)  # Move to the appropriate device
print(input_tensor.shape)

# Set model to evaluation mode
model.eval()

# Get predictions
with torch.no_grad():
    logits = model(input_tensor)  # Logits will have shape [1, sl, vocab_size] since we did unsqueeze

# If you want to get the most likely word predictions for each position in the sequence:
predictions = torch.argmax(logits, dim=2)  # This will give you [1, sl] tensor of word indices

# Convert predicted indices to words
predicted_words = [idx2word[idx.item()] for idx in predictions[0]]
predicted_words

### multi-layer RNNs 

will suffer from exploding...vanishing gradients...

In [None]:
### LSTM

class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh) # since we stack ni + nh into one tensor
        self.input_gate = nn.Linear(ni + nh, nh)
        self.cell_gate = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)
    
    def forward(self, input, state):
        """
        h and c start of shape (batch, nh)
        h becomes shape (batch, nh+ni) after being concat with input

        forget, c, inp, cell, and out are all of shape (batch, nh)
        h is passed into the linear layers as (batch, ni + nh)

        ### MATH
        concat([h, input], dim=1) * W_forget + b_forget -> forget "how much past memory to let go"
        concat([h, input], dim=1) * W_input + b_input -> inp "how much new information to let in"
        concat([h, input], dim=1) * W_cell + b_cell -> cell "what are the candidate values for new information"
        concat([h, input], dim=1) * W_output + b_output -> out "baesd on new cell state, what should my output be?"
        """
        h, c = state # hidden state and cell state; both of shape (batch, nh)
        h = torch.stack([h, input], dim=1) # stack h and input together so h is shape (batch, nh + ni)

        # forget cell state
        forget = torch.sigmoid(self.forget_gate(h)) # input shape (batch, nh + ni); output shape (batch, nh); squashed to be between values of 0 and 1
        c = c * forget # (batch, nh) * (batch, nh) = (batch, nh)

        # update cell state
        inp = torch.sigmoid(self.input_gate(h)) # input shape is (batch, nh + ni); output shape is (batch, nh)
        cell = torch.tanh(self.cell_gate(h)) # input shape is (batch, nh + ni); output shape is (batch, nh)
        c = c + (inp * cell) # (batch, nh) + (batch, nh) = (batch, nh)

        # generate new hidden state, using h and c
        out = torch.sigmoid(self.output_gate(h)) # input shape is (batch, nh + ni); output shape is (batch, nh)
        h = out * torch.tanh(c) # input shape of tanh is (batch, nh); output shape is (batch, nh)
        return h, (h, c)
    
### Refactor for speed into one giant matrix multiplication
    
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        self.ih = nn.Linear(ni, 4 * nh)
        self.hh = nn.Linear(nh, 4 * nh)
    
    def forward(self, input, state):
        h, c = state # hidden state and cell state 
        # One big matmul for all gates
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate, forgetgate, outgate = map(torch.sigmoid, gates[:3]) 
        cellgate = gates[3].tanh()

        c = (forgetgate * c) + (ingate * cellgate)
        h = outgate * c.tanh()
        return h, (h, c)


In [None]:
t = torch.arange(0, 10); t

t.chunk(2)

```
x * ih -> [batch_size, ni] * [ni, 4*nh] = [batch_size, 4*nh]
h * hh -> [batch_size, nh] * [nh, 4*nh] = [batch_size, 4*nh]
gates = (x * ih) + (h * hh) -> [batch_size, 4*nh]

ingate, forgetgate, outgate, cellgate = gates.chunk(4, 1)


ingate = sigmoid([batch_size, nh])
forgetgate = sigmoid([batch_size, nh])
outgate = sigmoid([batch_size, nh])
cellgate = tanh([batch_size, nh])

c = (forgetgate * c) + (ingate * tanh(cellgate))
h = outgate * tanh(c)
```


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LMModel6(nn.Module):
    def __init__(self, vocab_size, n_hidden, n_layers):
        super(LMModel6, self).__init__()  # Initialize the superclass

        # An embedding layer that converts input data (indices) to embeddings of shape (batch_size, n_hidden)
        self.i_h = nn.Embedding(vocab_size, n_hidden)

        # The LSTM layers
        self.lstm = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)

        # A fully connected output layer that maps the LSTM layer outputs to vocabulary size; also often called FC for fully connected
        self.h_o = nn.Linear(n_hidden, vocab_size)

        # Initialize the hidden and cell states with zeros
        self.h = self.init_hidden(n_layers, n_hidden)

    def forward(self, x):
        # Convert input indices to embeddings
        embeds = self.i_h(x)
        # get the output and new hidden state from the LSTM; lstm_out is of shape (batch_size, seq_len, n_hidden)
        lstm_out, self.h = self.lstm(embeds, self.h)
        # detach
        self.h = [h_.detach() for h_ in self.h]
        # fc / h_o
        return self.h_o(lstm_out)


    def init_hidden(self, n_layers, n_hidden):
        # generates the first hidden state of zeros which we'll use in the forward pass; right now its hard-coded to batch size of 64...
        # there is two because its for hidden + cell state
        return [torch.zeros(n_layers, 64, n_hidden) for _ in range(2)]

    def reset(self):
        # Reset the hidden state. Needed in case of detaching the hidden state between batches
        for h in self.h:
            h.zero_()


In [None]:
torch.mps.empty_cache()


model = LMModel6(vocab_size, n_hidden, 2)
model.to(mps_device)  # Move model to MPS device

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 3

train_losses, val_losses, train_accuracies, val_accuracies, model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, device)

### detour of detach

In [None]:
import torch

# Creating a simple computation graph
x = torch.tensor([2.], requires_grad=True)
y = x * 2

# Without detach
z = y * y  # z = (x * 2) * (x * 2) = 4x^2
z.backward()  # Computes gradients for the whole graph
gradients_without_detach = x.grad.item()  # Should be dz/dx = 8x, which is 16 when x = 2

# Resetting gradients
x.grad.data.zero_()

# With detach
y_detached = y.detach()  # y_detached is a new tensor with the same value as y but no history of operations
# No need to call z.backward() because y_detached has no grad_fn
# and thus won't contribute to gradients in x

# The gradient of x should still be zero since the detached part does not contribute to the computation
gradients_with_detach = x.grad.item()

gradients_without_detach, gradients_with_detach


In [None]:
## torch.stack()
lst = [torch.rand(2, 3) for i in range(3)]
print([x.shape for x in lst])

# stacking on dimension 1 maintains batch dimension, second dimension (1) becomes the timesteps
torch.stack(lst, dim=1).shape, torch.stack(lst, dim=1)

In [None]:
torch.stack(lst, dim=0).shape, torch.stack(lst, dim=0)