## Environment setup

In [112]:
!uv sync --extra cpu
## For NVIDIA gpu use !uv sync --extra cu128

[2mResolved [1m140 packages[0m [2min 2ms[0m[0m
[2mAudited [1m110 packages[0m [2min 0.11ms[0m[0m


In [113]:
!uv tool upgrade --all
!uv lock

Nothing to upgrade
[2mResolved [1m140 packages[0m [2min 1ms[0m[0m


In [114]:
import torch

print(torch.cuda.is_available())

False


## Prepare training data

In [7]:
# read it in to inspect it
with open("src/mini_gpt/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [8]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [9]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [10]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


### Apply encoding (convert text into integers)

In [11]:
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


def encode(s):
    return [stoi[c] for c in s]  # encoder: take a string, output a list of integers


def decode(s):
    return "".join(
        [itos[i] for i in s]
    )  # decoder: take a list of integers, output a string


print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


### Use torch and tensors to have faster mathematical calculations

In [13]:
!uv add torch

[2K[2mResolved [1m141 packages[0m [2min 11.82s[0m[0m                                      [0m
[2K[2mAudited [1m10 packages[0m [2min 0.02ms[0m[0m                                        [0m


In [12]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch  # we use PyTorch: https://pytorch.org

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(
    data[:1000]
)  # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

### Split data into training and validation dataset to avoid overfitting

In [45]:
# Let's now split up the data into train and validation sets
n = int(0.9 * len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

## Visualize training input and ouput data

In [14]:
block_size = 8
train_data[: block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [15]:
x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [16]:
torch.manual_seed(1337)
batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y


xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("----")

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # time dimension
        context = xb[b, : t + 1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53

In [21]:
print(xb)  # our input to the transformer
print(yb)  # our target for given input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


## Prepare basic Bigram language model to calculate loss and logits

In [23]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x7f5b5d35f490>

In [24]:
class BigramLanguageModel(nn.Module):
    """
    A simple Bigram Language Model.
    When given a character, it predicts the character that is most likely to come next.
    It doesn't have any memory of characters before the immediately preceding one.
    """

    def __init__(self, vocabulary_size):
        super().__init__()
        """
            This is the heart of our model: an Embedding table.
            It has one row for each character in our vocabulary.
            Each row contains the "logits" (raw scores) for the *next* character in the sequence.
            So, row `i` of this table contains the model's prediction for what comes after character `i`.
            The size is (vocabulary_size, vocabulary_size) because for each character, we need to output
            a score for every possible next character.
        """
        self.token_embedding_table = nn.Embedding(vocabulary_size, vocabulary_size)

    def forward(self, input_indices, targets=None):
        """
        This is the "forward pass". It's how the model makes predictions and calculates its error.
        - input_indices: A (Batch, Time) tensor of character indices.
        - targets: A (Batch, Time) tensor of the true next character indices.
        """
        # For every input character index, we look up its corresponding row in the embedding table.
        # This row contains the raw prediction scores (logits) for the next character.
        # Input shape: (Batch, Time) -> (B, T)
        # Output shape: (Batch, Time, Channel/VocabSize) -> (B, T, C)
        logits = self.token_embedding_table(input_indices)

        # The loss is a measure of how wrong our model's predictions (logits) were
        # compared to the actual targets. We want to minimize this value.
        if targets is None:
            # If we're just generating text, we don't have targets, so there's no loss.
            loss = None
        else:
            # To calculate the loss, PyTorch's cross_entropy function needs our tensors
            # in a specific shape. We flatten the Batch and Time dimensions into one.
            # Logits shape: (B, T, C) -> (B*T, C)
            # Targets shape: (B, T) -> (B*T)
            B, T, C = logits.shape
            logits_reshaped = logits.view(B * T, C)
            targets_reshaped = targets.view(B * T)
            loss = F.cross_entropy(logits_reshaped, targets_reshaped)

        return logits, loss

    def generate(self, starting_indices, max_new_tokens):
        """
        This function generates new text, token by token.
        - starting_indices: The starting sequence of characters, e.g., a single newline. (B, T) tensor.
        - max_new_tokens: The maximum number of new tokens (characters) to generate.
        """
        current_indices = starting_indices
        # Loop for the number of tokens we want to generate
        for _ in range(max_new_tokens):
            # 1. Get Predictions (Forward Pass)
            # We get the logits from the model for the current sequence of indices.
            # We don't need the loss here since we are not training.
            logits, loss = self.forward(current_indices)

            # 2. Focus on the Last Prediction
            # The logits tensor contains predictions for *every* timestep. We only care about
            # the prediction for the very last character in our sequence.
            # So, we grab the logits from the last time step.
            # Shape: (B, T, C) -> (B, C)
            last_step_logits = logits[:, -1, :]

            # 3. Convert Logits to Probabilities
            # Softmax turns our raw scores (logits) into a probability distribution.
            # All probabilities will sum to 1.
            # Shape: (B, C)
            probs = F.softmax(last_step_logits, dim=-1)

            # 4. Sample a New Character
            # We sample from the probability distribution to pick the next character.
            # This adds randomness. If we always picked the highest probability, the model
            # would be very repetitive.
            # Shape: (B, 1)
            next_index = torch.multinomial(probs, num_samples=1)

            # 5. Append the New Character
            # We add our newly chosen character's index to the end of our running sequence.
            # This new sequence becomes the input for the next iteration of the loop.
            # Shape: (B, T+1)
            current_indices = torch.cat((current_indices, next_index), dim=1)

        return current_indices

In [25]:
# Create an instance of our model
model = BigramLanguageModel(vocab_size)

# Perform a forward pass with our batch of data
print("\n--- Forward Pass and Loss Calculation ---")
logits, loss = model.forward(xb, yb)
print("Shape of Logits:", logits.shape)
print(
    "Loss (how wrong the model is):", loss.item()
)  # .item() gets the raw number from the tensor


--- Forward Pass and Loss Calculation ---
Shape of Logits: torch.Size([4, 8, 65])
Loss (how wrong the model is): 4.878634929656982


In [26]:
# Let's generate some text!
# We'll start with a single token, a tensor containing [[0]], which is a placeholder.
# In a real scenario, this might be the index for a newline character.
print("\n--- Generating New Text ---")
start_context = torch.zeros((1, 1), dtype=torch.long)
generated_indices = model.generate(starting_indices=start_context, max_new_tokens=50)[0]
generated_text = decode(generated_indices.tolist())

print("Generated text:\n", generated_text)


--- Generating New Text ---
Generated text:
 
SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHj


## Train Bigram model on training data

In [33]:
# create a PyTorch optimizer
# Its one and only job is to update the model's internal numbers (parameters)  -- in our case, the numbers inside the token_embedding_table,
# to make the model better.

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [38]:
# The training loop. We will repeat the learning process 100 times.
for steps in range(100):
    # --- Step 1: Get a Practice Worksheet ---
    # We get a small, random sample of data (a "batch") from our training set.
    # xb = the practice problems (e.g., "hello worl")
    # yb = the correct answers (e.g., "ello world")
    xb, yb = get_batch("train")

    # --- Step 2: Take the Test (The "Forward Pass") ---
    # The model (student) looks at the problems (xb) and makes its best guess
    # at the answers. The results are the raw scores (logits) and the grade (loss).
    # The 'loss' is a single number telling us how wrong the model was on this batch.
    logits, loss = model(xb, yb)

    # --- Step 3: Prepare for Feedback (Erase Old Notes) ---
    # This is a critical housekeeping step. We tell the teacher (optimizer) to
    # forget about the grades from the *previous* practice test. If we don't do this,
    # the feedback from all past tests would get jumbled up.
    # `set_to_none=True` is just a small performance optimization.
    optimizer.zero_grad(set_to_none=True)

    # --- Step 4: Figure Out What Went Wrong (The "Backward Pass") ---
    # This is where the magic happens. The grade (loss) is used to calculate
    # exactly how each individual parameter in the model contributed to the final
    # error. This process is called backpropagation.
    # It gives the teacher a "report card" (called gradients) that says, "this
    # parameter was very wrong and should be decreased," or "this one was pretty
    # good and should be increased slightly."
    loss.backward()

    # --- Step 5: Update the Knowledge (Apply Corrections) ---
    # The teacher (optimizer) takes the report card (gradients) from the
    # backward pass and uses it to update the model's parameters.
    # It nudges every number in the model in the correct direction, according
    # to the learning rate (lr) we set earlier.
    # This is the moment the model *actually learns and gets smarter*.
    optimizer.step()

# After the loop finishes...
# We print the grade (loss) from the VERY LAST practice test the model took.
# We hope that after 100 rounds of practice, this final grade is much better
# (a much lower number) than it was at the beginning.
print(loss.item())

4.716907501220703


In [47]:
print("\n--- Generating New Text (After training model) ---")
start_context = torch.zeros((1, 1), dtype=torch.long)
generated_indices = model.generate(starting_indices=start_context, max_new_tokens=500)[
    0
]
generated_text = decode(generated_indices.tolist())

print("Generated text:\n", generated_text)


--- Generating New Text (After training model) ---
Generated text:
 
ogM$J:ChFN&Bju3fVb
CUzkzepxEG3SPttRF.yCX$pxBME,NBjKpNSTFviMuBUCpXRr'dxrIcL&ya N:VO-FbPlcGonDq
nN3:CEQJnA:AtepgH?BofY.R:3f?:BuFjPDnlIotRwPi.B-tRju3faevrIk!bHWCJJm,ZRFVOl,stteAUv''V;zL:CERyCYhssC'nFbotv mN3f fHx'O3!eNc'jbHxrBjmn?YCjuYUg:ChVhchOlyGJMnk$qbbI$-xKuCajXa!UXNeNT'XQ!wtNCIgcdZiNXGQzkHEJpEr?E:JGMAUNH!aHQFsmZvjnQ:'PUzPmrFM
jlcir'VJK&D-OKvCUim:GXOBLvya!b
rIXIdZ:CKHd?$gRCONgcfmrX:f?RFEaW:CdDM$GG;RltehsZya!PTS.Db?3fhs!IqNUm;Jgbkzeh.:CjgX?VwbfSF$I''VbHEn$slyKd-msBkLI&e?ay,eEX3Hd,KNMjpmO;aeASiZO


## The mathematical trick in self-attention

Imagine you have a sequence of items (like words in a sentence). \
For each item, you want to create a new representation of it by averaging it with the items that came before it. 
1. The 1st new item is just the 1st old item.
2. The 2nd new item is an average of the 1st and 2nd old items.
3. The 3rd new item is an average of the 1st, 2nd, and 3rd old items.
4. And so on...

### Matrix multiplication for weighted aggregation

In [60]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
print("Printing matrix A:\n", a)

Printing matrix A:
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [63]:
a = a / torch.sum(a, 1, keepdim=True)
print("Printing matrix A:\n", a)

Printing matrix A:
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [62]:
b = torch.randint(0, 10, (3, 2)).float()
print("Printing matrix B:\n", b)

Printing matrix B:
 tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])


In [61]:
c = a @ b
print("Printing matrix C:\n", c)

Printing matrix C:
 tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


### Different implementation of self-attention (aggregated sum)

In [89]:
# consider the following toy example:

torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

### Version 1.
Using custom loop to add new character and average it.

In [97]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]  # (t,C)
        xbow[b, t] = torch.mean(xprev, 0)

xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 2
Using matrix multiplication (discussed in previous block)

In [96]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x  # (B, T, T) @ (B, T, C) ----> (B, T, C)

xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 3
Using Softmap

In [98]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
xbow3[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

## Self-attention head
give each word (or "token") three different roles:
1. Query (q): When a token wants to find relevant information, it sends out a "query." This is like you holding up a sign saying, "I am writing about topic A, who has information about this?"
2. Key (k): Each token in the sequence has a "key." This is like the title or index card of a book, saying, "I contain information about topic A."
3. Value (v): This is the actual content or information of the token. It's the answer you get when you find a match. "Here is the information I have on topic A."

In [100]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

In [102]:
# let's see a single Head perform self-attention
head_size = 16
# A linear layer to produce the "Key" for each token.
key = nn.Linear(C, head_size, bias=False)
# A linear layer to produce the "Query" for each token.
query = nn.Linear(C, head_size, bias=False)
# A linear layer to produce the "Value" for each token.
value = nn.Linear(C, head_size, bias=False)

# Now, let's generate the K, Q, and V vectors for every token in our input x
k = key(x)  # (B, T, 16) - Each of the T tokens now has a Key vector of size 16.
q = query(x)  # (B, T, 16) - Each of the T tokens now has a Query vector of size 16.

### The Calculation: Step-by-Step

#### Step 1: Find Affinity Scores (Dot Product)
Q: Why do a dot product? \
A: If a query vector and a key vector are pointing in a similar direction, their dot product will be high. \
The resulting wei (weights) matrix is a T x T grid of scores. The value at wei[i, j] tells us how much attention token i should pay to token j.

In [103]:
# Match every query with every key.
# q shape: (B, T, 16)
# k shape: (B, T, 16) -> we transpose the last two dimensions to (B, 16, T)
# The @ performs a batch matrix multiplication.
# Resulting shape: (B, T, 16) @ (B, 16, T) ---> (B, T, T)
wei = q @ k.transpose(-2, -1)

#### Step 2: Masking (Preventing Cheating)
Q: What is masking? \
A: We can't let a token "see" into the future. Token 3 should only be able to get information from tokens 1, 2, and 3, not from token 4.

In [104]:
# Create a lower-triangular matrix of ones.
tril = torch.tril(torch.ones(T, T))
# Where tril is 0 (the upper triangle), replace the values in 'wei' with negative infinity.
wei = wei.masked_fill(tril == 0, float("-inf"))

#### Step 3: Softmax (Normalize Scores into Weights)
We use the softmax function to turn them into nice percentages that all add up to 1. When softmax sees -inf, it turns it into 0.

In [108]:
# Softmax turns scores into a probability distribution (weights).
# dim=-1 means the softmax is applied across each row.
wei = F.softmax(wei, dim=-1)

#### Step 4: Aggregate the Values
We have the attention weights, we can finally create our new, improved token representations. \
We do this by performing a weighted sum of all the Value vectors.

In [109]:
# Get the "Value" for each token.
v = value(x)
# Perform the weighted aggregation.
out = wei @ v

#### Output
The output out[i] for the i-th token is a sum of all v[j] vectors, weighted by wei[i, j]. \

The final shape of out is (B, T, 16). Each of the 8 tokens in our sequence now has a new vector of size 16 that is context-aware, built by aggregating information from itself and all the tokens that came before it. This is the output of one attention head.

#### Step 5: Scaled Dot-Product Attention
As the head_size gets larger (e.g., from 16 to 64 to 512), the dot product values will also get much larger and more spread out. \
Consider softmax([1, 2, 3]) -> [0.09, 0.24, 0.67]. The outputs are reasonably distributed. \
Now, imagine the head_size is large, and our dot products are bigger. Let's just multiply the inputs by 8: \
softmax([8, 16, 24]) -> [9.3e-08, 2.0e-04, 9.9e-01].

Notice what happened? The output became extremely "spiky" or "hard." It's practically [0, 0, 1]. \
The softmax is now super confident that the last element is the only one that matters.

##### Summary:
To keep the inputs to the softmax function under control, preventing them from becoming too large. This ensures that the gradients are stable and don't vanish, which allows the model to learn effectively, especially when using a large head_size.

In [None]:
wei = q @ k.transpose(-2, -1) * head_size**-0.5

### All single head self attention code at once

In [115]:
# All single head attetion code in one block
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
wei = (
    q @ k.transpose(-2, -1) * head_size**-0.5
)  # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
# out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [116]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2853, 0.7147, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2858, 0.3704, 0.3437, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2679, 0.3740, 0.2292, 0.1289, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1926, 0.1797, 0.1312, 0.1444, 0.3521, 0.0000, 0.0000, 0.0000],
        [0.1533, 0.1322, 0.2156, 0.2225, 0.0987, 0.1777, 0.0000, 0.0000],
        [0.0863, 0.1575, 0.1328, 0.1596, 0.1788, 0.1199, 0.1652, 0.0000],
        [0.1044, 0.1764, 0.1101, 0.0950, 0.1406, 0.1058, 0.1436, 0.1241]],
       grad_fn=<SelectBackward0>)

# Linting check & format

In [70]:
!uv run ruff check
!uv run ruff format

All checks passed!
5 files left unchanged
