In [2]:
import os
import requests
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname('__file__'), 'data/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

In [3]:
with open(input_file_path, 'r') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))

vocab_size = len(chars)

In [5]:
str_to_int = { ch:i for i, ch in enumerate(chars) }
int_to_str = { i:ch for i, ch in enumerate(chars) }

encode = lambda s: [str_to_int[ch] for ch in s]
decode = lambda l: ''.join([int_to_str[i] for i in l])

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)

n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

In [7]:
torch.manual_seed(64)
batch_size = 4
block_size = 8


def generate_batch(split, batch_size=4):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix])
    y = torch.stack([data[i+1: i+block_size+1] for i in ix])
    return x, y

xb, yb = generate_batch('train')

In [8]:
n_embd = 32

class BigramLanguageModel(nn.Module):

    def __init__(self, idx):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, target=None):
        B, T = idx.shape

        token_embd = self.token_embedding_table(idx) #(B, T, C)
        logits = self.lm_head(token_embd) #(B, T, vocab_size)
        
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target) 
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(idx)

            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
    def backward(self, batch_size=32):
        optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

        for step in range(2501):
            xb, yb = generate_batch('train', batch_size)

            logits, loss = self.forward(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            
            if step % 500 == 0:
                print(f"Step {step}: {loss.item()}")




In [10]:
m = BigramLanguageModel(xb)

logits, loss = m.forward(xb, yb)

idx = torch.zeros((1, 1), dtype=torch.long)

print("Before Training:")
print(decode(m.generate(idx, max_new_tokens=400)[0].tolist()))

m.backward()

print("\nAfter Training:")
print(decode(m.generate(idx, max_new_tokens=400)[0].tolist()))

Before Training:

f'?'nC.UwSVWwdR.;3uH$Sg?HUFSB.VtmdaVFP nkk.qswUsSfneF3uwQuKU:S
Nts3SL?:AqXFgkwqCJg aLnW?T?mHoWqJPZjNpnCnF$$MYxEAGn3PVOnxqgnkM$O--szfnoVjncmgUmo;xZOlYDXVBR!Z;iraD,
.us$b?Z?A?gcCBbCBrwBOYH;QBOxZ!RxigUvgjgNtUSa,Tja.RpSctMdNstQY?anGmlpp&nISg-q:VD !NvlMpFJq
pTEz'$inn'm-$.MvvrU;oybj&MqZ.LdxIU:sEjteqjOt:L?ebcF.khIfUSMVkHEMr
pfWJ y MGibE&; EYyMS.vF-MNpKaWJfnKCtSeEh.f.'FUwYsB Q.,Zsms.dxDPOJ!J3kv&TWjs'?U!Sr


  from .autonotebook import tqdm as notebook_tqdm


Step 0: 4.388572692871094
Step 500: 2.7687180042266846
Step 1000: 2.490400791168213
Step 1500: 2.582876443862915
Step 2000: 2.5142037868499756
Step 2500: 2.5428881645202637

After Training:


ARitton! is d p,
qders ome:
A:
F:

SMENy wan d d CAnet, by, s&Cow haYomod, wane thind CENwfenchr.

P chewot t biurGin wPY aellos, th de.
O:ecu,
Oxt gofeme en n ld se thind shevathanomomaded so fou s t hesst R.
I IIIn.
BENE
CEve bengl p, husathe ied sftirynknsextlld, stou gho'd bu m
QIs tI

He.
Fsethin,
KINGENy bo
LNGICot cating indig,
Yo anTUS:
YOMa me the thar? orere h twhex.

TES: terd godseshi
