# Understand Chat-GPT (and self-attention)

In [18]:
import torch
import torch.nn as nn
from torch.nn import functional as F

## 1. Get some playful data
I will use the tinyshakespare, which is basically everything written by shakespare to train a model.

In [1]:
!wget -nc https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

File ‘input.txt’ already there; not retrieving.



In [2]:
#read it to inspect it
with open('input.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

In [4]:
print("length of the number of chracters: " + str(len(text)))
print(text[:500])

length of the number of chracters: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


## 2. Pre-process data
In this exercise, tokens will be each character (capital and lower letters will be treated differently, also punctuation signs).

In [5]:
# Unique characters of the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


We will mapp each character (defined above) to a number and two dictionaries to go from character to number and number to character.

In [6]:
# Create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
print(stoi)
print(itos)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i',

In [7]:
# Enconder: take a string, output the list of integers assigned to each.
encode = lambda s: [stoi[c] for c in s]
# Decoder: Take a list  of integers, output a string
decode = lambda s: [itos[c] for c in s] 

print(encode("Hello World"))
print(decode(encode("Hello World")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
['H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd']


# 3. Define the problem as Deep-Learning with PyTorch

In [9]:
# Transform it into a torch of integers to be used in Deep Learning
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


**Note:** Let's keep in mind that each character(punctuation, white space or charcter is included here as a data point).

As a good practice, we take out some validation data out to test things out.

In [10]:
# Split de data into training and evaluation
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

**Block:** When a language model is being trained, each next token will be feed not just by the latest token but also by a set a n previous tokens, this n is defined as the batch size. If we are predicting the last character, it will take into account the last 7 tokens, if we are predicting the 5th token it will take into account the previus 4, and so on.

In [11]:
block_size = 8 #Basically the batches for training fo the data

In [13]:
# Intuition: We want that our algorithm sees a sequence of integers from 1 to 8 (block size) and has the abillity to predict the next character.
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target is: {target}")

# Note: The transformer should be comfortable being context of size 1 and of size 8.

when input is tensor([18]) the target is: 47
when input is tensor([18, 47]) the target is: 56
when input is tensor([18, 47, 56]) the target is: 57
when input is tensor([18, 47, 56, 57]) the target is: 58
when input is tensor([18, 47, 56, 57, 58]) the target is: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


**Batch:** Batch will be the number of blocks that are going to be processed at the same time. We use GPU, better use those in other to optimize all together. Important to remember that between batches they are not communicating.

In [14]:
# Process multiple batches at the same time to make use of the GPU
torch.manual_seed(1337)
batch_size = 4 # number of independent sequences will we process in parallel.
block_size = 8 # maximum number of chracters in the same block

In [15]:
def get_batch(split):
    """
    Generate a batch of input and target sequences for training or validation.

    This function randomly samples `batch_size` starting positions from either 
    the training or validation dataset and returns two tensors:
      - x (inputs): sequences of length `block_size`
      - y (targets): the same sequences shifted one position to the right

    Parameters
    ----------
    split : str
        The dataset to draw from. Must be either:
          - "train" : use the global variable `train_data`
          - "val"   : use the global variable `val_data`

    Returns
    -------
    x : torch.Tensor, shape (batch_size, block_size), dtype=torch.long
        The input sequences, each of length `block_size`. 
        Each row corresponds to one sequence sampled from the dataset.

    y : torch.Tensor, shape (batch_size, block_size), dtype=torch.long
        The target sequences, each also of length `block_size`.
        For each row, `y` is the same as `x` but shifted by one character/token.
        (If x is [c1, c2, c3, ...], then y is [c2, c3, c4, ...])

    Notes
    -----
    - `ix` contains the random starting indices for each sequence.
    - The slicing ensures that `y` is the "next-token prediction target" 
      for `x`, which is the standard setup for language modeling.
    - Both x and y are stacked into tensors so they can be processed in 
      parallel by the model (GPU-friendly).
    """
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

We are basically selecting randomly blocks and batchez out of the original text (already encoded in numbers).

In [17]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[57, 43, 60, 43, 52,  1, 63, 43],
        [60, 43, 42,  8,  0, 25, 63,  1],
        [56, 42,  5, 57,  1, 57, 39, 49],
        [43, 57, 58, 63,  6,  1, 58, 46]])
targets:
torch.Size([4, 8])
tensor([[43, 60, 43, 52,  1, 63, 43, 39],
        [43, 42,  8,  0, 25, 63,  1, 45],
        [42,  5, 57,  1, 57, 39, 49, 43],
        [57, 58, 63,  6,  1, 58, 46, 47]])


## 4. Bigram Language Model
What if we create a language model that predicts the next token just using the previous token (the easiest possible model just for learning).

**B (Batch size)** = number of independent sequences processed in parallel.

**T (Time / Context length)** = max number of tokens per sequence.

**C (Channels / Classes)** = vocabulary size.

In [26]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    """
    BigramLanguageModel

    This simple model learns to predict the next token using *only the current token*
    (no context beyond one step = "bigram").

    Tensor dimensions:
        B = batch size
            → number of independent sequences processed in parallel.

        T = time (context length)
            → maximum number of tokens per sequence.
              Each row has T tokens, so total tokens in a batch = B * T.
        C = vocab size
            → number of output classes (the logits for each next token).

    Flow:
        idx: (B, T) integers [0, vocab_size-1]
        targets: (B, T)
        logits: (B, T, C) → reshaped to (B*T, C)
        loss: scalar cross-entropy over all B*T predictions
    """

    def __init__(self, vocab_size: int):
        super().__init__()
        # Embedding table: each input token maps to a row of length vocab_size.
        # Shape: (vocab_size, vocab_size)
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
        Forward pass.

        Parameters
        ----------
        inputs : torch.Tensor
            Shape (B, T) with token ids.
        targets : torch.Tensor
            Shape (B, T) with ground-truth next token ids.

        Returns
        -------
        logits : torch.Tensor
            Shape (B*T, C). Logits for every token position.
        loss : torch.Tensor
            Scalar mean cross-entropy loss.
        """
        if targets is None:
            loss = None
        else:
            # (B, T, C): each token index is mapped to a vocab-sized vector
            logits_3d = self.token_embedding_table(inputs)

            B, T, C = logits_3d.shape
            # Flatten batch and time so CE sees (N, C) with N=B*T
            logits = logits_3d.view(B * T, C)
            flat_targets = targets.view(B * T)

            loss = F.cross_entropy(logits, flat_targets)
        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        """
        Autoregressive text generation.

        Starting from an initial sequence of token indices, repeatedly sample
        the next token using the model’s bigram predictions and append it.

        Parameters
        ----------
        idx : torch.Tensor
            Shape (B, T) with initial token ids (the "prompt").
        max_new_tokens : int
            Number of tokens to generate beyond the initial prompt.

        Returns
        -------
        idx : torch.Tensor
            Shape (B, T + max_new_tokens). The original prompt followed by
            the newly generated tokens.
        """
        for _ in range(max_new_tokens):
            # Forward pass: (B, T, C) logits for all positions
            logits_3d = self.token_embedding_table(idx)

            # Focus only on the last time step
            # Shape: (B, C)
            logits_last = logits_3d[:, -1, :]

            # Convert logits to probabilities via softmax
            probs = F.softmax(logits_last, dim=-1)  # (B, C)

            # Sample next token from the probability distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # Append new token to the sequence along the time dimension
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)

        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)

# Perplexity = exp(loss)
perplexity = torch.exp(loss)

print("Logits shape:", logits.shape)     # (B*T, C) → (32, 65)
print("Loss:", loss.item())              # scalar
print("Perplexity:", perplexity.item())  # scalar

Logits shape: torch.Size([256, 65])
Loss: 4.756565093994141
Perplexity: 116.3456039428711


In [27]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(''.join(decode(m.generate(idx, max_new_tokens=100)[0].tolist())))


Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


Above is a random text generated, it does not any human-readable as it was not optimized. It is just a sequence of random characters. The idea with the optimizer is reduce the perplexity.

**Intuition**

When we say the vocabulary size is 65, this means there are 65 possible tokens that the model can generate at each step. The logits tensor has the shape B × T, where B represents the batch size and T is the number of tokens in the sequence. Each entry in this tensor corresponds to a score (before applying softmax) that represents how likely the model thinks each of the 65 vocabulary tokens is to be the next token. In other words, for every position in the sequence, the model produces a probability distribution over the entire vocabulary.

**Perplexity:** is a way to measure how well a language model predicts text.

If the model is very confident and usually correct, the perplexity is low.

If the model is often unsure or wrong, the perplexity is high.

You can think of perplexity as “How surprised is the model when it sees the real next word?”

If perplexity = 1 → the model is perfect (always predicts the right word with probability 1).

If perplexity = vocabulary size (e.g., 65) → the model is just guessing randomly.

The perplexity is defined as $ \text{Perplexity} = \exp\!\left(-\frac{1}{N} \sum_{i=1}^N \log P(w_i \mid w_{<i})\right) $.

$P(w_i \mid w_{<i}) $. is the probability of the model assigining to the right token.

### How about if we now, we optimize it?

In [22]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

After applying Adam optimization, our goal is to get a better perplexity

In [24]:
batch_size = 32
for steps in range(100000):

  # Sample a batch of data
  xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())

# Perplexity = exp(loss)
perplexity = torch.exp(loss)

print("Loss:", loss.item())              # scalar
print("Perplexity:", perplexity.item())  # scalar


2.4275641441345215
Loss: 2.4275641441345215
Perplexity: 11.331247329711914


Perplexity was reduced to almost a 10%, which shows the ability of the optimizer.

The below is what will happen if we just randomly create random text, how will it look like.

In [25]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(''.join(decode(m.generate(idx, max_new_tokens=100)[0].tolist())))


Ar bungin u,
Myo fe.
Heabeayoisurbet beld, bybe s; bonsth, t hid Y:
Plooweaureiss fou, Fouffltyorbra


## 5. Mathematical simple option of Self-Attention Mechanism

In [29]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # Batch, Tokens, Vocab Size
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

x is a torch that represent the number of batches * the number of tokens * the size of the vocabulary (the probability of the next token being any of the vocab size).

**Intuition:** The easiest way to know the probability of the next token is just by averagin the probability of this next token of the previous tokens within the same block..

In [34]:
# We want x[b,t] = mean_{i<=t} x[b, i]
# The attention here is that what is the average of the same vocab character for
# all the precedent characters.
xbow = torch.zeros(B, T, C, device=x.device, dtype=x.dtype)
for b in range(B):
    for t in range(T):
        xbow[b, t] = x[b, :t+1].mean(dim=0)

In [31]:
# In order not to repeat using for, using Matrix multiplication
wei = torch.tril(torch.ones(T, T, device=x.device, dtype=x.dtype))
wei = wei / wei.sum(dim=1, keepdim=True)   # row-normalize
xbow2 = wei @ x

In [None]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

head_size = 16
key   = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)                  # (B, T, H)
q = query(x)                # (B, T, H)
v = value(x)                # (B, T, H)

# scaled dot-product attention scores
wei = q @ k.transpose(-2, -1) / (head_size ** 0.5)  # (B, T, T)

# causal mask (lower triangular)
tril = torch.tril(torch.ones(T, T, device=wei.device)).bool()
wei = wei.masked_fill(~tril, float('-inf'))

# normalize to probabilities
wei = F.softmax(wei, dim=-1)  # (B, T, T)

# weighted sum of values
out = wei @ v  # (B, T, H)

out.shape  # -> torch.Size([4, 8, 16])


torch.Size([4, 8, 16])

**Explanation:**

**Self-attention** is a way for each word in a sequence to decide which other words matter to it when forming its meaning. In the code, every word is first turned into three vectors: **a query** (what it’s looking for), **a key** (what it contains), and **a value** (the information it can share). 

1. Each query compares itself with all keys using dot products to produce attention scores, which measure relevance.
2. These scores are restricted, masked so words cannot look into the future
3. Tokens pass through a softmax to become probabilities. This creates an “attention distribution” that tells a word how much to focus on each other word, including itself.
4.  Finally, each word gathers the values of all words, weighted by these attention probabilities, to create a new, context-aware representation of itself. 

In other words, instead of treating words in isolation, self-attention lets them dynamically borrow information from others:

*Classic Example:*  “bank” in “river bank” pays attention to “river,” while “bank” in “money bank” focuses on “money.”

The output of the code is a matrix where every word has been enriched with the context it deems most important.