In [45]:
!python --version

Python 3.11.5


In [46]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print(torch.version.cuda)

# hyper parameters
max_iters = 10000
# eval interval = 2500
block_size = 8
batch_size = 4
learning_rate = 3e-4
eval_iters = 250

cpu
11.8


In [47]:
with open('wizard_of_oz.txt','r', encoding='utf-8') as f:
    text = f.read()
print(len(text))
print(text[:200])

232310
﻿
  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC. NE


In [48]:
# chars all the characters in the text
chars = sorted(set(text))
print(chars)
vocab_size = len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']


In [49]:
# now we can use a tokeniser,
# a tokeniser consists of an encoder and a decoder
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80,  0,  1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32,
        29,  1, 47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1,
        26, 49,  0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45,
        37,  0,  0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,
         1, 47, 33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,
         1, 36, 25, 38, 28,  1, 39, 30,  1, 39])


In [50]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    #print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device), y.to(device) #method to push to the gpu
    return x,y

x,y = get_batch('train')
print('inputs: \n',x,'\n','outputs: \n', y)

inputs: 
 tensor([[ 0,  0, 28, 68, 71, 68, 73, 61],
        [68,  1, 66, 58, 58, 73,  1, 61],
        [58, 69, 62, 67, 60,  1, 61, 62],
        [65, 11,  0, 50, 58, 55,  1, 76]]) 
 outputs: 
 tensor([[ 0, 28, 68, 71, 68, 73, 61, 78],
        [ 1, 66, 58, 58, 73,  1, 61, 58],
        [69, 62, 67, 60,  1, 61, 62, 66],
        [11,  0, 50, 58, 55,  1, 76, 54]])


In [56]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [53]:
# bigram language model 
# make predictions off a block size 
# we have a tiny tensor of 5 images, and given the context we predict
# how far away is the prediction from the error

# this is taking the first 8 characters and is showing what the target is (so given the current character, predict the next)
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("when the input is: ",context," target is: ",target)

when the input is:  tensor([80])  target is:  tensor(0)
when the input is:  tensor([80,  0])  target is:  tensor(1)
when the input is:  tensor([80,  0,  1])  target is:  tensor(1)
when the input is:  tensor([80,  0,  1,  1])  target is:  tensor(28)
when the input is:  tensor([80,  0,  1,  1, 28])  target is:  tensor(39)
when the input is:  tensor([80,  0,  1,  1, 28, 39])  target is:  tensor(42)
when the input is:  tensor([80,  0,  1,  1, 28, 39, 42])  target is:  tensor(39)
when the input is:  tensor([80,  0,  1,  1, 28, 39, 42, 39])  target is:  tensor(44)


In [54]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token__embeddings_table = nn.Embedding(vocab_size, vocab_size) # this is a learnable param

    def forward(self, index, targets=None):
        logits =  self.token__embeddings_table(index)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
        
    def generate(self, index, max_new_tokens):
        # index is (B X T) array of indices in the current context

        # generates based on the max new tokens
        for _ in range(max_new_tokens):
            logits, loss = self.forward(index)
            # focus on only the last time step
            logits = logits[:, -1, :] # this becomes B,C
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution only take 1 so get the next one
            index_next = torch.multinomial(probs, num_samples=1) #(B,1)
            # append the sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) #(B, T+1)
        return index

model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


n0bb7K7HvM7.7hA10Fz1w0F_8iM?)oDi,J6:xN6k.ON8-"h6-cx7JBBK7PIXGXDJaIMxN2NOO]yk_rtvg0j
*vI.L3:tegY.ElK7NZP?x7H((&NG8gPUJdloB"JcTBcGk_MP_T﻿4NGG![d?;I8n&!m8t6IT2VzDfh",mDcKu brnsSMb?*?n0']i)L,&G3[ZdJ49Q([b3z0-rJEGKTYRlzkXOi&Gq. f6kXz1F?4ixN6V
R6h0yUFXCdJ"p-cQqNZ﻿-V
vXR!D*.zp]j:ydx"a5*gaKa7!34wqpEGiQ?l3ui.YG']yRt]EXpyn8sldiVoZ[ZZ1IFrJ"R-*zXZ"rHzIBaPIE1cG﻿VJ?sIiTao]jbF6G,"40
6UyqfC,k'K5DcR'Ta);"jZI(KaEgP
Ww
2IGTZmpNL_&?cLXO zP_8sAVn08R?;FgaolrH6Xe3pc"9F
?Y0.Z]g[[n!y"KxN0WyuWSuV4A(8R:t8]QsAtwZ0Js&71Wq
v


In [63]:
# create the pytorch optimiser

optimiser = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# the training loop
for iter in range(max_iters):
    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f} ")

    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model.forward(xb, yb)
    optimiser.zero_grad(set_to_none=True) # zero grad ensures that the previous gradient does not affect the current one
    # the gradients are set to none instead of zero as this occupies less space
    loss.backward()
    optimiser.step()

print(loss.item())


step: 0, train loss: 2.5108, val loss: 2.5332 
step: 250, train loss: 2.5273, val loss: 2.5409 
step: 500, train loss: 2.4881, val loss: 2.5318 
step: 750, train loss: 2.5017, val loss: 2.5356 
step: 1000, train loss: 2.4841, val loss: 2.5281 
step: 1250, train loss: 2.4838, val loss: 2.5488 
step: 1500, train loss: 2.5050, val loss: 2.5364 
step: 1750, train loss: 2.5204, val loss: 2.5332 
step: 2000, train loss: 2.4951, val loss: 2.5267 
step: 2250, train loss: 2.4927, val loss: 2.5592 
step: 2500, train loss: 2.4676, val loss: 2.5323 
step: 2750, train loss: 2.5007, val loss: 2.5315 
step: 3000, train loss: 2.4991, val loss: 2.5397 
step: 3250, train loss: 2.5069, val loss: 2.5198 
step: 3500, train loss: 2.4960, val loss: 2.5141 
step: 3750, train loss: 2.4745, val loss: 2.5006 
step: 4000, train loss: 2.4976, val loss: 2.5277 
step: 4250, train loss: 2.4881, val loss: 2.5153 
step: 4500, train loss: 2.4648, val loss: 2.5096 
step: 4750, train loss: 2.4796, val loss: 2.5419 
step: 

In [None]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)