In [1]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F


In [2]:
# hardware acceleration
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
print(device)

cuda


In [3]:
# get input
if not os.path.exists('input.txt'):
    import requests
    data = requests.get('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
    with open('input.txt', 'w') as f:
        f.write(data.text)
    print('finished downloading input data')
else:
    print('already have input data')

already have input data


In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    print('n_chars:', len(text))

n_chars: 1115394


In [5]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print('vocab_size:', vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size: 65


In [7]:
# create mapping from characters to integers
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # takes a string: outputs a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # takes a list of integers, output a string


In [8]:
# encode text
data = torch.tensor(encode(text), dtype=torch.long)

# create training and validation splits
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [10]:
batch_size = 64
block_size = 8 # also known as context_length

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # stack along dim 0
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# why are we taking 8 outputs for each batch?
xb, yb = get_batch('train')
print(xb.shape, yb.shape)

torch.Size([64, 8]) torch.Size([64, 8])


In [20]:
# when we encode our vocabulary, our batch samples xb become a [B,T,C] tensor
# where each sample is in a batch of size B containing.
# each sample is a time varying sequence of length T
# and each time step contains a channel of information with length C (where C is dependent on the embedding)
B, T, C = 4, 8, 2 # batch, time, channel

x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

### Precursor for Self-Attention
In an *auto-regressive* model, we want to model each time step in terms of its previous time steps. However, we do not want each time step to be dependent on any future time step as our model is tasked with predicting future time stpes. One way to model this relationship is by writing the channels of each time step as a linear combination, henceforth called *aggregation*, of the channels corresponding to the previous time steps. 

We can easily do this by multiplying each sample, which is $[T, C]$ matrix, by a *lower triangular* weighted matrix $\mathrm{L}[T, T]$ whose rows sum to $1$. Initially, the weightings will be uniform, so the aggregation is really just the mean of previous channels. However, the weightings will later be learned by the model. The matrix multiplication $\mathrm{L} \times x$ has the same dimension as $x$, however each row of $x$, corresponding to the channels of each time step is now a weighted average of itself and the channels of the previous time steps.

eg: $$\begin{bmatrix}1.0 & 0.0 & 0.0 \\ 0.5 & 0.5 & 0.0 \\ 0.33 & 0.33 & 0.33\end{bmatrix} \times \begin{bmatrix}a_1 & a_2 \\ b_1 & b_2\\ c_1 & c_2\end{bmatrix}=\begin{bmatrix}a_1 & a_2 \\ 0.5a_1 + 0.5b_1 & 0.5a_2 + 0.5b_2 \\ 0.33a_1 + 0.33b_1 + 0.33c_1 & 0.33a_2 + 0.33b_2 + 0.33c_2\end{bmatrix}$$

Althoough this does model each time step as a linear combination of its previous time steps, it does not retain any knowledge of the sequence, which makes it not ideal at its current state.

In [21]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow = wei @ x # xbow short for x bag-of-words which refers to a model that disregards ordering

In [24]:
# Alternative calculation for x bag-of-words which will be useful for self-attention
tril = torch.tril(torch.ones(T, T))
wei = torch.ones((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # fills upper triangle with -inf
wei = F.softmax(wei, dim=1) # exponentiate and normalize which will results in the same matrix as before
xbow = wei @ x

---
### Transformer Model

In [None]:
class BigramLM(nn.Module):
    def __init__(self):
        super().__init__()
        n_embd = 32
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        x = tok_emb + pos_emb
        logits = self.lm_head(x) # (B,T,vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # for F.cross_entropy, we need to flatten the B and T dims
            # there are B batches and T characters in each batch (with C=vocab_size channels)
            # so each row in the matrix below is some channel of a character from the batch
            logits = logits.view(B*T, C) 
            # each entry in the vector below corresponds to the target value of each character in logits[B*T, C]
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get prediction from forward()
            logits, loss = self(idx)
            # only interested in predicting the next character
            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

m = BigramLM().to(device)
logits, loss = m(xb, yb)

print(logits.shape)
print(loss)

# 0 corresponds to new-line
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(idx=idx, max_new_tokens=100)[0].tolist()))

In [25]:
import numpy as np

# Define matrices A and B
A = np.array([[1, 2, 3],
              [4, 5, 6]])

B = np.array([[7, 8],
              [9, 10],
              [11, 12]])

# Compute matrix C
C = np.dot(A, B)

# Gradient of C with respect to A is B^T
gradient_C_wrt_A = B.T

# Display C and the gradient of C with respect to A
print("Matrix C:")
print(C)

print("\nGradient of C with respect to A (B^T):")
print(gradient_C_wrt_A)

# Compute element-wise gradients
# Initialize a matrix to store element-wise gradients
element_wise_gradients = np.zeros((C.shape[0], C.shape[1], A.shape[0], A.shape[1]))

# Compute element-wise gradients
for i in range(C.shape[0]):  # For each row in C
    for j in range(C.shape[1]):  # For each column in C
        for k in range(A.shape[0]):  # For each row in A
            for l in range(A.shape[1]):  # For each column in A
                if i == k:
                    element_wise_gradients[i, j, k, l] = B[l, j]
                else:
                    element_wise_gradients[i, j, k, l] = 0

print("\nElement-wise gradients with respect to A:")
print(element_wise_gradients)

# Verify that the gradient matrix B.T matches the element-wise gradients
print("\nCheck if gradient_C_wrt_A matches the sum of element-wise gradients:")
print(np.allclose(gradient_C_wrt_A, np.sum(element_wise_gradients, axis=(0, 1))))


Matrix C:
[[ 58  64]
 [139 154]]

Gradient of C with respect to A (B^T):
[[ 7  9 11]
 [ 8 10 12]]

Element-wise gradients with respect to A:
[[[[ 7.  9. 11.]
   [ 0.  0.  0.]]

  [[ 8. 10. 12.]
   [ 0.  0.  0.]]]


 [[[ 0.  0.  0.]
   [ 7.  9. 11.]]

  [[ 0.  0.  0.]
   [ 8. 10. 12.]]]]

Check if gradient_C_wrt_A matches the sum of element-wise gradients:
False
