# Karpathy Bigram Language Model

This is my Jupyter notebook where I replicate what Andrej teaches me in his
YouTube video: https://www.youtube.com/watch?v=kCc8FmEb1nY. The notebook 
contains the part of the video (first ~40 minutes) where he walks us through
how to train a simple Bigram Language model using the `tinyshakespeare` 
dataset to make predictions using only the last character in the sequence.

This is a good pedagogical introduction to the general problem of training
models, of pulling random chunks of text out of the training corpus to 
train on. How to use 

## GPU optimization using Metal
Learning about batch sizes. The inference function, `generate()` is not 
optimized to use batching, so it is extremely slow on GPU vs. CPU. This is 
a bit strange because of the entire idea of unified memory on Macs. This
is something to delve into more.



First cell in the notebook downloads the `tinyshakespeare` dataset into this
repo.

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-03-12 14:02:12--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2023-03-12 14:02:12 (4.32 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [2]:
with open('input.txt', 'r') as f:
    text = f.read()

print('corpus length:', len(text))
print(text[:300])

corpus length: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


Compute the vocabulary to use for the characters in the dataset since we're
building a character-based (NOT token-based) model.

In [3]:
chars=sorted(list(set(text)))
vocab_size=len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


Create some dicts that will be used for forward/reverse mapping of chars to tokens (integers in our case).

In [4]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
print(stoi)
print(itos)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i',

Write some functions that will encode and decode to integers.

In [5]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[c] for c in l])

print(encode('hii there'))
print(decode(encode('hii there')))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


Now let's do the same thing but using a real tokenizer, [tiktoken](https://github.com/openai/tiktoken). There are multiple encoders that can be used:

- `cl100k_base`: encoder used by ChatGPT models
- `gpt2`: encoder used by older models like GPT3 which has a smaller vocabulary size

In [6]:
import tiktoken
enc = tiktoken.get_encoding('cl100k_base')
print(enc.n_vocab)
print(enc.encode("hii there"))
print(enc.decode(enc.encode("hii there")))


100277
[71, 3893, 1070]
hii there


The trade-off here is that there are token sequences using a character tokenizer, but things are simpler in the end.

## Initialize pytorch for this notebook

Using the `mps` device on Apple Silicon to accelerate pytorch computations.

In [7]:
import torch
import torch.nn as nn 
from torch.nn import functional as F 

torch.manual_seed(1337)

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# FORCE CPU
device = torch.device("cpu")

Using Metal backend


Next, let's load the entire Shakespeare corpus into a pytorch tensor. You can see that it is a 1 dimensional tensor with the same number of elements as characters from earlier.

In [8]:
data = torch.tensor(encode(text), dtype=torch.long).to(device)
print(data.shape, data.dtype)
print(data[:300])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

Split the dataset into training/validation split. 90% will be train, 10% will be validation.

In [9]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
print(val_data)

tensor([12,  0,  0,  ..., 45,  8,  0])


The input is never fed in its entirety into the transformer. Instead it is fed in chunks, with the `block_size` variable controlling the size of the chunk.

In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

The goal of the transformer is to predict the next character in the sequence within the block, but also the next character given a context. The context is key to the prediction.

In [11]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


The next thing to think about here is the batch size. We will feed `block_size` chunks into the transformer in batches for efficiency. The number of batches is controlled by `batch_size`. The `get_batch` function below will compute a vector of random integers for offsets (`ix`) and use the `torch.stack()` function to stack the vectors into a two dimensional array. 

Therefore, each block is a vector of length `block_size` and we turn that into a two dimensional tensor that "stacks" each vector as rows. So the number of rows in each batch is of `batch_size`. Also note that the targets are just each input shifted to the right by one character.

In [12]:
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("----")

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b,:t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])
targets:
torch.Size([4, 8])
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]])
----
when input is [53] the target: 59
when input is [53, 59] the target: 6
when input is [53, 59, 6] the target: 1
when input is [53, 59, 6, 1] the target: 58
when input is [53, 59, 6, 1, 58] the target: 56
when input is [53, 59, 6, 1, 58, 56] the target: 47
when input is [53, 59, 6, 1, 58, 56, 47] the target: 40
when input is [53, 59, 6, 1, 58, 56, 47, 40] the target: 59
when input is [49] the target: 43
when input is [49, 43] the target: 43
when input is [49, 43, 43] the target: 54
when input is [49, 43, 43, 54] the target: 1
when input is [49, 43, 43, 54, 1] the target: 47
when input is [49, 43, 

Build a Bigram language model (see his other video for details for how it works). Note that we can see the shape of the module as (`batch_size`, `block_size`, `vocab_size`). For each token, it retrieves a row from the Embedding table which are the logits which is the probability of each token. So given a token of 4 in batch 0, we will have a vector of length `vocab_size` that represents the probability of each token predicted by the model. So we can do things like retrieve the highest probability or some other algorithm to generate the prediction.

In [13]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        logits = self.token_embedding_table(idx).to(device)
        return logits 
    
m = BigramLanguageModel(vocab_size).to(device)
out = m(xb, yb).to(device)
print(out.shape)
print(out.device)

torch.Size([4, 8, 65])
cpu


What we have here is a (B,T,C) tensor, but the `F.cross_entropy()` function expects a (B,C,T) tensor as input. We will accomplish this by collapsing the B and T dimensions into a single dimension - what Andrej refers to in his video as "stretching out" the tensor in those directions. We will do it for both the inputs and the targets via the view() method of the tensors.

After running the method you can see the collapsed input tensors - `4*8 = 32`.

In [14]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        logits = self.token_embedding_table(idx)

        B, T, C = logits.shape 
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

m = BigramLanguageModel(vocab_size).to(device)
logits, loss = m(xb, yb)
print("Information about logits:")
print(f"Shape:\n{logits.shape}")
print(f"Logits (predictions for first character of first batch):\n{logits[0]}")
char_idx = int(torch.argmax(logits[0]))
print(f"Index of predicted character:\n{char_idx}")
print(f"Predicted character: {decode([char_idx])}")
print(f"Overall loss:\n{loss}")

Information about logits:
Shape:
torch.Size([32, 65])
Logits (predictions for first character of first batch):
tensor([ 0.3963, -1.4631, -1.0990, -1.7104, -1.2570, -0.2641, -0.3085, -0.5644,
        -0.0534,  0.7400,  0.0912, -0.0041, -0.3235,  0.9601,  0.2023,  0.0994,
        -0.6136, -2.0696,  0.4888, -0.7050,  0.7657, -1.0252, -0.6200, -0.8280,
        -1.2047, -2.5844,  1.9835,  2.4489, -1.2784,  0.0163,  1.0204,  0.6234,
        -0.4944, -0.5679,  0.7387,  0.2977,  1.5133,  0.0898,  0.3490, -1.4351,
         1.2178, -0.7338,  0.2396, -0.0415,  0.3067,  2.1749, -0.0563,  1.0076,
        -1.5035,  1.4801, -1.3473,  1.2003,  0.3616,  0.4924, -2.3997, -1.3982,
         1.1088,  0.0864, -0.2992,  0.5236, -0.8487, -0.7711, -0.0528, -0.8631,
        -0.4571], grad_fn=<SelectBackward0>)
Index of predicted character:
27
Predicted character: O
Overall loss:
4.826590538024902


Now we write the `generate()` function that takes an input prompt (`idx`) and uses it as the starting point to predict the next character in the model. Now this model hasn't been trained at all, so there is no reason why it should produce anything other than gibberish which is what we see below.

In [15]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, debug=False):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        self.debug = debug

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape 
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def pprint(self, s, printed):
        if self.debug and printed:
            print(s)

    def generate(self, idx, max_new_tokens):
        printed = False
        for _ in range(max_new_tokens):
            # Get the logits for the prediction - this is a vector where each
            # input token has a score
            logits, _ = self(idx)
            self.pprint(f"idx: {idx}", printed)
            self.pprint(f"logits shape: {logits.shape}", printed)
            self.pprint(f"logits: {logits}", printed)
            # Focus only on the last token - we are not using history to make
            # predictions!
            logits = logits[:, -1, :]
            self.pprint(f"logits shape: {logits.shape}", printed)
            self.pprint(f"logits: {logits}", printed)
            probs = F.softmax(logits, dim=-1)
            self.pprint(f"probs shape: {probs.shape}", printed)
            self.pprint(f"probs: {probs}", printed)
            # What algorithm do we use to sample from the distribution?
            idx_next = torch.multinomial(probs, num_samples=1)
            self.pprint(f"Predicted next character: {idx_next}", printed)
            # Append to the resulting tensor. Note that we start with the 
            # original tensor, and then append the new token
            idx = torch.cat([idx, idx_next], dim=1)
            printed = True
        return idx

m = BigramLanguageModel(vocab_size).to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long).to(device)
predictions = m.generate(idx, max_new_tokens=100)
print(f"The context is {idx}")
print(f"The predictions are {predictions}")
print(f"Decoded: {decode(predictions[0].tolist())}")


torch.Size([32, 65])
tensor(4.4065, grad_fn=<NllLossBackward0>)
The context is tensor([[0]])
The predictions are tensor([[ 0, 25, 59, 44, 30,  9,  2, 63, 15,  1, 26, 20, 23, 19, 37, 12, 15,  9,
         43, 16, 17, 19, 25, 54,  7, 34, 15, 41, 38, 26,  9, 39, 53, 30, 26,  4,
         54, 53, 61, 44, 36, 51, 14,  8, 17, 27, 36,  8, 39, 18, 50,  9, 57, 47,
         58, 19, 48,  4, 35, 41, 19, 54, 53, 17, 44, 30, 60, 20, 24, 40, 30, 12,
          0, 22, 62, 20, 53, 44, 30, 36,  6, 25, 41, 41, 13, 35, 14, 32, 18,  8,
         13, 24, 18,  3, 51, 29, 11,  2, 19, 37, 10]])
Decoded: 
MufR3!yC NHKGY?C3eDEGMp-VCcZN3aoRN&powfXmB.EOX.aFl3sitGj&WcGpoEfRvHLbR?
JxHofRX,MccAWBTF.ALF$mQ;!GY:


Now let's train the model. We'll begin by creating an optimizer object and then passing it into a training loop.

In [20]:
import time

m = BigramLanguageModel(vocab_size).to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
iterations = 8192

start = time.perf_counter()
for steps in range(iterations):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(f"Final Loss = {loss.item()}")
elapsed = time.perf_counter() - start
print(f"Training time: {elapsed:.2f} seconds or {iterations/elapsed:.2f} iterations per second")

Final Loss = 2.6402642726898193
Training time: 6.19 seconds or 1323.55 iterations per second


Now let's run the model. Note that this model is only using the last character in the sequence (and not the characters that came before it) to make the next prediction. So, not optimial, but if we squint really hard, we can see a bit of structure to what otherwise looks kind of like noise below:

It seems like there is a real issue with inference time on GPU (probably because the `generate()` function doesn't do any clever batching at all).

In [22]:
tokens_to_generate = 500
start = time.perf_counter()
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=tokens_to_generate)[0].tolist()))
elapsed = time.perf_counter() - start
print(f"Inference time: {elapsed:.2f} seconds or {tokens_to_generate/elapsed:.2f} tokens per second")


Ase ts
ANCoraidghith I'stNoug whes us, so hamethcomberm.
Tht thie

Flarle,'so is miswsim:
DWAy yonecrerenge bonchamy war te I' ol od:

thisprelder frerin'l k, t athat, I!journop asuldimeke--OFIO.
D ISes mamubrisowistharenot y be eDERRY tuir iffotheilin s a thescofirowndolely th oricom:
ARD foreanes so, thant,
Whert crs
CExY m ters drd coo miqbomyesgnf arpr l me d ONENende ler fr fanet!
Q:


TIZThe aldI m m a o my themy thuivere
ARUTRG sDu

LO:
NThelin!
Holener whelakind uge f m;
An:
F:
INGINCE b
Inference time: 0.68 seconds or 732.81 tokens per second


## Understanding some linear algebra tricks

These tricks are at the core of making the computations efficient using GPU
acceleration for matrix multiplication. It gets us out of loops in Python
and into vectorized matmul instead.

In [23]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
print(x.shape)
print(x)

torch.Size([4, 8, 2])
tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])


At some point t in T we ensure that we only look at tokens that can before
time t. We can't look beyond t because that would be looking into the future.
The code below computes the running mean of the token probabilities for 
each position before t.

In [32]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        # Look only at the logits of tokens before t
        xprev = x[b, :t+1]

        # Compute the mean of the previous logits
        xbow[b,t] = torch.mean(xprev, 0)

print(x[0])
print(xbow[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


We can rewrite using the mathematical trick of using triangular matrix 
vectorization!

In [38]:
# Naive iteration as first step
tril = torch.tril(torch.ones(8, 8))
tril = tril / torch.sum(tril, 1, keepdim=True)
print(tril)
for i in range(B):
    # Vectorize the compuation of the averages using the tril trick
    b = x[i]
    c = tril @ b
    print(f"x[{i}]:\n{b}")
    print(f"prev_token_average_probs[{i}]:\n{c}")

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
x[0]:
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
prev_token_average_probs[0]:
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
       

But we can remove the outermost loop as well and just compute the mmul of the
vectors. tril is (T,T) and x is (B,T,C). This gives us the same answer but 
MUCH faster than manually iterating in Python. These are small tensors so the
perf difference is not noticeable.

The clever thing about how pytorch works is that in cases where there is a 
missing dimension (tril is only (T,T) and x is (B,T,C)) pytorch will 
automatically create a new dimension so that the mult looks like:

(B,T,T) @ (B,T,C) -> (B,T,C)

In [37]:
xbow2 = tril @ x
print(xbow2)
print(xbow==xbow2) # Should be all True but there are some rounding errors
print(torch.allclose(xbow, xbow2)) # This will evaluate to true though!

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])
tensor([[[ True,  True],
         [ True,  True],
         [ Tru

## Self-Attention mechanism
