In [4]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-11-30 05:26:18--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2025-11-30 05:26:19 (29.6 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



### Reading & Exploring the Data

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

In [None]:
print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [None]:
# lets look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [None]:
# All the unique characters that occur in the text

chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


### Tokenization | train/val split

In [None]:
stoi = {ch:i for i,ch in enumerate(chars)} # enumerate() gives both the index (i) and the character (ch) from the string
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a sring, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode('hii there'))
print(decode(encode('hii there')))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [None]:
# let's now encode the entire dataset and store it into a torch.Tensor
import torch # we use PyTorch: https://pytorch.org
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [None]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

### Creating Batches & Input Data

In [None]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'when input is {context} the taget: {target}')

when input is tensor([18]) the taget: 47
when input is tensor([18, 47]) the taget: 56
when input is tensor([18, 47, 56]) the taget: 57
when input is tensor([18, 47, 56, 57]) the taget: 58
when input is tensor([18, 47, 56, 57, 58]) the taget: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the taget: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the taget: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the taget: 58


In [None]:
torch.manual_seed(1337)
batch_size = 4       # how many independent sequences processed in parallel
block_size = 8       # maximum context length for predictions

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    # random start indices (each i -> sequence i : i+block_size)
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    # targets are the next-token sequence: i+1 : i+block_size+1
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('input shape:', xb.shape)   # -> (batch_size, block_size)
print('target shape:', yb.shape)  # -> (batch_size, block_size)

print('input (xb):')
print(xb)

print('targets (yb):')
print(yb)

print('-----')
# fixed printing loop: use the same variable names and valid indexing
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, : t + 1]   # tokens up to and including position t
        target = yb[b, t]          # the token we want to predict at this timestep
        print(f'when input is {context.tolist()} the target is {int(target)}')

input shape: torch.Size([4, 8])
target shape: torch.Size([4, 8])
input (xb):
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets (yb):
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----
when input is [24] the target is 43
when input is [24, 43] the target is 58
when input is [24, 43, 58] the target is 5
when input is [24, 43, 58, 5] the target is 57
when input is [24, 43, 58, 5, 57] the target is 1
when input is [24, 43, 58, 5, 57, 1] the target is 46
when input is [24, 43, 58, 5, 57, 1, 46] the target is 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target is 39
when input is [44] the target is 53
when input is [44, 53] the target is 56
when input is [44, 53, 56] the target is 1
when input is [44, 53, 56, 1] the target is 58
when inpu

In [None]:
print(xb) # our input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


### Bigarm Model

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets = None):

    # idx and targets are both (B,T) tensor of integers
    logits = self.token_embedding_table(idx) #B,T,C

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # get the predictions
      logits, loss = self(idx)
      # focus only on the last time step
      logits = logits[:, -1, :] # becomes (B, C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim = -1) # (B, C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples = 1) # (B, 1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim = 1) # (B, T+1)
    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1,1), dtype = torch.long), max_new_tokens = 100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [None]:
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)

In [None]:
batch_size = 32
for steps in range(10000):

  # sample a batch of data
  xb, yb = get_batch("train")

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none = True)
  loss.backward()
  optimizer.step()

print(loss.item())

2.382369041442871


In [None]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype = torch.long), max_new_tokens = 500)[0].tolist()))


lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulseecherd d o blllando;LUCEO, oraingofof win!
RIfans picspeserer hee tha,
TOFonk? me ain ckntoty ded. bo'llll st ta d:
ELIS me hurf lal y, ma dus pe athouo
BEY:! Indy; by s afreanoo adicererupa anse tecorro llaus a!
OLeneerithesinthengove fal amas trr
TI ar I t, mes, n IUSt my w, fredeeyove
THek' merer, dd
We ntem lud engitheso; cer ize helorowaginte the?
Thak orblyoruldvicee chot, p,
Bealivolde Th li


### The Mathematical Trick in Self-Attention

In [None]:
# consider the following toy example:

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

#### Version 1

In [None]:
# We want x[b, t] = mean_{i<=t} x[b,i]
# basically we want average of all the tokent before the context token
xbow = torch.zeros(B,T,C)
for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1] # (t,C)
    xbow[b,t] = torch.mean(xprev, 0)

In [None]:
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

#### Version 2

In [None]:
# Logic Building for Version 2

torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim = True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('---')
print('b=')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [None]:
# Main Code

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim = True) # weights
xbow2 = wei @ x # (T,T) @ (B,T,C) ----> (B,T,C)
torch.allclose(xbow, xbow2)

False

In [None]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

#### Version 3: Softmax

In [None]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
wei = F.softmax(wei, dim = -1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [None]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim = -1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

False

#### Version 4: slef-attention

In [None]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)
k = key(x)    # (B, T, 16)
q = query(x)  # (B, T, 16)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # (B, T, 16) @ (B, 16, T) ---> (B,T,T ) | (scaled self attention [* head_size**-0.5])

tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # We comment out the mask only in an encoder, never in a decoder
wei = F.softmax(wei, dim = -1)

v = value(x)
out = wei @ v
# out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [None]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2034, 0.1663, 0.1455, 0.3369, 0.0000, 0.0000, 0.0000],
        [0.1259, 0.2490, 0.1324, 0.1062, 0.3141, 0.0724, 0.0000, 0.0000],
        [0.1598, 0.1990, 0.1140, 0.1125, 0.1418, 0.1669, 0.1061, 0.0000],
        [0.0845, 0.1197, 0.1078, 0.1537, 0.1086, 0.1146, 0.1558, 0.1553]],
       grad_fn=<SelectBackward0>)

In [None]:
out[0]

tensor([[-1.5713e-01,  8.8009e-01,  1.6152e-01, -7.8239e-01, -1.4289e-01,
          7.4676e-01,  1.0068e-01, -5.2395e-01, -8.8726e-01,  1.9068e-01,
          1.7616e-01, -5.9426e-01, -4.8124e-01, -4.8599e-01,  2.8623e-01,
          5.7099e-01],
        [ 4.3974e-01, -1.4227e-01, -1.3157e-01,  2.8896e-03, -1.3222e-01,
          6.6079e-04, -2.7904e-01, -2.2676e-01, -2.8723e-01,  5.7456e-01,
          5.6053e-01, -2.5208e-01,  9.7243e-02,  1.0771e-01,  3.0455e-02,
          1.0727e+00],
        [ 4.3615e-01, -6.6358e-02, -2.9296e-01,  7.4315e-02,  5.4381e-02,
         -7.0388e-02, -6.8985e-02, -8.2153e-02, -2.9377e-01, -5.8952e-02,
          3.5887e-01, -2.3087e-03, -1.8212e-01, -3.6142e-02, -6.7189e-02,
          1.1412e+00],
        [ 4.2069e-01, -1.0619e-01, -2.9984e-01,  5.2820e-02,  2.0077e-01,
         -1.6048e-01, -3.5710e-02, -8.3110e-02, -1.7919e-01,  7.7992e-02,
          1.2719e-01,  2.2611e-02, -5.1811e-02,  7.4466e-02,  1.8131e-01,
          8.4463e-01],
        [ 3.9499e-01

### GPT Model

In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters
batch_size = 64   # how many independent sequences will we process in parallel?
block_size = 256  # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# ------------

torch.manual_seed(1337)

# Load dataset
with open('input.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

# Encoder / decoder mappings
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]           # encoder: take a string, output list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: integers back to string

# Train / val splits
data  = torch.tensor(encode(text), dtype = torch.long)
n     = int(0.9 * len(data))
train_data  = data[:n]
val_data    = data[n:]

# --------------------------------------------------
# get_batch(): returns (x,y) pairs of shape (B, block_size)
# --------------------------------------------------
def get_batch(split):
    data_src = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_src) - block_size, (batch_size,))

    # FIX: need list comprehension inside torch.stack
    x = torch.stack([data_src[i:i+block_size] for i in ix])
    y = torch.stack([data_src[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y

# --------------------------------------------------
# estimate_loss(): evaluates train + val loss WITHOUT
# gradient computation (no_grad + model.eval)
# --------------------------------------------------
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iters, device=device)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

# --------------------------------------------------
# Attention Head
# --------------------------------------------------
class Head(nn.Module):
    """ One head of self-attention """
    def __init__(self, n_embd, head_size):
        super().__init__()
        # Linear layers reduce dimension from n_embd → head_size
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Dropout for attention weights
        self.dropout = nn.Dropout(dropout)

        # Causal mask: lower triangular matrix
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B,T,C = x.shape

        k = self.key(x)                      # (B,T,head_size)
        q = self.query(x)                    # (B,T,head_size)

        # Scaled dot‑product attention
        wei = q @ k.transpose(-2,-1)         # (B,T,hs) @ (B,hs,T) → (B,T,T)
        wei = wei * (k.shape[-1] ** -0.5)    # scale by 1/sqrt(head_size)

        # Apply causal mask
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        # Softmax over last dimension (distribution over time)
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Weighted aggregation of values
        v = self.value(x)     # (B,T,hs)
        out = wei @ v         # (B,T,hs)
        return out

# --------------------------------------------------
# Multi-Head Attention
# --------------------------------------------------
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, n_embd, num_heads, head_size):
        super().__init__()

        # Create independent heads
        self.heads = nn.ModuleList([Head(n_embd, head_size) for _ in range(num_heads)])

        # Projection: concat_heads_dim → n_embd
        self.proj = nn.Linear(head_size * num_heads, n_embd)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Run all heads in parallel, concatenate results
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # Project back to model embedding space + dropout
        out = self.dropout(self.proj(out))
        return out

# --------------------------------------------------
# FeedForward Network (position-wise MLP)
# --------------------------------------------------
class FeedForward(nn.Module):
    """ A simple linear layer followed by ReLU and projection """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),   # expand
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),   # project back
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# --------------------------------------------------
# Transformer Block: Attention + MLP + Residual + LayerNorm
# --------------------------------------------------
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head

        self.sa   = MultiHeadAttention(n_embd, n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1  = nn.LayerNorm(n_embd)
        self.ln2  = nn.LayerNorm(n_embd)

    def forward(self, x):
        # Pre-LN architecture (using layer norm first then proceeding for each computation)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# --------------------------------------------------
# Full GPT Model
# --------------------------------------------------
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Token → embedding vectors
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)

        # Position → embedding vectors
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # Stack of Transformer Blocks
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(n_embd)

        # Final linear layer: project embeddings → logits over vocab
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # Initialize weights
        self.apply(self._init_weights)

    # Custom weight initialization
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Token + position embeddings
        tok_emb = self.token_embedding_table(idx)               # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T,C)
        x = tok_emb + pos_emb                                   # (B,T,C)

        # Pass through Transformer blocks
        x = self.blocks(x)

        # Final normalization + logits
        x = self.ln_f(x)
        logits = self.lm_head(x)                                # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B2, T2, C = logits.shape
            logits = logits.view(B2*T2, C)
            targets = targets.view(B2*T2)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    # --------------------------------------------------
    # Text generation loop
    # --------------------------------------------------
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # Limit context to last block_size tokens
            idx_cond = idx[:, -block_size:]

            # Forward pass
            logits, _ = self(idx_cond)

            # Focus only on last timestep
            logits = logits[:, -1, :]  # (B, vocab_size)

            # Convert to probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample next token index
            idx_next = torch.multinomial(probs, num_samples=1)  # (B,1)

            # Append new token
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# --------------------------------------------------
# Create model
# --------------------------------------------------
model = GPTLanguageModel().to(device)
print(f"Model params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# --------------------------------------------------
# Training Loop
# --------------------------------------------------
for iter in range(max_iters):

    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# --------------------------------------------------
# Text Generation
# --------------------------------------------------
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_tokens = model.generate(context, max_new_tokens=500)[0].tolist()
print(decode(generated_tokens))

Model params: 10.79M parameters
step 0: train loss 4.2221, val loss 4.2306
step 500: train loss 1.7600, val loss 1.9146
step 1000: train loss 1.3903, val loss 1.5987
step 1500: train loss 1.2644, val loss 1.5271
step 2000: train loss 1.1835, val loss 1.4978
step 2500: train loss 1.1233, val loss 1.4910
step 3000: train loss 1.0718, val loss 1.4804
step 3500: train loss 1.0179, val loss 1.5127
step 4000: train loss 0.9604, val loss 1.5102
step 4500: train loss 0.9125, val loss 1.5351
step 4999: train loss 0.8589, val loss 1.5565

But with prison, I will steal for the fimker.

KING HENRY VI:
To prevent it, as I love this country's cause.

HENRY BOLINGBROKE:
I thank bhop my follow. Walk ye were so?

NORTHUMBERLAND:
My lord, I hearison! Who may love me accurse
Some chold or flights then men shows to great the cur
Ye cause who fled the trick that did princely action?
Take my captiving sound, althoughts thy crown.

RICHMOND NE:
God neit will he not make it wise this!

DUKE VINCENTIO:
Worthy 

In [8]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_tokens = model.generate(context, max_new_tokens=1500)[0].tolist()
print(decode(generated_tokens))


I am sound to do for a king sleep:
I came to convert thy grief; and then be thieve
My indictment state and heart my soldier;
Some thy fable of life is flat, to woo.

ESCALUS:
Learn's is that, and but that thy, by edict.

POLIXENES:
Your tongue, my lord.
If you did mean they will this bud most know two:
if you wish met; but that they smoth were noted trainful
doing the one, and they stand goods for
minemen, know not at such receivity to me
welcome tof what's seen men.

Shepherd:
Out of this, night, if thou!

ESCALUS:
What are the prince, happy neck of his passes.

POMPHEY:
Then what make, fit shore sound for some requish,
short this he hath done; would afflict him the mock.

ANGELO:
Go weep, my lords. Come, come hither, thy absent,
Show'd thy frail and mock, and sworn break thirt.

POMPEY:
Since may, while you be glad and so swift ere then
come to seek me the flow sighting, so bald. Pray,
such forth as I can be as old. Let me come, follow.

BENVOLIO:
Here comes bleed, Johve ajoy, bear 

### This is to see things under the hood in logit

In [None]:
# import torch
# import torch.nn as nn
# torch.manual_seed(1337)

# vocab_size = 4
# emb = nn.Embedding(vocab_size, vocab_size)

# print("embedding weight shape:", emb.weight.shape)  # expected: torch.Size([4, 4])
# print("embedding weight matrix (rows = logits for next token):\n", emb.weight.detach())

# # example input: batch of 2 sequences, each length 3
# idx = torch.tensor([[0,1,2],
#                     [1,2,3]], dtype=torch.long)  # shape (B=2, T=3)
# logits = emb(idx)  # shape (B, T, C)
# print("\ninput idx:\n", idx)
# print("emb(idx) shape:", logits.shape)  # expected: torch.Size([2, 3, 4])
# print("emb(idx) values:\n", logits)