# CS336 Assignments

| # | Topic                         | Description                                 |
|---|-------------------------------|---------------------------------------------|
| 1 | Basics                        | Train an LLM from scratch                   |
| 2 | Systems                       | Make it run fast!                           |
| 3 | Scaling                       | Make it performant at a FLOP budget         |
| 4 | Data                          | Prepare the right datasets                  |
| 5 | Alignment & Reasoning RL      | Align it to real-world use cases            |

# Assignment #1
- Implement all of the components (tokenizer, model, loss function, optimizer) necessary to train a standard Transformer language model
- Train a minimal language model

In [5]:
import warnings
warnings.filterwarnings("ignore")

import torch
import lovely_tensors as lt
lt.monkey_patch()

import tiktoken

from datasets import load_dataset

In [8]:
data_hf = load_dataset("roneneldan/TinyStories")
data_hf

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [10]:
sample_sentence = "Your journey starts with one step"
sample_sentence

emb_matrix = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]    
])

In [11]:
ttok = tiktoken.get_encoding("gpt2")
ttok.encode(sample_sentence)

[7120, 7002, 4940, 351, 530, 2239]

## Code a simplified self-attention model without trainable weights

In [14]:
attn_scores = emb_matrix @ emb_matrix.T
attn_scores.v

tensor[6, 6] n=36 x∈[0.294, 1.495] μ=0.806 σ=0.333
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [38]:
attn_weights = attn_scores / attn_scores.sum(dim=1, keepdim=True)
attn_weights.v

tensor[6, 6] n=36 x∈[0.063, 0.233] μ=0.167 σ=0.052
tensor([[0.2241, 0.2140, 0.2113, 0.1066, 0.1026, 0.1415],
        [0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656],
        [0.1454, 0.2277, 0.2248, 0.1280, 0.1104, 0.1637],
        [0.1304, 0.2313, 0.2275, 0.1354, 0.0953, 0.1801],
        [0.1436, 0.2219, 0.2245, 0.1090, 0.2088, 0.0921],
        [0.1350, 0.2325, 0.2269, 0.1405, 0.0628, 0.2022]])

In [39]:
attn_weights = torch.softmax(attn_weights, dim=0)
attn_weights.v

tensor[6, 6] n=36 x∈[0.156, 0.183] μ=0.167 σ=0.005
tensor([[0.1787, 0.1647, 0.1647, 0.1637, 0.1645, 0.1639],
        [0.1652, 0.1670, 0.1669, 0.1673, 0.1654, 0.1679],
        [0.1652, 0.1670, 0.1669, 0.1672, 0.1658, 0.1676],
        [0.1627, 0.1676, 0.1674, 0.1685, 0.1633, 0.1704],
        [0.1649, 0.1660, 0.1669, 0.1641, 0.1829, 0.1560],
        [0.1635, 0.1678, 0.1673, 0.1693, 0.1581, 0.1742]])

In [43]:
context_vec = attn_weights @ emb_matrix
context_vec.v

tensor[6, 3] n=18 x∈[0.426, 0.590] μ=0.514 σ=0.065
tensor([[0.4321, 0.5772, 0.5337],
        [0.4305, 0.5846, 0.5281],
        [0.4308, 0.5844, 0.5279],
        [0.4288, 0.5873, 0.5281],
        [0.4421, 0.5767, 0.5213],
        [0.4256, 0.5897, 0.5307]])

## Code a simplified self-attention model with trainable weights

This self-attention is also called as scaled dot product attention.

Instead of directly computing `attn_scores` from the embedding matrix, we want to do a weighted multiplication with input vectors. This helps in identifying "good context vectors".

In [45]:
emb_matrix.v

tensor[6, 3] n=18 x∈[0.050, 0.890] μ=0.514 σ=0.276
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [74]:
torch.manual_seed(42)

W_query = torch.randn((3, 2))
W_key = torch.randn((3, 2))
W_value = torch.randn((3, 2))

W_query.v, W_key.v, W_value.v


(tensor[3, 2] n=6 x∈[-1.123, 0.337] μ=-0.063 σ=0.549 [[0.337, 0.129], [0.234, 0.230], [-1.123, -0.186]]
 tensor([[ 0.3367,  0.1288],
         [ 0.2345,  0.2303],
         [-1.1229, -0.1863]]),
 tensor[3, 2] n=6 x∈[-0.638, 2.208] μ=0.607 σ=0.927 [[2.208, -0.638], [0.462, 0.267], [0.535, 0.809]]
 tensor([[ 2.2082, -0.6380],
         [ 0.4617,  0.2674],
         [ 0.5349,  0.8094]]),
 tensor[3, 2] n=6 x∈[-1.690, 1.322] μ=0.255 σ=1.266 [[1.110, -1.690], [-0.989, 0.958], [1.322, 0.817]]
 tensor([[ 1.1103, -1.6898],
         [-0.9890,  0.9580],
         [ 1.3221,  0.8172]]))

In [78]:
query = emb_matrix @ W_query
key = emb_matrix @ W_key
value = emb_matrix @ W_value

query.v, key.v, value.v

(tensor[6, 2] n=12 x∈[-0.819, 0.206] μ=-0.110 σ=0.314
 tensor([[-0.8194, -0.0759],
         [-0.3519,  0.1483],
         [-0.3274,  0.1500],
         [-0.1605,  0.1004],
         [ 0.2056,  0.1381],
         [-0.4132,  0.0882]]),
 tensor[6, 2] n=12 x∈[-0.343, 1.993] μ=0.907 σ=0.758
 tensor([[ 1.4948,  0.4861],
         [ 1.9692,  0.4159],
         [ 1.9934,  0.3816],
         [ 0.9301,  0.2818],
         [ 1.8692, -0.3435],
         [ 0.7739,  0.6271]]),
 tensor[6, 2] n=12 x∈[-0.980, 1.506] μ=0.431 σ=0.618
 tensor([[ 1.5058,  0.1444],
         [ 0.6229,  0.4434],
         [ 0.6384,  0.3741],
         [ 0.1070,  0.4535],
         [ 0.7399, -0.9799],
         [-0.0085,  1.1313]]))

In [80]:
attn_scores = query @ key.T
attn_weights = attn_scores / attn_scores.sum(dim=0, keepdim=True)
attn_weights = torch.softmax(attn_weights, dim=1)
attn_weights.v

tensor[6, 6] n=36 x∈[0.153, 0.188] μ=0.167 σ=0.005
tensor([[0.1670, 0.1633, 0.1626, 0.1663, 0.1526, 0.1882],
        [0.1666, 0.1671, 0.1672, 0.1667, 0.1687, 0.1638],
        [0.1665, 0.1672, 0.1673, 0.1667, 0.1691, 0.1632],
        [0.1665, 0.1672, 0.1673, 0.1667, 0.1691, 0.1632],
        [0.1661, 0.1685, 0.1689, 0.1666, 0.1758, 0.1541],
        [0.1667, 0.1663, 0.1663, 0.1667, 0.1652, 0.1688]])

In [81]:
context_vec = attn_weights @ value
context_vec.v

tensor[6, 2] n=12 x∈[0.239, 0.609] μ=0.431 σ=0.178
tensor([[0.5861, 0.2962],
        [0.6029, 0.2562],
        [0.6033, 0.2553],
        [0.6033, 0.2553],
        [0.6095, 0.2395],
        [0.5994, 0.2648]])

### Pytorchification of SelfAttention class

In [84]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out))
        self.W_key   = torch.nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        # x shape: (seq_len, emb_sz)
        query = x @ self.W_query # query shape: (seq_len, d_out)
        key   = x @ self.W_key # key shape: (seq_len, d_out)
        value = x @ self.W_value # value shape: (seq_len, d_out)

        attn_scores = query @ key.T # attn_scores shape: (seq_len, seq_len)
        attn_weights = attn_scores / attn_scores.sum(dim=0, keepdim=True)
        attn_weights = torch.softmax(attn_weights, dim=1)
        context_vec = attn_weights @ value
        return context_vec
    
torch.manual_seed(42)
model = SelfAttention(3, 2)
out = model(emb_matrix)
out.v

tensor[6, 2] n=12 x∈[0.239, 0.609] μ=0.431 σ=0.178 grad MmBackward0
tensor([[0.5861, 0.2962],
        [0.6029, 0.2562],
        [0.6033, 0.2553],
        [0.6033, 0.2553],
        [0.6095, 0.2395],
        [0.5994, 0.2648]], grad_fn=<MmBackward0>)

There have been numerous studies that suggest to do a normalization with d_out ** 0.5 prior to softmax. Let's do that now.

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out))
        self.W_key   = torch.nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        # x shape: (seq_len, emb_sz)
        query = x @ self.W_query # query shape: (seq_len, d_out)
        key   = x @ self.W_key # key shape: (seq_len, d_out)
        value = x @ self.W_value # value shape: (seq_len, d_out)

        attn_scores = query @ key.T # attn_scores shape: (seq_len, seq_len)
        attn_weights = attn_scores / query.shape[1] ** 0.5 # normalizing with sqrt(d_out)
        attn_weights = torch.softmax(attn_weights, dim=1)
        context_vec = attn_weights @ value
        return context_vec
    
torch.manual_seed(42)
model = SelfAttention(3, 2)
out = model(emb_matrix)
out.v

tensor[6, 2] n=12 x∈[0.254, 0.618] μ=0.441 σ=0.136 grad MmBackward0
tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]], grad_fn=<MmBackward0>)

## Code a self-attention model with causal attention

For next token prediction, we attend to words prior to the current word and hence we cannot access future words. This means, our attention cannot be extended to the complete sequence but just the words prior. This is called causal attention. Let's go ahead and mask the attention with future tokens.

In [130]:
torch.manual_seed(42)

W_query = torch.randn((3, 2))
W_key = torch.randn((3, 2))
W_value = torch.randn((3, 2))

W_query.v, W_key.v, W_value.v


(tensor[3, 2] n=6 x∈[-1.123, 0.337] μ=-0.063 σ=0.549 [[0.337, 0.129], [0.234, 0.230], [-1.123, -0.186]]
 tensor([[ 0.3367,  0.1288],
         [ 0.2345,  0.2303],
         [-1.1229, -0.1863]]),
 tensor[3, 2] n=6 x∈[-0.638, 2.208] μ=0.607 σ=0.927 [[2.208, -0.638], [0.462, 0.267], [0.535, 0.809]]
 tensor([[ 2.2082, -0.6380],
         [ 0.4617,  0.2674],
         [ 0.5349,  0.8094]]),
 tensor[3, 2] n=6 x∈[-1.690, 1.322] μ=0.255 σ=1.266 [[1.110, -1.690], [-0.989, 0.958], [1.322, 0.817]]
 tensor([[ 1.1103, -1.6898],
         [-0.9890,  0.9580],
         [ 1.3221,  0.8172]]))

In [131]:
query = emb_matrix @ W_query
key = emb_matrix @ W_key
value = emb_matrix @ W_value


In [132]:
attn_scores = query @ key.T
attn_scores.v

causal_mask = torch.triu(torch.ones(6, 6), diagonal=1)
attn_scores = attn_scores.masked_fill(causal_mask.bool(), -torch.inf)
attn_scores.v


attn_weights = attn_scores / query.shape[1] ** 0.5
# attn_scores.v, attn_weights.v

attn_weights = torch.softmax(attn_weights, dim=1)
attn_scores.v, attn_weights.v

(tensor[6, 6] n=36 x∈[-1.262, 0.463] μ=-0.310 σ=0.468 [31m-Inf![0m
 tensor([[-1.2618,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.4540, -0.6313,    -inf,    -inf,    -inf,    -inf],
         [-0.4166, -0.5824, -0.5955,    -inf,    -inf,    -inf],
         [-0.1911, -0.2742, -0.2816, -0.1210,    -inf,    -inf],
         [ 0.3745,  0.4623,  0.4625,  0.2301,  0.3368,    -inf],
         [-0.5747, -0.7769, -0.7900, -0.3594, -0.8026, -0.2644]]),
 tensor[6, 6] n=36 x∈[0., 1.000] μ=0.167 σ=0.204
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5313, 0.4687, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3609, 0.3210, 0.3181, 0.0000, 0.0000, 0.0000],
         [0.2543, 0.2398, 0.2386, 0.2673, 0.0000, 0.0000],
         [0.1998, 0.2126, 0.2126, 0.1804, 0.1946, 0.0000],
         [0.1670, 0.1448, 0.1435, 0.1945, 0.1422, 0.2080]]))

In [133]:
context_vec = attn_weights @ value
context_vec.v

tensor[6, 2] n=12 x∈[0.094, 1.506] μ=0.589 σ=0.426
tensor([[1.5058, 0.1444],
        [1.0920, 0.2845],
        [0.9465, 0.3134],
        [0.7133, 0.3535],
        [0.7323, 0.0938],
        [0.5575, 0.3262]])

Let's combine this into a contained class!

In [137]:
class CausalAttention_v1(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_key  = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)

    def forward(self, x):
        # x shape: (seq_len, emb_sz)

        seq_len, emb_sz = x.shape
        query = x @ self.W_query
        key   = x @ self.W_key
        value = x @ self.W_value

        attn_scores = query @ key.T
        
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        attn_scores = attn_scores.masked_fill(causal_mask.bool(), -torch.inf)

        attn_scores = attn_scores / emb_sz ** 0.5
        attn_weights = torch.softmax(attn_scores, dim=1)

        context_vec = attn_weights @ value
        return context_vec

torch.manual_seed(42)
model = CausalAttention_v1(3, 2)
context_vec = model(emb_matrix)
context_vec.v

tensor[6, 2] n=12 x∈[0.093, 1.506] μ=0.588 σ=0.425 grad MmBackward0
tensor([[1.5058, 0.1444],
        [1.0869, 0.2862],
        [0.9420, 0.3148],
        [0.7143, 0.3536],
        [0.7306, 0.0925],
        [0.5660, 0.3140]], grad_fn=<MmBackward0>)

It's a common practice to add a dropout layer right after computing the attention weights.

In [None]:
class CausalAttention_v2(torch.nn.Module):
    def __init__(self, d_in, d_out, drop_rate=0.5):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_key  = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.dropout = torch.nn.Dropout(drop_rate)

    def forward(self, x):
        # x shape: (seq_len, emb_sz)

        seq_len, emb_sz = x.shape
        query = x @ self.W_query
        key   = x @ self.W_key
        value = x @ self.W_value

        attn_scores = query @ key.T
        
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        attn_scores = attn_scores.masked_fill(causal_mask.bool(), -torch.inf)

        attn_scores = attn_scores / emb_sz ** 0.5
        attn_weights = torch.softmax(attn_scores, dim=1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ value
        return context_vec

torch.manual_seed(42)
model = CausalAttention_v2(3, 2)
context_vec = model(emb_matrix)
context_vec.v

tensor[6, 2] n=12 x∈[-0.139, 2.174] μ=0.643 σ=0.758 grad MmBackward0
tensor([[ 0.0000,  0.0000],
        [ 2.1738,  0.5725],
        [ 1.8840,  0.6295],
        [ 0.6649,  0.6339],
        [ 1.1534, -0.1391],
        [ 0.2575, -0.1156]], grad_fn=<MmBackward0>)

As there will be many data points as an input to the model, let's grab them into a batch and pass them to the model.

In [143]:
batch = torch.stack([emb_matrix, emb_matrix], dim=0)
batch.v

tensor[2, 6, 3] n=36 x∈[0.050, 0.890] μ=0.514 σ=0.272
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [189]:
# modified the softmax to consider the last dimension to normalize.

class CausalAttention_v3(torch.nn.Module):
    def __init__(self, d_in, d_out, drop_rate=0.5):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_key  = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.dropout = torch.nn.Dropout(drop_rate)

    def forward(self, x):
        # x shape: (batch_size, seq_len, emb_sz)

        batch_sz, seq_len, emb_sz = x.shape
        query = x @ self.W_query # query shape: (batch_size, seq_len, d_out)
        key   = x @ self.W_key # key shape: (batch_size, seq_len, d_out)
        value = x @ self.W_value # value shape: (batch_size, seq_len, d_out)

        attn_scores = query @ key.transpose(1, 2) # (transposing 1st and 2nd dimension for matmul)
        
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) 
        attn_scores = attn_scores.masked_fill(causal_mask.bool(), -torch.inf)

        attn_scores = attn_scores / emb_sz ** 0.5
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ value
        return context_vec

torch.manual_seed(42)
model = CausalAttention_v3(3, 2)
context_vec = model(batch)
context_vec.v

tensor[2, 6, 2] n=24 x∈[-0.139, 3.012] μ=0.674 σ=0.797 grad UnsafeViewBackward0
tensor([[[ 0.0000,  0.0000],
         [ 2.1738,  0.5725],
         [ 1.8840,  0.6295],
         [ 0.6649,  0.6339],
         [ 1.1534, -0.1391],
         [ 0.2575, -0.1156]],

        [[ 3.0116,  0.2888],
         [ 1.5828,  0.1518],
         [ 0.0000,  0.0000],
         [ 1.1213,  0.5271],
         [ 0.8639,  0.2442],
         [ 0.4022,  0.2759]]], grad_fn=<UnsafeViewBackward0>)

We see that the output is the right size (2, 6, 2) which is (batch_size, seq_len, d_out).

However, for the curious, the 1st element output is different from 2nd element even though we passed the exact same input via `torch.stack([emb_matrix, emb_matrix])`. Why is that the case?

The answer is Dropout! As dropout layer is randomly cutting off a few weights, we see the outputs changing. If we turn off the dropout layer, we see that the output is same across the batch dimension.

In [191]:
# modified the softmax to consider the last dimension to normalize.

class CausalAttention_v4(torch.nn.Module):
    def __init__(self, d_in, d_out, drop_rate=0.5):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_key  = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=True)
        self.dropout = torch.nn.Dropout(drop_rate)

    def forward(self, x):
        # x shape: (batch_size, seq_len, emb_sz)

        batch_sz, seq_len, emb_sz = x.shape
        query = x @ self.W_query # query shape: (batch_size, seq_len, d_out)
        key   = x @ self.W_key # key shape: (batch_size, seq_len, d_out)
        value = x @ self.W_value # value shape: (batch_size, seq_len, d_out)

        attn_scores = query @ key.transpose(1, 2) # (transposing 1st and 2nd dimension for matmul)
        
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) 
        attn_scores = attn_scores.masked_fill(causal_mask.bool(), -torch.inf)

        attn_scores = attn_scores / emb_sz ** 0.5
        attn_weights = torch.softmax(attn_scores, dim=-1)
        # attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ value
        return context_vec

torch.manual_seed(42)
model = CausalAttention_v4(3, 2)
context_vec = model(batch)
context_vec.v

tensor[2, 6, 2] n=24 x∈[0.093, 1.506] μ=0.588 σ=0.416 grad UnsafeViewBackward0
tensor([[[1.5058, 0.1444],
         [1.0869, 0.2862],
         [0.9420, 0.3148],
         [0.7143, 0.3536],
         [0.7306, 0.0925],
         [0.5660, 0.3140]],

        [[1.5058, 0.1444],
         [1.0869, 0.2862],
         [0.9420, 0.3148],
         [0.7143, 0.3536],
         [0.7306, 0.0925],
         [0.5660, 0.3140]]], grad_fn=<UnsafeViewBackward0>)