# Language Modeling With PyTorch

## Part 1 – Solutions

This notebook is a companion of [Language Modeling with PyTorch – Part 1](./32_language_modeling_1.ipynb) notebook, and contains *proposed* solutions to the exercises.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

## 1. Build a Trigram model

A trigram language model predicts the next character based on the previous **two** characters, unlike a bigram model which only uses one previous character. This additional context should help make more accurate predictions.

In this exercise, we will:
1. Train a trigram language model that takes **two** characters as input to predict the 3rd one
2. Implement this using a neural network approach
3. Evaluate the model's performance using loss metrics
4. Compare its performance to the bigram model

**Key Questions to Address:**
1. Did the trigram model improve over the bigram model?
2. If yes, by what percentage did it improve?

**Intuition:** By considering two previous characters instead of just one, a trigram model captures more context and patterns in the language, which should lead to better predictions and lower loss compared to a bigram model.

In [None]:
# Set training device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
words = open("data/lm/names.txt").read().splitlines()

### Understanding Trigrams

Each trigram consists of a sequence of $2$ input characters, followed by $1$ expected output character.
The model's task is to predict the output character given the two input characters.

To generate these trigrams from our text data, we can extend the sliding-window approach we used for bigrams:

In [None]:
for w in words[:1]:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):  # Three char 'sliding-window'
        print(ch1, ch2, ch3)

**Important Implementation Detail:**

Notice how our first trigram is of shape `('.', 'e', 'm')`, and not `('.', '.', 'e')`. This is a deliberate choice.

While we could modify the code to produce `('.', '.', 'e')` with `chs = ['.', '.'] + list(w) + ['.', '.']`, having the first two characters as special tokens (both '.') wouldn't provide meaningful context for predicting the next character. Special tokens at the beginning don't contain real linguistic patterns, so starting with `('.')` followed by the first actual character gives our model more useful information.

This approach helps avoid confusing the model and reduces wasted computation on inputs that don't reflect natural language patterns.

### Visualizing Trigram Occurrences

To properly represent and visualize our trigrams, we need to account for all possible character combinations:
- We have $26$ letters from the English alphabet
- **+1** special character ('.') for word boundaries

This gives us a 3D array of dimensions $27\times 27\times 27$ to store all possible trigram combinations:
- First dimension: the first character in the trigram
- Second dimension: the second character in the trigram
- Third dimension: the third (predicted) character

Each cell in this 3D array will store the count of how many times that specific trigram appears in our training data.

In [None]:
N = torch.zeros((27, 27, 27), dtype=torch.int32)

chars = sorted(set("".join(words)))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0  # Special token has position zero
itos = {i: s for s, i in stoi.items()}

for w in words:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):  # Three token 'sliding-window'
        N[stoi[ch1], stoi[ch2], stoi[ch3]] += 1  # Increment cell in 3D tensor by 1

Let's create a more informative visualization of our trigram counts.
We'll visualize the most frequent first characters (up to 3) rather than just the first two

In [None]:
# Find the first characters with the most occurrences
first_char_counts = N.sum(dim=(1, 2))
top_k_indices = torch.topk(first_char_counts, k=3).indices.tolist()

# Add special character (.) to always show regardless of frequency
if 0 not in top_k_indices:
    top_k_indices = [0] + top_k_indices[:2]  # Ensure we include the special token

for k in top_k_indices:
    # Only show non-zero entries for clarity
    nonzero_mask = N[k] > 0

    if nonzero_mask.sum() > 0:  # Skip empty slices
        plt.figure(figsize=(16, 14))

        # Use a perceptually uniform colormap with better contrast
        plt.imshow(
            N[k],
            cmap="viridis",
            norm=plt.matplotlib.colors.LogNorm(vmin=0.1, vmax=N[k].max()),
        )

        plt.colorbar(label="Frequency (log scale)")
        plt.title(f"Trigram Heatmap for First Character: '{itos[k]}'", fontsize=16)
        plt.xlabel("Third Character (Predicted)", fontsize=12)
        plt.ylabel("Second Character", fontsize=12)

        # Add labels showing both the trigram and its count where non-zero
        for i in range(27):
            for j in range(27):
                if N[k, i, j] > 0:  # Only label non-zero entries
                    count = N[k, i, j].item()
                    chstr = itos[k] + itos[i] + itos[j]

                    # Color text based on background darkness for better readability
                    # White text on dark backgrounds, black text on light backgrounds
                    norm_val = plt.matplotlib.colors.LogNorm()(count)
                    # Use normalized value - lower values are darker in viridis colormap
                    # Simple threshold - values below 0.5 get white text, others get black
                    text_color = "white" if norm_val < 0.5 else "black"

                    plt.text(
                        j,
                        i,
                        chstr,
                        ha="center",
                        va="bottom",
                        color=text_color,
                        fontsize=8,
                        fontweight="bold",
                    )
                    plt.text(
                        j, i, count, ha="center", va="top", color=text_color, fontsize=8
                    )

        # Add grid lines for better readability
        plt.grid(False)
        plt.tight_layout()
        plt.show()

        # Print summary stats for this first character
        total_count = N[k].sum().item()
        unique_trigrams = (N[k] > 0).sum().item()
        print(
            f"First char '{itos[k]}': {unique_trigrams} unique trigrams out of {total_count} total occurrences"
        )

**How to Interpret the Trigram Heatmaps?**

The heatmaps above visualize the frequency of trigram patterns in our dataset:

- **X-axis (horizontal)**: The third character in the trigram (the character we're trying to predict)
- **Y-axis (vertical)**: The second character in the trigram
- **Title**: Shows the first character, which is fixed for each heatmap

**Color intensity**: Represents frequency on a logarithmic scale - brighter/more intense colors indicate trigrams that appear more frequently in our dataset.

**Labels on cells**: Each cell shows:
- The full trigram (at the bottom of the cell)
- The count of occurrences (at the top of the cell)

**What to look for**:
- Bright cells indicate common character combinations
- Dark/empty areas show rare or non-existent combinations
- Patterns along rows/columns reveal which character sequences occur more frequently in names

These visualizations help us understand the statistical patterns our trigram model will learn to predict.

### Neural Network Implementation of the Trigram Model

Now let's implement the neural network version of our trigram model. Unlike the bigram model that handled a single input character, the key challenge here is processing two input characters simultaneously to predict the third.

The following code prepares our training data for the neural network by:
1. Converting each word into a sequence of trigrams
2. Extracting input pairs (first two characters) and output targets (third character)
3. Converting these into tensor representations the network can process

In [None]:
# Create training set of all trigrams
xs, ys = [], []

for w in words:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs.append([stoi[ch1], stoi[ch2]])
        ys.append(stoi[ch3])

xs, ys = torch.tensor(xs), torch.tensor(ys)  # [196113, 2], [196113]
num_x, num_y = xs.nelement() // 2, ys.nelement()
print("Number of examples\nx:", num_x, "\ny:", num_y)

### Character Representation for the Neural Network

For our neural network to process characters, we need to convert them to numerical representations:

- Each character is represented by a $27$-dimensional one-hot vector (one position for each of our 26 letters plus 1 special character)
- For a trigram, we need to concatenate the one-hot vectors of the two input characters
- This creates a $27+27=54$-dimensional input vector for each trigram

You can conceptualize this 54-dimensional vector as a "two-hot" vector, where exactly two of the 54 dimensions are set to $1$ (one for each input character) and the rest are $0$.

Next, we'll initialize our neural network weights and prepare for training:

In [None]:
g = torch.Generator(device=device).manual_seed(2147483647)

# Random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)
W = torch.randn((27 + 27, 27), device=device, generator=g, requires_grad=True)

In [None]:
# Training cycles, using the entire dataset over 200 Epochs, like the bigram model
for k in range(200):
    # Forward pass
    # One-hot encoding, [196113, 2, 27]
    xenc = F.one_hot(xs, num_classes=27).float().to(device)
    xenc = xenc.view(num_x, -1)  # concatenate the one-hot vectors, [196113, 54]
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram

    # Normal distribution probabilities (this is y_pred)
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num_x), ys].log().mean() + 0.01 * (W**2).mean()
    print(f"Loss @ iteration {k + 1}: {loss}")

    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool

    # Weight update
    W.data += -50 * W.grad

### Results and Comparison

**Performance Evaluation:**

- The bigram model produced a final loss of $2.462393045425415$
- Our trigram model achieves a loss of $2.259373664855957$
- This represents an $8.24\%$ improvement in prediction accuracy

**Conclusion:** As we hypothesized, providing the model with two characters of context (trigram) instead of just one (bigram) allows it to better capture language patterns and make more accurate predictions. This demonstrates how increasing the context window in language models leads to improved performance, a principle that extends to modern large language models which use much larger context windows.

## 2. Split the dataset

In real-world machine learning applications, we need to evaluate how well our models generalize to unseen data. To do this, we typically split our dataset into three parts:

- **Training set (80%)**: Used to train the model parameters
- **Validation/Dev set (10%)**: Used for hyperparameter tuning and model selection
- **Test set (10%)**: Used only for final evaluation to estimate real-world performance

In this exercise, we will:
1. Randomly split our dataset following the 80:10:10 ratio
2. Train our language models (bigram and trigram) **only** on the training set
3. Evaluate performance on both validation and test sets
4. Observe patterns in generalization performance

**Key Questions:**
- How do the models perform on unseen data compared to training data?
- Which model generalizes better: bigram or trigram?
- What does this tell us about the tradeoff between model complexity and generalization?

In [None]:
g = torch.Generator(device=device).manual_seed(2147483647)

### Bigram model baseline

Let's start by establishing a baseline using our familiar bigram model. This simpler model will help us understand:

1. How well a basic model can generalize to new data
2. Provide a comparison point for our more complex trigram model

We'll reuse the same architecture as before, but now restrict training to just the training portion of our data.

In [None]:
# Create set of all *bigrams*
xs, ys = [], []

for w in words:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2 in zip(chs, chs[1:]):
        xs.append(stoi[ch1])
        ys.append(stoi[ch2])

xs, ys = torch.tensor(xs), torch.tensor(ys)  # [196113], [196113]
num_x, num_y = xs.nelement(), ys.nelement()

# Shuffle/Permute the dataset, keeping pairs in sync
perm = torch.randperm(num_x)
xs, ys = xs[perm], ys[perm]

# Split 80:10:10 for train:valid:test
xs_bi_train, xs_bi_valid, xs_bi_test = (
    xs[: int(num_x * 0.8)],
    xs[int(num_x * 0.8) : int(num_x * 0.9)],
    xs[int(num_x * 0.9) :],
)
ys_bi_train, ys_bi_valid, ys_bi_test = (
    ys[: int(num_x * 0.8)],
    ys[int(num_x * 0.8) : int(num_x * 0.9)],
    ys[int(num_x * 0.9) :],
)

In [None]:
W = torch.randn((27, 27), device=device, generator=g, requires_grad=True)

# Training cycles, using the entire dataset -> 200 Epochs
for k in range(200):
    # Forward pass
    xenc = (
        F.one_hot(xs_bi_train, num_classes=27).float().to(device)
    )  # one-hot encode the names
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(len(probs)), ys_bi_train].log().mean()
        + 0.01 * (W**2).mean()
    )
    print(f"Loss @ iteration {k + 1}: {loss}")
    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool
    # Weight update
    W.data += -50 * W.grad

### Evaluating Bigram Training Performance

The bigram model trained on our 80% training set reaches a final loss of $2.4826600551605225$.

**Note**: This is slightly worse than when we trained on the full dataset (which is expected). The model has access to less information when trained on only 80% of the data.

However, this approach offers a crucial advantage: we can now measure how well our model generalizes to unseen examples using our held-out validation and test sets. This will tell us whether our model is:

- **Underfitting**: Performing poorly on both training and validation sets
- **Overfitting**: Performing well on training but poorly on validation
- **Generalizing well**: Performing similarly on both training and validation/test sets

In [None]:
# Validation Loss
with torch.no_grad():
    xenc = (
        F.one_hot(xs_bi_valid, num_classes=27).float().to(device)
    )  # one-hot encode the names
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(len(probs)), ys_bi_valid].log().mean()
        + 0.01 * (W**2).mean()
    )
print(f"Validation Loss: {loss}")

# Test Loss
with torch.no_grad():
    xenc = (
        F.one_hot(xs_bi_test, num_classes=27).float().to(device)
    )  # one-hot encode the names
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(len(probs)), ys_bi_test].log().mean() + 0.01 * (W**2).mean()
    )
print(f"Test Loss:\t {loss}")

### Bigram Model Generalization Analysis

**Key Observation**: The train, validation, and test losses are all very close to each other!

This indicates that our bigram model generalizes well to unseen data. When a model performs similarly across all three data splits, it suggests that:

1. The model has learned meaningful patterns rather than just memorizing the training data
2. The statistical patterns of character sequences in names are relatively consistent across our dataset
3. The bigram model's complexity level is appropriate for this particular task

This good generalization makes sense intuitively - with only 27² = 729 possible bigrams, our model can easily encounter most meaningful character combinations during training, allowing it to handle similar patterns in the validation and test sets.

### Compare the Bigram and Trigram Models

Now that we've established our bigram baseline, let's implement and evaluate our trigram model on the same data splits. This comparison will help us understand the tradeoffs between:

1. **Model complexity**: Trigrams capture more context but have more parameters
2. **Generalization ability**: How well each model performs on unseen data
3. **Sample efficiency**: How effectively each model learns from limited training data

### Predicting Trigram Generalization Performance

**Hypothesis:** Will the trigram model generalize as well as the bigram model? Let's analyze this question before seeing the results.

**Mathematical perspective**:
- A trigram model must learn patterns for $27^3 = 19,683$ possible character combinations
- A bigram model only needs to learn $27^2 = 729$ possible combinations
- That's a 27x increase in the possible patterns to learn!

**Model complexity implications**:
- The trigram model's weight matrix `W` has dimensions $(27+27) \times 27 = 54 \times 27 = 1,458$ parameters
- This larger parameter space creates a higher-dimensional optimization problem
- While the model can theoretically capture more nuanced patterns in names, it needs more data to do so reliably

**Expected generalization behavior**:
- Due to this increased complexity, the trigram model will likely show some gap between training and validation/test performance
- We expect both validation and test losses to be higher than the training loss
- However, if the gap remains small, it would indicate that the additional complexity is justified by the improved modeling capacity

In [None]:
# Create set of all *trigrams*
xs, ys = [], []

for w in words:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs.append([stoi[ch1], stoi[ch2]])
        ys.append(stoi[ch3])

xs, ys = torch.tensor(xs), torch.tensor(ys)  # [196113, 2], [196113]
num_x, num_y = xs.nelement() // 2, ys.nelement()

# Shuffle/Permute the dataset, keeping (x,y) pairs in sync
perm = torch.randperm(num_x)
xs, ys = xs[perm, :], ys[perm]  # xs are shuffled along the zeroth dimension

# Split 80:10:10 for train:valid:test
xs_tri_train, xs_tri_valid, xs_tri_test = (
    xs[: int(num_x * 0.8), :],
    xs[int(num_x * 0.8) : int(num_x * 0.9), :],
    xs[int(num_x * 0.9) :, :],
)
ys_tri_train, ys_tri_valid, ys_tri_test = (
    ys[: int(num_x * 0.8)],
    ys[int(num_x * 0.8) : int(num_x * 0.9)],
    ys[int(num_x * 0.9) :],
)

In [None]:
W = torch.randn(
    (27 + 27, 27), device=device, generator=g, requires_grad=True
)  # random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)

# Training cycles, using the entire dataset -> 200 Epochs, like the bigram model
d_size = xs_tri_train.shape[0]
for k in range(200):
    # Forward pass
    xenc = (
        F.one_hot(xs_tri_train, num_classes=27).float().to(device)
    )  # One-hot encoding, [196113, 2, 27]
    xenc = xenc.view(d_size, -1)  # concatenate the one-hot vectors, [196113, 54]
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(d_size), ys_tri_train].log().mean() + 0.01 * (W**2).mean()
    )
    print(f"Loss @ iteration {k + 1}: {loss}")

    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool

    # Weight update
    W.data += -50 * W.grad

### Trigram Training Results

The training loss of our trigram model trained on 80% of the data reaches $2.2590503692626953$, which is virtually identical to the loss we achieved when training on the full dataset ($2.259373664855957$).

**Interesting observation**: Despite having access to 20% less data, the trigram model achieves essentially the same training performance. This suggests:

1. The 80% training set still contains most of the important trigram patterns present in the full dataset
2. Our training process (200 epochs) is sufficient for the model to converge to a good solution
3. The reduced dataset size doesn't significantly impact the model's ability to learn the core patterns

Now let's evaluate this model on our validation and test sets to see how well it generalizes.

In [None]:
# Validation Loss
d_size = xs_tri_valid.shape[0]
with torch.no_grad():
    xenc = (
        F.one_hot(xs_tri_valid, num_classes=27).float().to(device)
    )  # one-hot encode the names
    xenc = xenc.view(d_size, -1)
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(d_size), ys_tri_valid].log().mean() + 0.01 * (W**2).mean()
    )
print(f"Validation Loss: {loss}")

# Test Loss
d_size = xs_tri_test.shape[0]
with torch.no_grad():
    xenc = (
        F.one_hot(xs_tri_test, num_classes=27).float().to(device)
    )  # one-hot encode the names
    xenc = xenc.view(d_size, -1)
    logits = xenc @ W  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = -probs[torch.arange(d_size), ys_tri_test].log().mean() + 0.01 * (W**2).mean()
print(f"Test Loss:\t {loss}")

### Trigram Generalization Analysis

**Results Summary**:
- Both validation and test losses are slightly higher than the training loss, which aligns with our hypothesis
- However, the difference is smaller than we might have expected given the significant increase in model complexity

**Interpretation**:
1. **Confirmation of complexity theory**: The increased complexity does lead to some generalization gap, as predicted
2. **Surprisingly good generalization**: Despite having 27x more possible combinations to learn, the gap is quite small
3. **Value of context**: The additional context captured by trigrams appears to provide genuinely useful information for prediction
4. **Dataset characteristics**: The names dataset likely contains strong trigram patterns that are consistent across the data splits

**Key insight**: The improved performance of the trigram model (lower overall loss compared to the bigram model) combined with its good generalization suggests that the additional complexity is justified for this task. The benefit of modeling longer contextual dependencies outweighs the increased risk of overfitting.

In [None]:
# Create set of all *trigrams*
xs, ys = [], []

for w in words:
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs.append([stoi[ch1], stoi[ch2]])
        ys.append(stoi[ch3])

xs, ys = torch.tensor(xs), torch.tensor(ys)  # [196113, 2], [196113]
num_x, num_y = xs.nelement() // 2, ys.nelement()

# Shuffle/Permute the dataset, keeping (x,y) pairs in sync
perm = torch.randperm(num_x)
xs, ys = xs[perm, :], ys[perm]  # xs are shuffled along the zeroth dimension

# Split 80:10:10 for train:valid:test
xs_tri_train, xs_tri_valid, xs_tri_test = (
    xs[: int(num_x * 0.8), :],
    xs[int(num_x * 0.8) : int(num_x * 0.9), :],
    xs[int(num_x * 0.9) :, :],
)
ys_tri_train, ys_tri_valid, ys_tri_test = (
    ys[: int(num_x * 0.8)],
    ys[int(num_x * 0.8) : int(num_x * 0.9)],
    ys[int(num_x * 0.9) :],
)

### Hyperparameter Tuning: Regularization Strength

Now that we've established the basic performance of our trigram model, let's explore how regularization affects its generalization.

**Regularization** adds a penalty to the loss function for large weight values, which can help prevent overfitting. The strength of this penalty is a hyperparameter we can tune.

**Experimental setup**:
- We'll sweep the regularization strength from $0.0$ (no regularization) to $1.0$ (strong regularization)
- We'll use 25 evenly spaced values across this range for thorough coverage
- For each strength value, we'll:
  1. Train a complete trigram model from scratch
  2. Evaluate it on all three data splits (train, validation, test)
  3. Record the losses for later analysis

**Selection process**:
- We'll visualize all losses to understand the relationship between regularization and generalization
- The optimal strength will be selected based on the validation loss (not the test loss)
- This mimics real-world scenarios where the test set remains untouched until final evaluation

In [None]:
# from 0.0 to 1.0 in 25 steps
strengths = torch.linspace(0.0, 1.0, 25, device=device)
losst, lossv, lossf = [], [], []

for strength in strengths:
    W = torch.randn(
        (27 + 27, 27), device=device, generator=g, requires_grad=True
    )  # random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)

    # Training cycles, using the entire dataset -> 200 Epochs, like the bigram model
    d_size = xs_tri_train.shape[0]
    for k in range(200):
        # Forward pass
        xenc = (
            F.one_hot(xs_tri_train, num_classes=27).float().to(device)
        )  # One-hot encoding, [196113, 2, 27]
        xenc = xenc.view(d_size, -1)  # concatenate the one-hot vectors, [196113, 54]
        logits = xenc @ W  # logits, different word for log-counts
        counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
        probs = counts / counts.sum(
            1, keepdims=True
        )  # Normal distribution probabilities (this is y_pred)
        loss_t = (
            -probs[torch.arange(d_size), ys_tri_train].log().mean()
            + strength * (W**2).mean()
        )

        # Backward pass
        W.grad = None  # Make sure all gradients are reset
        loss_t.backward()  # Torch kept track of what this variable is, kinda cool

        # Weight update
        W.data += -50 * W.grad

    # Validation Loss
    d_size = xs_tri_valid.shape[0]
    with torch.no_grad():
        xenc = (
            F.one_hot(xs_tri_valid, num_classes=27).float().to(device)
        )  # one-hot encode the names
        xenc = xenc.view(d_size, -1)
        logits = xenc @ W  # logits, different word for log-counts
        counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
        probs = counts / counts.sum(
            1, keepdims=True
        )  # Normal distribution probabilities (this is y_pred)
        loss_v = (
            -probs[torch.arange(d_size), ys_tri_valid].log().mean()
            + strength * (W**2).mean()
        )

    # Test Loss
    d_size = xs_tri_test.shape[0]
    with torch.no_grad():
        xenc = (
            F.one_hot(xs_tri_test, num_classes=27).float().to(device)
        )  # one-hot encode the names
        xenc = xenc.view(d_size, -1)
        logits = xenc @ W  # logits, different word for log-counts
        counts = (
            logits.exp()
        )  # 'fake counts' as we did for the N matrix of the Bigram model
        probs = counts / counts.sum(
            1, keepdims=True
        )  # Normal distribution probabilities (this is y_pred)
        loss_f = (
            -probs[torch.arange(d_size), ys_tri_test].log().mean()
            + strength * (W**2).mean()
        )

    # Note the losses for this strength
    losst.append((strength, loss_t))
    lossv.append((strength, loss_v))
    lossf.append((strength, loss_f))

In [None]:
# Plot the losses
plt.figure(figsize=(16, 8))
plt.plot(
    [y.item() for (_, y) in losst],
    label="Train Loss",
    linestyle="-",
    marker="o",
    color="green",
)
plt.plot(
    [y.item() for (_, y) in lossv],
    label="Validation Loss",
    linestyle="-",
    marker="o",
    color="blue",
)
plt.plot(
    [y.item() for (_, y) in lossf],
    label="Test Loss",
    linestyle="-",
    marker="o",
    color="red",
)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Loss vs. Strength")
plt.legend()
plt.show()

# Report the best strength
print(
    f"Best Strength: {min(lossv, key=lambda x: x[1])[0]} @ Train: {min(losst, key=lambda x: x[1])[1]}, Validation: {min(lossv, key=lambda x: x[1])[1]} & Test: {min(lossf, key=lambda x: x[1])[1]}"
)

### Analyzing Regularization Effects

The results from our regularization experiments reveal several important insights:

**Observed patterns**:
- **Training loss**: Increases steadily with regularization strength in a logarithmic-like curve
- **Validation loss**: Follows the training loss curve closely, also increasing with regularization strength
- **Test loss**: Generally tracks with validation loss, with slightly sharper increases at lower regularization values
- **Generalization gap**: The gap between training and test loss is largest at low regularization strengths

**Understanding regularization mechanics**:
- Regularization works by penalizing large weight values in the model
- As regularization strength increases, the model is forced to use smaller weights
- In our case, this constraint appears to limit the model's ability to capture important patterns
- The result is higher loss across all data splits as regularization increases

**Counterintuitive finding**:

While regularization typically helps prevent overfitting, in this particular case:
- The dataset is relatively small and structured
- The trigram patterns are consistent and meaningful
- The model complexity, while higher than bigrams, is still manageable for this data

**Practical conclusion**:
- For this specific language modeling task, minimal or no regularization produces the best results
- This demonstrates how the ideal regularization strategy depends on the specific characteristics of the dataset and model
- In larger, more complex models or noisier datasets, we might see more benefit from stronger regularization

In [None]:
W = torch.randn(
    (27 + 27, 27), device=device, generator=g, requires_grad=True
)  # random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)

# Training cycles, using the entire dataset -> 200 Epochs, like the bigram model
d_size = xs_tri_train.shape[0]
for k in range(200):
    # Forward pass
    logits = (
        W[xs_tri_train[:, 0]] + W[27 + xs_tri_train[:, 1]]
    )  # logits, different word for log-counts
    counts = logits.exp()  # 'fake counts', kinda like in  the N matrix of bigram
    probs = counts / counts.sum(
        1, keepdims=True
    )  # Normal distribution probabilities (this is y_pred)
    loss = (
        -probs[torch.arange(d_size), ys_tri_train].log().mean() + 0.01 * (W**2).mean()
    )
    print(f"Loss @ iteration {k + 1}: {loss}")

    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool

    # Weight update
    W.data += -50 * W.grad

### Computational Optimization: Direct Embedding Lookup

**Understanding the Original Approach**

The core idea of our original trigram model was:
```python
xenc = F.one_hot(xs_tri_train, num_classes=27).float().to(device) # One-hot encoding, [196113, 2, 27]
xenc = xenc.view(d_size, -1) # concatenate the one-hot vectors, [196113, 54]
logits = xenc @ W
```

This implementation:
1. Converts each character to a one-hot vector (27 dimensions)
2. Concatenates two one-hot vectors to form a 54-dimensional input vector
3. Multiplies this with weight matrix `W` of shape $54 \times 27$

The weight matrix structure is significant:
- First 27 rows correspond to weights for the first character
- Next 27 rows correspond to weights for the second character
- This allows the model to learn different weights for each character position

**The Optimized Implementation**

We can achieve the same mathematical result more efficiently:
```python
logits = W[xs_tri_train[:,0]] + W[27 + xs_tri_train[:,1]]
```

**How and why this works**:
1. **Matrix multiplication with one-hot vectors is equivalent to lookup**: When you multiply a one-hot vector by a matrix, you're essentially selecting a specific row of that matrix
2. **Direct indexing**: Instead of creating one-hot vectors and performing matrix multiplication, we directly index into the rows of `W` using character indices
3. **Position encoding**: For the second character, we add 27 to the index to access the second half of the weight matrix
4. **Addition combines influences**: The sum combines the influence of both characters on predicting the next character

**Benefits**:
- **Computational efficiency**: Eliminates the need to create and manipulate large one-hot vectors
- **Memory efficiency**: Reduces memory usage during forward pass
- **Mathematical equivalence**: Produces exactly the same results as the original approach

This optimization illustrates an important principle in deep learning: understanding the mathematical equivalence between operations allows us to implement more efficient solutions without changing the underlying model.

## 3. Change the loss function

In neural networks, the choice of loss function is crucial for effective training. So far, we've been implementing the negative log-likelihood loss manually. However, PyTorch provides optimized implementations of common loss functions.

In this exercise, we will:
1. Replace our manual negative log-likelihood implementation with PyTorch's built-in `F.cross_entropy`
2. Compare the results to verify mathematical equivalence
3. Analyze the advantages of using the built-in function

**Key Questions:**
- Does `F.cross_entropy` produce the same results as our manual implementation?
- What are the technical advantages of using the built-in function?
- Are there any performance benefits to this approach?

### Using PyTorch's Cross Entropy Loss

PyTorch's `F.cross_entropy` ([docs](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.cross_entropy.html)) function combines several operations in one efficient call:

1. It applies a softmax function to convert logits to probabilities
2. It then computes the negative log-likelihood of the correct class
3. It averages the loss across all examples in the batch

**Implementation Note:** The function expects:
- First argument: raw logits (unnormalized scores) from the model
- Second argument: target class indices (not one-hot encoded)

Let's implement our trigram model using this function:

In [None]:
W = torch.randn(
    (27 + 27, 27), device=device, generator=g, requires_grad=True
)  # random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)

# Training cycles, using the entire dataset -> 200 Epochs, like the bigram model
d_size = xs_tri_train.shape[0]
for k in range(200):
    # Forward pass
    logits = (
        W[xs_tri_train[:, 0]] + W[27 + xs_tri_train[:, 1]]
    )  # logits, different word for log-counts
    loss = F.cross_entropy(logits, ys_tri_train) + 0.01 * (W**2).mean()
    print(f"Loss @ iteration {k + 1}: {loss}")

    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool

    # Weight update
    W.data += -50 * W.grad

### Advantages of Using `F.cross_entropy`

The final loss achieved is equivalent to our manual implementation, but using `F.cross_entropy` offers several important advantages:

**1. Code Simplicity and Readability**
- Replaces multiple lines of complex operations with a single function call
- Makes the code easier to read, maintain, and debug
- Clarifies the intent of the computation

**2. Numerical Stability**
- Our manual implementation used these steps:
```python
counts = logits.exp() # Convert logits to unnormalized probabilities
probs = counts / counts.sum(1, keepdims=True) # Normalize to get probabilities
loss = -probs[torch.arange(d_size), ys_tri_train].log().mean() + 0.01 * (W**2).mean()
```
- This approach can cause numerical issues:
  - The `exp()` operation can easily overflow for large logit values
  - Taking the logarithm of very small probabilities can lead to underflow

**3. Computational Efficiency**

`F.cross_entropy` uses a mathematically equivalent but more stable approach:
- Applies log-softmax operation that combines softmax and log in a single step
- Computes operations in log-space to avoid overflow/underflow
- Implements optimized CUDA kernels for faster computation on GPUs

**4. Memory Efficiency**
- Avoids creating intermediate tensors for probabilities
- Reduces memory usage during the forward and backward passes

These advantages make `F.cross_entropy` the preferred choice in production machine learning code, ensuring training is faster, more stable, and less prone to numerical errors.

In [None]:
# Random column tensor of (27+27)x27 numbers (requires_grad=True for autograd)
W = torch.randn(
    (27, 27), device=device, generator=g, requires_grad=True
)

# Training cycles, using the entire dataset for 200 Epochs
d_size = xs_tri_train.shape[0]
for k in range(200):
    # Forward pass
    logits = (
        W[xs_tri_train[:, 0]] + W[xs_tri_train[:, 1]]
    )  # logits, different word for log-counts
    loss = F.cross_entropy(logits, ys_tri_train) + 0.01 * (W**2).mean()
    print(f"Loss @ iteration {k + 1}: {loss}")

    # Backward pass
    W.grad = None  # Make sure all gradients are reset
    loss.backward()  # Torch kept track of what this variable is, kinda cool

    # Weight update
    W.data += -50 * W.grad

### Understanding Model Architecture Importance

The results from this experiment are clear, and they are **not good**:

- The training loss is significantly higher than our original trigram model
- Training behavior has become erratic and unstable
- The model fails to learn effective representations

**What went wrong?**

In this final experiment, we modified the architecture of our model by changing the weight matrix dimensions from $(27+27) \times 27$ to $27 \times 27$. This seemingly small change had dramatic consequences:

1. **Loss of position information**: By using the same weights for both character positions, we've eliminated the model's ability to learn position-specific representations

2. **Reduced modeling capacity**: The number of parameters was reduced from 1,458 to 729, cutting the model's capacity in half

3. **Parameter interference**: Updates to weights for one position now affect predictions for the other position, creating destructive interference

**Key insight**: This experiment demonstrates that the proper architecture design is crucial for model performance. In language modeling, preserving position information through separate parameters for each position is essential for effective learning.

This reinforces a fundamental principle in neural network design: the architecture should reflect the structure of the problem. For sequence modeling tasks like ours, position-specific parameters are vital for capturing the sequential nature of language.