In [8]:
# download and save the data
!mkdir -p data && \
    wget -O data/input.txt \
    https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-09-05 16:45:08--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘data/input.txt’


2025-09-05 16:45:09 (4.20 MB/s) - ‘data/input.txt’ saved [1115394/1115394]



In [9]:
# read the dataset
with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
text[:10]

'First Citi'

In [10]:
print(f"text length: {len(text)}")

text length: 1115394


In [11]:
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [14]:
# find all the unique characters in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [17]:
# character-level LM, encoder and decoder
# char to int mapping
stoi = { ch:i for i,ch in enumerate(chars) }
# int to char mapping
itos = { i:ch for i,ch in enumerate(chars) }
# encoder: maps the given string to a list of int (tokens)
encode = lambda s: [stoi[c] for c in s]
# decoder: maps the list of int (tokens) to string
decode = lambda l: ''.join([itos[i] for i in l])

In [18]:
print(encode("Hello, there!"))
print(decode(encode("Hello, there!")))

[20, 43, 50, 50, 53, 6, 1, 58, 46, 43, 56, 43, 2]
Hello, there!


In [21]:
# encode the Shakespeare text
import torch 
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [23]:
# create train and test split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [24]:
# in this chunk we have 8 training examples that we train simultaneously
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [27]:
# this helps with both efficieny and teaching transformers to look at all context length from 1 to block_size, helps in inference
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Example {t+1} -> when input is {context}, target is {target}")

Example 1 -> when input is tensor([18]), target is 47
Example 2 -> when input is tensor([18, 47]), target is 56
Example 3 -> when input is tensor([18, 47, 56]), target is 57
Example 4 -> when input is tensor([18, 47, 56, 57]), target is 58
Example 5 -> when input is tensor([18, 47, 56, 57, 58]), target is 1
Example 6 -> when input is tensor([18, 47, 56, 57, 58,  1]), target is 15
Example 7 -> when input is tensor([18, 47, 56, 57, 58,  1, 15]), target is 47
Example 8 -> when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), target is 58


In [59]:
torch.manual_seed(42)
batch_size = 4 # num of sequences we will look in parallel
block_size = 8 # maximum context for the the predictions

def get_batch(split):
    data = train_data if split == 'train' else val_data
    # get batch_size random indices
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

xb, yb = get_batch('train')
print(xb.shape)
print(xb)
print(yb.shape)
print(yb)

for x,y in zip(xb,yb):
    for t in range(block_size):
        context = x[:t+1]
        target = y[t]
        print(f"Example {t+1} -> when input is {context}, target is {target}")
    print('---')
# here we have 32 independent examples packed in 4 batches of block_size 8

torch.Size([4, 8])
tensor([[57,  1, 46, 47, 57,  1, 50, 53],
        [ 1, 58, 46, 43, 56, 43,  1, 41],
        [17, 26, 15, 17, 10,  0, 32, 53],
        [57, 58,  6,  1, 61, 47, 58, 46]])
torch.Size([4, 8])
tensor([[ 1, 46, 47, 57,  1, 50, 53, 60],
        [58, 46, 43, 56, 43,  1, 41, 39],
        [26, 15, 17, 10,  0, 32, 53,  1],
        [58,  6,  1, 61, 47, 58, 46,  0]])
Example 1 -> when input is tensor([57]), target is 1
Example 2 -> when input is tensor([57,  1]), target is 46
Example 3 -> when input is tensor([57,  1, 46]), target is 47
Example 4 -> when input is tensor([57,  1, 46, 47]), target is 57
Example 5 -> when input is tensor([57,  1, 46, 47, 57]), target is 1
Example 6 -> when input is tensor([57,  1, 46, 47, 57,  1]), target is 50
Example 7 -> when input is tensor([57,  1, 46, 47, 57,  1, 50]), target is 53
Example 8 -> when input is tensor([57,  1, 46, 47, 57,  1, 50, 53]), target is 60
---
Example 1 -> when input is tensor([1]), target is 58
Example 2 -> when input i

In [60]:
# Bigram LM: simplest language model
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange

torch.manual_seed(42)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocal_size):
        super().__init__()
        # lookup embedding table where we have an embedding row each token in vocab_size
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None):
        # inputs, targets: B (batch) X T (block_size)
        logits = self.token_embedding_table(inputs) # (B, T, C) C: embedding_length

        if targets is None:
            loss = None
        else:
            # targets: (B,T)
            logits = rearrange(logits, 'b t c -> (b t) c')
            targets = rearrange(targets, 'b t -> (b t)')
            loss = F.cross_entropy(logits, targets) # expected value: -ln(1/65)
        
        return logits, loss

    def generate(self, inputs, max_new_tokens):
        # inputs: (B,T)
        for _ in range(max_new_tokens):
            # get the predictions
            logits, _ = self(inputs)
            # focus only on last token
            logits = logits[:,-1,:] # (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the sequence
            inputs = torch.cat((inputs, idx_next), dim=1) # (B,T+1)
        return inputs
        

m = BigramLanguageModel(vocab_size)
out, loss = m(xb, yb)
print(out.shape, loss)
        

torch.Size([32, 65]) tensor(4.8865, grad_fn=<NllLossBackward0>)


In [61]:
inputs = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(inputs, max_new_tokens=100)[0].tolist()))


o$,q&IWqW&xtCjaB?ij&bYRGkF?b; f ,CbwhtERCIfuWr,DzJERjhLlVaF&EjffPHDFcNoGIG'&$qXisWTkJPw
 ,b Xgx?D3sj


In [62]:
# optimizer
# you can get away with large lr for small NN
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [69]:
batch_size = 32
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')
    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps%100 == 0: print(f"step: {steps} -> loss: {loss.item():.5f}") 


step: 0 -> loss: 2.48880
step: 100 -> loss: 2.37263
step: 200 -> loss: 2.45426
step: 300 -> loss: 2.49152
step: 400 -> loss: 2.44427
step: 500 -> loss: 2.47666
step: 600 -> loss: 2.31586
step: 700 -> loss: 2.43365
step: 800 -> loss: 2.46397
step: 900 -> loss: 2.53290
step: 1000 -> loss: 2.52733
step: 1100 -> loss: 2.37466
step: 1200 -> loss: 2.46733
step: 1300 -> loss: 2.43932
step: 1400 -> loss: 2.40134
step: 1500 -> loss: 2.51769
step: 1600 -> loss: 2.41952
step: 1700 -> loss: 2.42812
step: 1800 -> loss: 2.45226
step: 1900 -> loss: 2.48836
step: 2000 -> loss: 2.43946
step: 2100 -> loss: 2.32969
step: 2200 -> loss: 2.34064
step: 2300 -> loss: 2.43337
step: 2400 -> loss: 2.49782
step: 2500 -> loss: 2.56243
step: 2600 -> loss: 2.44355
step: 2700 -> loss: 2.56188
step: 2800 -> loss: 2.44086
step: 2900 -> loss: 2.38822
step: 3000 -> loss: 2.48613
step: 3100 -> loss: 2.46777
step: 3200 -> loss: 2.41473
step: 3300 -> loss: 2.37576
step: 3400 -> loss: 2.36604
step: 3500 -> loss: 2.23020
step

In [71]:
inputs = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(inputs, max_new_tokens=100)[0].tolist()))


Hak aifucouthret. gan f, telyonouer'scty
Buss ithim nst I y tis s buietrakiswe awere,
CEO:
Yet:

Whe
