# Language Modeling With PyTorch

The purpose of this notebook is to give a gentle introduction to the PyTorch library, with a focus on language modeling.
At high-level, we will build a progressively more complex **character-level language model** that can generate more text similar to the training data.

The final result is not meant to be a "production-ready" language model, but rather a simple yet effective example of how to use PyTorch for language modeling. Along the way, we will learn the fundamental building blocks that lay the groundwork for more complex models, including the base models that powers the state-of-the-art LLMs and derived products, like our friendly and always helpful assistant ChatGPT.

The final implementation will allow you to experiment with different models, starting from the most simple and basic one (a **bigram** model) to a more complex **RNN** and finally a **Transformer** model.
In particular, we'll be following roughly some key papers:

- Bigram (one character predicts the next one with a lookup table of counts)
- Multi-Layer Perceptron (MLP): [Bengio et al. 2003](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf)
- Convolutional Neural Network (CNN): [DeepMind WaveNet 2016](https://arxiv.org/abs/1609.03499)
- Recurrent Neural Network (RNN): [Mikolov et al. 2010](https://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf)
- Long Short-Term Memory (LSTM): [Graves et al. 2014](https://arxiv.org/abs/1308.0850)
- Gated Recurrent Unit (GRU): [Kyunghyun Cho et al. 2014](https://arxiv.org/abs/1409.1259)
- Transformer: [Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)

## Inspecting the data

In [None]:
import pathlib as pl

Our initial dataset is a simple list of strings that represent common names:

In [None]:
words = pl.Path("data/names.txt").read_text().splitlines()
words[:10]

In [None]:
len(words), min(len(w) for w in words), max(len(w) for w in words)

The information we can extract from a single name, e.g. `isabella`, is multiple:

- We know that the character `i` is followed by `s`.
- We know that, after the characters `isabell`, the following character is `a`.
- We know that, after the characters `isabella`, the following character is `\n` (end of string).

The idea is that a single word packs multiple pieces of information regarding the statistical structure of the language it belongs to.
And since we have about 32k words, there's quite a lot of information we can use to train even a simple language model.

## Bigram language model

A bigram language model is the simplest possible language model.
Given a sequence of characters (each character is usually referred to as a **token**), the bigram language model assigns a probability to each possible next token, given the previous token.
It's a predictor for each pair of tokens.

In [None]:
for w in words[:3]:
    for ch1, ch2 in zip(w, w[1:]):
        print(ch1, ch2)

The most basic modeling of the statistical patterns embedded in our input data is predidicting the next token **by frequency**.
We can build a simple dictionary that counts how many times a bigram (i.e., sequence of two tokens) appear in our dataset.

We also need to add the *special* information about the start and end of the sequence.
We can "encode" that information with two **special tokens**: `<S>` (start) and `<E>` (end)

In [None]:
#% Student code
from collections import defaultdict

bigrams = defaultdict(int)

for w in words:
    chars = ["<S>"] + list(w) + ["<E>"]
    for ch1, ch2 in zip(chars, chars[1:]):
        bigrams[(ch1, ch2)] += 1

print(dict(bigrams))

What does an entry of the bigram dictionary look like?
It's something like `('a', '<E>'): 6640`, which means: the bigram `('a', '<E>')` occurred 6640 times.
That is: the letter `a` is quite likely to appear at the end of a name.

We want to sort the bigrams by their count, from the most frequent to the least frequent.
Let's see the first 10 most frequent bigrams:

In [None]:
sorted(bigrams.items(), key=lambda x: x[1], reverse=True)[:10]

A much better way to store this information in a 2D array, where the rows are going to be the 1st character and the columns are going to be the 2nd character.
The entry at row `a` and column `b` is going to be the count of the bigram `ab`.

Since we are dealing with 26 characters, plus `<S>` and `<E>`, we need a total of 28x28 = 784 entries.
Bracketed tokens are customary in NLP to represent special tokens, but here we are only interested in knowing when a sentence starts or ends.

But there's also a problem with keeping two special tokens for the start and end of a sentence: we can't have `('<S>', '<S>')` or `('<E>', '<E>')`, or other combinations like `('a', '<S>')` or `('<E>', 'b')`.
These would be invalid bigrams.
To solve this, we can replace `<S>` with `.` and `<E>` with `.` and have a total of 27x27 = 729 entries.

In [None]:
import torch

N = torch.zeros((27, 27), dtype=torch.int32)

How should we encode the bigrams?
Our 2D array is going to hold integers **only**, so we need a way to make this conversion.
One way is to be a so-called **vocabulary** from our input data.

A vocabulary requires two functions:
1. `stoi`: string to integer (**encoding**)
2. `itos`: integer to string (**decoding**)

In [None]:
chars = sorted(set(''.join(words)))

stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0 # special token for end of sentence is mapped to 0
itos = {i:s for s, i in stoi.items()}

print(stoi)
print(itos)

In [None]:
#% Student code
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    i = stoi[ch1]
    j = stoi[ch2]
    N[i, j] += 1

Let's create a nice visualization of our bigram frequency table:

In [None]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16,16))
norm = mpl.colors.Normalize(vmin=N.min(), vmax=N.max())
im = ax.imshow(N, cmap='Blues')

for i in range(N.shape[0]):
    for j in range(N.shape[1]):
        val = norm(N[i,j])
        text_color = 'white' if val > 0.5 else 'black'
        # character
        ax.text(j, i, itos[i]+itos[j],
                ha="center", va="bottom",
                color=text_color, fontweight='bold')
        # count
        ax.text(j, i, int(N[i,j]),
                ha="center", va="top",
                color=text_color)

ax.set_axis_off()

Our bigram model is essentially an iterative sampling from a probability distribution that describes how frequent each bigram is in the dataset.

We already have the frequency table, so we need to built a probability distribution and a sampling mechanism.
Let's do it for the first row, which represents the frequency of each bigram starting with the character '.'

In [None]:
proba = N[0].float()
proba /= proba.sum()
print(proba)

PyTorch provides us with a method to sample from a [**Multinomial distribution**](https://docs.pytorch.org/docs/main/distributions.html#multinomial).
A Multinomial distribution is a generalization of a [**Binomial distribution**](https://en.wikipedia.org/wiki/Binomial_distribution), where we sample from a distribution with more than two outcomes.

To enforce predictability, we can initialize a random number generator with a fixed seed.
Also, we need to allow sampling **with replacement**, so that we can sample the same token multiple times.

In [None]:
gen = torch.Generator().manual_seed(2147483647)
ix = torch.multinomial(proba, num_samples=1, replacement=True, generator=gen).item()
print(itos[ix])

We can of course sample as many tokens as we want:

In [None]:
torch.multinomial(proba, num_samples=100, replacement=True, generator=gen)

You might have understood how the process goes: after we extract a given bigram, we need to lookup the most likely bigram that starts with the second character of the first bigram.

In [None]:
g = torch.Generator().manual_seed(2147483647)
NUM_WORDS = 20

for _ in range(NUM_WORDS):
  
  out = []
  ix = 0

  while True:
    # Compute the probabilities
    p = N[ix].float()
    p /= p.sum()
    
    # Sample the next character
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    
    # Add the character to the output
    out.append(itos[ix])
    
    # Stop if we reach the end of the text
    if ix == 0:
      break

  print(''.join(out))

The results are quite terrible, although they're reasonable given the simplicity of the model and the patterns we're trying to capture.

The core problem is that a bigram model looks only the the frequency of a pair of tokens, but it has zero information of what's most likely to come before or after those two tokens.
You can imagine that the obvious next step is a **trigram** model, which looks at the frequency of a triplet of tokens.

Let's now improve a bit our code: the first thing is to compute **all** the probabilities once, and then sample from them.
PyTorch tensors support **vectorized** operations, which means that we can perform operations on entire tensors at once, without having to loop through them.

Each row of our 2D matrix contains the counts of how many times the token with that row index is followed by all the other tokens, whose indexes run along the columns.

$$
P_{ij}= \frac{N_{ij}}{\displaystyle\sum_{k} N_{i k}}
$$

For each pair $(i,j)$:
- The numerator $N_{ij}$ is the count of the number of times token `j` follows token `i`.
- The denominator $\sum_{k} N_{i k}$ is the total number of times *any* character follows `i`.

In Python

```python
P = N.float()
P /= P.sum(dim=1, keepdim=True)
```

Here `dim=1` tells PyTorch to sum over the columns (the second index), while `keepdim=True` tells it to keep the first dimension (the first index) as a singleton (a `1`) dimension.
Without `keepdim=True`, the result would have shape `(27,)`, and performing the division would produce the wrong result because of how [brodcasting](https://pytorch.org/docs/stable/notes/broadcasting.html) works.

> Try to experiment with the `keepdim` parameter and see what happens if you remove it.
> Can you explain why the predictions become complete garbage?

In [None]:
g = torch.Generator().manual_seed(2147483647)

P = N.float()
P /= P.sum(dim=1, keepdim=True)

for _ in range(10):
  
  out = []
  ix = 0

  while True:
    # Get the probabilities
    p = P[ix]
    
    # Sample the next character
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    
    # Add the character to the output
    out.append(itos[ix])
    
    # Stop if we reach the end of the text
    if ix == 0:
      break

  print(''.join(out))

### Evaluating the quality of the model

We have built a bigram language model by counting letter combination frequencies, then normalizing and sampling with that probability base.

We trained the model, we sampled from the model (iteratively, character-wise). But its still bad at coming up with names.

But how bad? We know that the model's "knowledge" is represented by `P`, but how can we boil down the model's quality in one value?

First, let's look at the bigrams we created from the dataset: the bigrams to `emma` are `.e, em, mm, ma, a.`.
**What probability does the model assign to each of those bigrams?**

In [None]:
for w in words[:1]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]): # Neat way for two char 'sliding-window'
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix1, ix2]
        print(f'{ch1}{ch2}: {prob:.2%}')

Anything above or below $\frac{1}{27} \approx 3.7\%$ means we deviate from the mean, that is, a completely uniform distribution of bigrams. 
And that means we learned something from the bigram statistics.

How can we summarize these probabilities into a quality indicating measurement?
We may compute the product of all probabilities — a number called the **likelihood**.
But since all these probabilities are small numbers, the product is also a small number, and it is hard to compare likelihoods.
Solution: *The log-likelihood, the **sum** of $\log(P)$ over all the individual token probabilities* ($\log$ is applied for convenience).

> The higher the log-likelihood, the better the model, because the more capable it is of predicting the next character in a sequence from the dataset.

In [None]:
# Initialize variables
log_likelihood = 0.0
n = 0  # character pair count

for word in words:
    # Add start/end tokens and convert to character list
    chars = ['.'] + list(word) + ['.']
    
    # Calculate log probabilities in a more compact way
    for ch1, ch2 in zip(chars, chars[1:]):
        prob = P[stoi[ch1], stoi[ch2]]
        log_likelihood += torch.log(prob)
        n += 1

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')  # Negative log likelihood
print(f'Average NLL: {nll/n:.4f}')  # More descriptive output

We calculated a negative log-likelihood, because this follows the convention of setting the goal to minimize the **loss function**, the function that drives the optimization (i.e., training) process.
The lower the loss/negative log-likelihood, the better the model.

We got $2.45$ for the model. The lower, the better.
We need to find the parameters that reduce this value.

**Goal:** Maximize likelihood of the trained data w. r. t. model parameters in `P`
- This is equivalent to maximizing the log-likelihood (as $\log$ is monotonic)
- This is equivalent to minimizing the *negative* log-likelihood
- And this is equivalent to minimizing the average negative log-likelihood (the quality-measurement, as shown by $2.45$ above)

There's an immediate problem, though: if we have a word containing a bigram that **never** appears in our training data, the model will assign a probability of $0$ to it, which will make the log-likelihood $-\infty$.

In [None]:
# Initialize variables
log_likelihood = 0.0
n = 0  # character pair count

for word in ["edobq"]:
    # Add start/end tokens and convert to character list
    chars = ['.'] + list(word) + ['.']
    
    # Calculate log probabilities in a more compact way
    for ch1, ch2 in zip(chars, chars[1:]):
        prob = P[stoi[ch1], stoi[ch2]]
        log_likelihood += torch.log(prob)
        n += 1

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')  # Negative log likelihood
print(f'Average NLL: {nll/n:.4f}')  # More descriptive output

A negative infinite log-likelihood is definitely not good because our optimizer will never find a "stable" solution.

One simple fix is to assign a small but non-zero probability to every bigram: this is called **model smoothing**.
The easiest way is to ensure that no bigram *never* appears: we can achieve this by adding a constant to our 2D matrix `N`.

In [None]:
PS = (N + 1).float()  # The higher the number, the more smoothing we apply
PS /= PS.sum(dim=1, keepdim=True)

# Initialize variables
log_likelihood = 0.0
n = 0  # character pair count

for word in ["edobq"]:
    # Add start/end tokens and convert to character list
    chars = ['.'] + list(word) + ['.']
    
    # Calculate log probabilities in a more compact way
    for ch1, ch2 in zip(chars, chars[1:]):
        prob = PS[stoi[ch1], stoi[ch2]]  # Use the smoothed probabilities
        log_likelihood += torch.log(prob)
        n += 1

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')  # Negative log likelihood
print(f'Average NLL: {nll/n:.4f}')  # More descriptive output

## A neural network approach

We will cast the problem of character estimation into the framework of neural networks.
The problem remains the same, the approach changes, and the outcome should look similar.

Our neural network **receives a single character** and **outputs the probability distribution over the next possible characters** ($27$ in this case).

It's going to make guesses on the most likely character to follow.
We can still measure the performance through the *same* loss function, the negative log-likelihood.

From the training data, we also know the character that actually comes next in each training example.
We'll use this information to fine-tune (i.e., train or update the parameters of) the neural network to make better guesses: this is a textbook example of **supervised learning**.

### The training set

In [None]:
#Create training set of all bigrams
xs, ys = [], [] # Input and output character indices

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)

# Convert lists to tensors
xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [None]:
for i in range(5):
    print(f'For character #{i} "{itos[xs[i].item()]}" in xs, we expect the model to predict "{itos[ys[i].item()]}"')