# GPT from scratch

Follows [this](https://www.youtube.com/watch?v=kCc8FmEb1nY) video tutorial.

In [None]:
# !curl -O https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [1]:
with open('input.txt', 'r') as f:
    text = f.read()

In [2]:
print(f"Length of text: {len(text)}")
print(text[:100])

Length of text: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


We need to create our vocabulary.

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
vocab = {c: i for i, c in enumerate(chars)}
print(len(vocab)) # 65 unique chars.

65


We want to now be able to tokenize the text somehow. Here, we're building a character-level model, so we're just translating individual characters into integers.

In [6]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}

# take a string, return a list of integers
encode = lambda x: [stoi[c] for c in x]
# take a list of integers, return a string
decode = lambda x: ''.join([itos[i] for i in x])

In [7]:
print(encode("Hello this is a test string"))
print(decode(encode("Hello this is a test string")))

[20, 43, 50, 50, 53, 1, 58, 46, 47, 57, 1, 47, 57, 1, 39, 1, 58, 43, 57, 58, 1, 57, 58, 56, 47, 52, 45]
Hello this is a test string


We can now tokenize the entire text dataset

In [9]:
import torch

In [10]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [18]:
# let's also split out data into train and validation datasets
n = int(0.9*len(data))
train_data, val_data = data[:n], data[n:]

We can't pass in the entire dataset in one iteration of the transformer, so we do need to set a "block size" (AKA context length).

In [12]:
block_size = 8

Let's take a look at one block

In [13]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

### Training using chunks of the context window
A transformer makes simultaneous predictions for each of these positions. This makes training much more efficient since we treat each position as a separate prediction task instead of just predicting the next character after the last position, which greatly increases the number of samples used for training.

This is why we take `[:block_size+1]` since if we have 8 chars, we need to look 9 chars so that we can always predict the char after a given char.

We can actually display what the prediction task is: given a block of text, we predict what token is at each position. The input passed into each prediction task is whatever tokens were in the block before this target.

In [14]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is: {target}")

When input is tensor([18]), the target is: 47
When input is tensor([18, 47]), the target is: 56
When input is tensor([18, 47, 56]), the target is: 57
When input is tensor([18, 47, 56, 57]), the target is: 58
When input is tensor([18, 47, 56, 57, 58]), the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 58


So, in a chunk of 9 characters, we essentially have 8 training examples. This not only gives us more training samples, but it also makes the Transformer used to seeing context sizes from as little as 1 all the way to the context length.

### Batching

We also need to care about the batching dimension as well, since we'll be feeding the input as batches of tensors. This lets us process more inputs in parallel, especially since GPUs are great at working with batches of data.

In [15]:
torch.manual_seed(1337)

<torch._C.Generator at 0x114a0e470>

In [16]:
# number of sequences in a batch, processed in parallel.
batch_size = 4

# maximum context length for prediction
block_size = 8

In [19]:
def get_batch(split):
    """Gets a batch of data for training or validation."""
    data = (
        train_data if split == "train" else val_data
    )
    # we take random ints as our start indices. We make
    # sure to take random ints that allow us to take a full
    # sequence of `block_size` length
    ix = torch.randint(
        len(data) - block_size, # which ints to include in pool
        (batch_size,) # number of results = batch size.
    )

    # for each start index, we take a random block of chars.
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [24]:
xb, yb = get_batch(split="train")

Let's now look at our batch.

In [26]:
for b in range(batch_size):
    print(f"Batch {b}")
    print('-' * 10)
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context}, the target is: {target}")
    print('-' * 10)

Batch 0
----------
When input is tensor([43]), the target is: 1
When input is tensor([43,  1]), the target is: 51
When input is tensor([43,  1, 51]), the target is: 39
When input is tensor([43,  1, 51, 39]), the target is: 63
When input is tensor([43,  1, 51, 39, 63]), the target is: 1
When input is tensor([43,  1, 51, 39, 63,  1]), the target is: 40
When input is tensor([43,  1, 51, 39, 63,  1, 40]), the target is: 43
When input is tensor([43,  1, 51, 39, 63,  1, 40, 43]), the target is: 1
----------
Batch 1
----------
When input is tensor([58]), the target is: 46
When input is tensor([58, 46]), the target is: 43
When input is tensor([58, 46, 43]), the target is: 1
When input is tensor([58, 46, 43,  1]), the target is: 43
When input is tensor([58, 46, 43,  1, 43]), the target is: 39
When input is tensor([58, 46, 43,  1, 43, 39]), the target is: 56
When input is tensor([58, 46, 43,  1, 43, 39, 56]), the target is: 57
When input is tensor([58, 46, 43,  1, 43, 39, 56, 57]), the target is

Each of these batch samples becomes a row in the data. What we end up getting is a `(batch_size, block_size)`-shaped tensor for our training data.For batch_size=4 and block_size=8, this leads to a 4x8 tensor.

Because we predict the next token for each position in the block, we get 8 training samples out of a block of size=8. This means that in this batch, we have 32 training samples.

In [22]:
len(y)

4

### Starting neural network training with the simplest model: bigrams
Now we can start using this input for neural network training. We can start with the simplest model: the bigram.

#### What is a bigram model?
A bigram model is a model that predicts the next word given the previous word in a sentence. It assumes that the probability of the next word is *only affected by the previous word*.

For our case, we'll use a bigram character model, where the probability of a given character is based only on the probability of the previous character.

In [27]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x114a0e470>

In [36]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        return logits

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


Let's see how this would look so far.

In [37]:
model = BigramLanguageModel(vocab_size=vocab_size)
print(f"{model.token_embedding_table.weight.shape=}")

model.token_embedding_table.weight.shape=torch.Size([65, 65])


Let's make a naive prediction. Given a particular index in our vocabulary, let's see what the predicted output is:

In [44]:
test_idx = 5
logits = model(torch.tensor([test_idx]))
prediction = model.predict(torch.tensor([test_idx]))
print(f"Input index: {test_idx}, Prediction index: {prediction.item()}")
# tensor of shape (1, vocab_size), with each element being the
# logit for each token in the vocabulary at the given index.
print(f"Output logits shape: {logits.shape}")
# we take the max index of the logits to get the
# predicted token.
print(f"Input: {itos[test_idx]}, Prediction: {itos[prediction.item()]}")

Input index: 5, Prediction index: 38
Output logits shape: torch.Size([1, 65])
Input: ', Prediction: Z


Let's try it on a batch of texts. Let's get the accuracy based on a random set of results. When we run `forward`, we get a result of shape `(batch_size, block_size, vocab_size)`. We get a `[vocab_size]`-shaped tensor for all positions in our `[batch_size, block_size]`-shaped batch.

In [45]:
batch_logits = model(xb)
preds = model.predict(xb)

In [46]:
# batch logits shape = (batch_size, block_size, vocab_size)
print(f"{batch_logits.shape=}")

# preds shape = (batch_size, block_size) because we
# take the argmax on the -1 dim to get the predicted token.
print(f"{preds.shape=}")
print(f"{yb.shape=}")
print(f"Accuracy: {(preds == yb).float().mean()}")

batch_logits.shape=torch.Size([4, 8, 65])
preds.shape=torch.Size([4, 8])
yb.shape=torch.Size([4, 8])
Accuracy: 0.03125


Our model only has an accuracy of 3%. Again, this makes sense given that our model is randomly initialized.

Of course this doesn't have any actual basis, since this is based on a randomly initialized lookup table.

Let's add a loss function to our model so we can have a way to evaluate "how good" the predictions are.

In [47]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # flatten the logits
            logits = logits.view(B*T, C)
            # flatten the targets
            targets = targets.view(B*T)
            # compute the loss
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


Now let's get the predictions and loss. We get the logits for all `[batch_size, block_size]` characters as well as the total loss.

We're expecting something like `-ln(1/65) ~ 4.17`, which is what the loss would be if we were accurate `1/65` of the time (which is expected for a random model with vocab size of 65).

In [49]:
model = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = model(xb, yb)
print(f"{logits.shape=}")
print(f"{loss=}")

logits.shape=torch.Size([32, 65])
loss=tensor(4.7801, grad_fn=<NllLossBackward0>)


Now that we can evaluate the loss, we'd like to also be able to do generation from the model.

In [51]:
class BigramLanguageModel(nn.Module):
    """Bigram model."""
    def __init__(self, vocab_size):
        super().__init__()
        # read off the logits for the next token
        # from a lookup table.
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        """Forward pass."""
        # get the logits for the next token.
        # for a given batch with shape (B, T), we will get an output
        # of shape (B, T, C), where B = batch size, T = sequence length
        # and C = vocab size.
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # flatten the logits
            logits = logits.view(B*T, C)
            # flatten the targets
            targets = targets.view(B*T)
            # compute the loss
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def predict(self, x):
        """Predict the next token."""
        # get the logits
        logits = self(x)
        # get argmax
        return logits.argmax(dim=-1)


    def generate(self, idx, max_new_tokens):
        """Generate new tokens.
        
        idx is (B,T) shaped tensor with array of indices
        that are in the current context.

        At each step, we predict the next token and add it to the context.

        Algorithm:
            - We get the logits for the last token in the context for each sequence
            in the batch
            - We apply softmax to get probabilities for the next token, for each
            sequence in the batch.
            - We sample from the distribution to get the next token for each sequence
            in the batch.
            - We add the new token to the context and repeat.
        """
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # focus on last time stamp
            # becomes (B, C) since we take the logits of the
            # last position in each sequence of the batch.
            logits = logits[:, -1, :]
            # apply softmax to get probabilities.
            # (B, C), but now values are probabilities.
            # Each row is a probability distribution whose values add
            # up to 1.
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            # (B, 1). We get the next token for each sample in the batch.
            next_token = torch.multinomial(probs, num_samples=1)
            # we add the new token to the context and repeat.
            # idx originally (B, T), next_token (B, 1), so when we concatenate
            # we get (B, T+1)
            idx = torch.cat([idx, next_token], dim=1) # (B, T+1)

        # shape of idx is (B, T+max_new_tokens) since we've added
        # max_new_tokens to the context.
        return idx


Let's see how well this works.

In [71]:
model = BigramLanguageModel(vocab_size=vocab_size)

In [72]:
max_new_tokens = 100
generated_tokens = model.generate(xb, max_new_tokens)
# batch_size of 4
# 8 initial tokens
# => (4, 8)
print(f"{xb.shape=}")
# batch_size of 4
# for each sequence, 8 initial tokens + 100 generated tokens.
# => (4, 108)
print(f"{generated_tokens.shape=}")

xb.shape=torch.Size([32, 8])
generated_tokens.shape=torch.Size([32, 108])


Let's decode both of these to see the actual letters. First, let's look at our batch.

In [59]:
decoded_batch = []
for seq in xb:
    output_str = ""
    for idx in seq:
        output_str += itos[idx.item()]
    decoded_batch.append(output_str)

for i, seq in enumerate(decoded_batch):
    print(f"String {i}: {seq}")

String 0: e may be
String 1: the ears
String 2: ation? Y
String 3: ore I ca


Now let's look at our generated text.

In [61]:
decoded_generated_text = []
for seq in generated_tokens:
    output_str = ""
    for idx in seq:
        output_str += itos[idx.item()]
    decoded_generated_text.append(output_str)

for i, (seq, generated_seq) in enumerate(
    zip(decoded_batch, decoded_generated_text)
):
    print(f"String {i}: {seq} -> {generated_seq}\n")
    print('-' * 10)

String 0: e may be -> e may beg:yS??smx 3WiGMiBLd$m'iH?sNQC!3OSzKldX.Mmx;kQvn
KGBtkQMcuUmINRn$I,t BhhMGdoCxIbEdLK
eIfaX-gOI-YURdNT

----------
String 1: the ears -> the ears?TLo,rTH?zdNvs-qXy.iHVb.iv:B:r
mcC:BqE!J:i.xbdY$IZxlWj!xikQ?BvESHfaQC:Q!BmIivJCCFuxXsHhrwaWILhXEJp-o

----------
String 2: ation? Y -> ation? YXRcatLGJx'3W;gXZw.
'A3YCPbJyYBy3hNZ oTkL'$'c&&qLk-Xj!fAzp;r!.YsHpoELFi?ghNTtkOGJsxruJDEJSUk,tpv'VbMZ

----------
String 3: ore I ca -> ore I caH?-:aCpPWSCZ:ejBLGqis?g'Aoj exzzoQPPfGRlvXoJbu.RhlK?g:SHlM?zkcoagxSHSAxkRLilWCqRcj!aP-jUclqXMA33PluV

----------


As we can see, the strings that are generated are a bunch of random-looking characters. This is because we randomly generated the weights of our lookup table.

Let's set up the logic to train our model then so that it becomes less random.

We'll set up what the training logic will look like. The current results from training won't be very impressive, but we'll use the same logic and swap in more powerful training logic later.

In [73]:
lr = 1e-3
batch_size = 32
n_epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [64]:
for epoch in range(n_epochs):
    # sample a batch of data
    xb, yb = get_batch(split="train")

    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 4.737148761749268
Epoch 10, Loss: 4.644645690917969
Epoch 20, Loss: 4.661149501800537
Epoch 30, Loss: 4.517314910888672
Epoch 40, Loss: 4.539600372314453
Epoch 50, Loss: 4.698057651519775
Epoch 60, Loss: 4.696743011474609
Epoch 70, Loss: 4.54141902923584
Epoch 80, Loss: 4.665844440460205
Epoch 90, Loss: 4.5253987312316895


This is our general logic. The loss doesn't go down by much here. Let's try to train for more iterations.

Our training loop has the effect of updating the initial weights in our lookup table such that they'll more accurately reflect what we expect.

Let's train this for many more iterations though since the update step is pretty simple, in order to see if we can get a more useful result.

In [74]:
for epoch in range(10000):
    # sample a batch of data
    xb, yb = get_batch(split="train")

    # evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 4.638540267944336
Epoch 1000, Loss: 3.658257246017456
Epoch 2000, Loss: 3.0377516746520996
Epoch 3000, Loss: 2.8597161769866943
Epoch 4000, Loss: 2.5883102416992188
Epoch 5000, Loss: 2.5793514251708984
Epoch 6000, Loss: 2.4103808403015137
Epoch 7000, Loss: 2.4814467430114746
Epoch 8000, Loss: 2.4670302867889404
Epoch 9000, Loss: 2.4843976497650146


Let's take a look at our results.

It's clear that even though the results are still pretty random, it's much better than the original results. It's beginning to at least take the structure and lengths that you expect in a Shakespeare play. Words are still pretty random but it looks a lot more like an encrypted message than a random string of characters. Heck, it even appears to get close to resembling some text.

This is a promising approach, it's impressive that we can begin to get somewhat plausible text just from a simple bigram model combined with backpropagation. We don't even include any other context besides what character came earlier! During backpropagation, the gradients are computed for the selected rows in the embedding lookup table, and then the optimizer updates these rows based on the gradients. Across epochs, this has the effect of updating the rows of the embedding lookup table so that they more appropriately represent the probability of next characters given a certain reference index character.

In [75]:
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


Wen
TI ks, s sen bance;
Herd,Wh t!
Yolooutto knd wroushilde, ay nd Whan;mod nd mer d ononisinat I Hegids o mapte t beve theamous invere, beloveters HARO,



I aven he ft ouray o man:
Thesujhaplis, m Lind tr anngredsowinavourlfat wathin k, sureind wahavan?
DUSe wers k.
ING s hil ard istis,
Mavesh tsthe t muirey fave urcke by n EDI ss K:
I:
I murolengny fot ifothe,
O:-al:
Dondsuphenyookerou ucrcobr aly.
Ancre dagbeausoroours den.Whillof plearr'laret RYCESwnte t, bishepatifay hatt gdan, waneve me c


We can make this much more powerful though if we can use more of the context than just what character came before the one that we're trying to predict.

### The mathematical trick in self-attention

Let's consider the following example:

In [76]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B, T, C)
print(f"{x.shape=}")

x.shape=torch.Size([4, 8, 2])


Currently, the 8 tokens in each sequence are not talking to each other. We want the tokens to talk to each other, but in a very specific way.

For example, if we are at the 5th token, we want to use the 1st-4th tokens as context, but not any future tokens.

What is an easy way to encapsulate the "context"? One naive way is to just take the average of the tokens before the $t^{th}$ token and use that as the context for the $t^{th}$ token.

If I'm the 5th token, for example, I want to take the information from my time step but also steps 1-4 as well, and then take a weighted average of the those, in order to create a feature vector that "summarizes" the 5th token.

This obviously loses a lot of information, especially positional information. This falls for a lot of the same flaws as, say, one-hot encoding or bag of words does, namely that you can't capture position. For example, "the cat is jumping over the dog" and "the dog is jumping over the cat" are treated identically because they have the same words. We take the "average representation" of all the words up to and including the $t^{th}$ token.

In [77]:
xbow = torch.zeros((B,T,C)) # bag of words
# for each sequence in the batch
for b in range(B):
    # for each token in the sequence
    for t in range(T):
        # get the previous tokens.
        x_prev = x[b,:t+1] # (t, C)
        # get the mean of the previous tokens
        x_mean = x_prev.mean(dim=0)
        # set the mean to the current token.
        xbow[b,t] = x_mean

# we want x[b,t] = mean_{i<=t} x[b,i]

Let's take a look at how this looks:

In [79]:
print(f"{xbow.shape=}")

xbow.shape=torch.Size([4, 8, 2])


Let's see how this looks for, say, the 1st sequence. We see that for the first sequence, we get a result with shape `[8,2]`.

In [89]:
print(f"{x[0,:,:].shape=}")

x[0,:,:].shape=torch.Size([8, 2])


Let's figure out how we got the value at, say, `[0, 4,:]`, the 5th element in the tensor.

In [97]:
print(f"{x[0,4,:]=}") # original value
print(f"{xbow[0,4,:]=}") # averaged value

x[0,4,:]=tensor([0.3612, 1.1679])
xbow[0,4,:]=tensor([0.3525, 0.0545])


Let's see what the elements are from the 1st to the 5th elements and then average them.

In [96]:
print(f"{x[0,4,:]=}")

x[0,4,:]=tensor([0.3612, 1.1679])


In [98]:
print(f"{x[0,0:5,:]=}")
print(f"{x[0,:5].mean(dim=0)=}") # matches our averaged value.

x[0,0:5,:]=tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679]])
x[0,:5].mean(dim=0)=tensor([0.3525, 0.0545])


We took the logits of the 1st to 5th values, averaged them out, and that was our representation for the 5th value.

#### The trick: migrating this to matrix multiplication
We can actually do this quite efficiently if we use matrix multiplication.

Let's see how this can work.

In [101]:
torch.manual_seed(42)
a = torch.ones((3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f"{a=}")
print(f"{a.shape=}")
print(f"{b=}")
print(f"{b.shape=}")
print(f"{c=}")
print(f"{c.shape=}")

a=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
a.shape=torch.Size([3, 3])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
b.shape=torch.Size([3, 2])
c=tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
c.shape=torch.Size([3, 2])


Multiplying by a matrix of ones naturally adds all of the elements per column:

$$
\left[
\begin{matrix}
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
2\\
6\\
6\\
\end{matrix}
\right]
= 14
\rightarrow
\left[
\begin{matrix}
1&1&1\\
1&1&1\\
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
2\\
6\\
6\\
\end{matrix}
\right]
=
\left[
\begin{matrix}
14\\
14\\
14\\
\end{matrix}
\right]
$$

$$
\left[
\begin{matrix}
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
7\\
4\\
5\\
\end{matrix}
\right]
= 16
\rightarrow
\left[
\begin{matrix}
1&1&1\\
1&1&1\\
1&1&1\\
\end{matrix}
\right]
*
\left[
\begin{matrix}
7\\
4\\
5\\
\end{matrix}
\right]
=
\left[
\begin{matrix}
16\\
16\\
16\\
\end{matrix}
\right]
$$

We can then simply average these, if we wanted, by taking any of the values per column and dividing by the mean:
$$
\left[
\begin{matrix}
\frac{14}{3}&\frac{16}{3}\\
\end{matrix}
\right]
$$

We've now established that we can easily get the sums of a tensor via matrix multiplication.

Let's now look at another piece for the calculation: `torch.tril`.

In [102]:
# lower triangular matrix
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

Now let's replicate our logic from above.

In [103]:
a = torch.tril(torch.ones(3,3))

In [105]:
c = a @ b
print(f"{a=}")
print(f"{b=}")
print(f"{c=}")
print(f"{c.shape=}")

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
c.shape=torch.Size([3, 2])


Let's break down what we're looking at by looking at the output matrix.

The first output row is $\left[2,7\right]$. We get this through the matrix multiplication of:
$$
\left[
1,0,0
\right]
*
\left[
\begin{matrix}
2&7\\
6&4\\
6&5\\
\end{matrix}
\right]
\rightarrow
\left[(1*2 + 0*6 + 0*6),(1*7 + 0*4 + 0*5)\right]
\rightarrow
\left[2, 7\right]
$$

We take the sum of the first element and zero out the rest.

The second output row is $\left[8, 11\right]$. We get this through the matrix multiplication of:
$$
\left[
1,1,0
\right]
*
\left[
\begin{matrix}
2&7\\
6&4\\
6&5\\
\end{matrix}
\right]
\rightarrow
\left[(1*2 + 1*6 + 0*6),(1*7 + 1*4 + 0*5)\right]
\rightarrow
\left[8, 11\right]
$$

We calculate the second row by suming the first and second elements and zeroing out the rest.

The same logic applies for the third row.

By multiplying by the triangular matrix of ones, we can get the cumulative sum, at index $t$, of all elements up to and including $t$, which is exactly what we had done with a for loop before. This lets us perform the operation of "get the sum of all the previous terms up to the term at position $t$" in a mathematical way.

Now, to calculate the means, we can just change $a$ so that we normalize across all the rows:

In [106]:
a = torch.tril(torch.ones(3,3))
print(f"a before normalization: {a}")
a = a / a.sum(dim=-1, keepdim=True)
print(f"a after normalization: {a}")


a before normalization: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a after normalization: tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
