In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F

block_size = 8
batch_size = 4
max_iter = 10000
eval_iter = 250
learning_rate = 3e-4

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(text[:200])

DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC. NEW YO


In [5]:
chars = sorted(set(text))
print(chars)
vocab_size = len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [6]:
string_to_int = { ch:i for i,ch in enumerate(chars)}
int_to_string = { i:ch for i,ch in enumerate(chars)}
encode = lambda s:[string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

print(encode('hello'))

[61, 58, 65, 65, 68]


In [7]:
encoded_hello = encode('hello')
decoded_hello = decode(encoded_hello)
print(decoded_hello)

hello


In [8]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,  0,  0,
         1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,  0,  1,
         1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47, 33, 50,
        25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36, 25, 38,
        28,  1, 39, 30,  1, 39, 50,  9,  1, 39])


In [9]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')

print('Inputs')
print(x)
print('Targets')
print(y)

Inputs
tensor([[ 1, 72, 61, 68, 76,  1, 66, 58],
        [ 1, 73, 61, 54, 73,  1, 61, 58],
        [58, 71, 78,  1, 76, 58, 65, 65],
        [61, 62, 66,  9,  1, 54, 67, 57]], device='cuda:0')
Targets
tensor([[72, 61, 68, 76,  1, 66, 58,  1],
        [73, 61, 54, 73,  1, 61, 58,  1],
        [71, 78,  1, 76, 58, 65, 65,  9],
        [62, 66,  9,  1, 54, 67, 57,  0]], device='cuda:0')


In [28]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iter)
        for k in range(eval_iter):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out                   

In [29]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, index, targets=None):
        logits = self.token_embedding_table(index)
        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, index, max_new_tokens):
        # index is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self.forward(index)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            index = torch.cat((index, index_next), dim=1) # (B, T+1)
        return index

model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


hGEkX]*c!cBbp2Ci9jbR:U4v1ucWkCeHs[;tBk cb*x85"pH7BdE9)t!t4SqmCen8Ije"KAqmtTexX[sorpfh9KJ?fYBd]zHCk*,PDKR3-"!2orZ1D's CXgv0TN8nB cl-u6i5U"'ctUR;w8qvZ.WV, I5jy)OSLj?qM[;Qv7j[;Ek]uBbtivNVvsL*ltAx JFI 6JEwgC0wg?
.BzgOGhjJ-(QU1!tHvjuF8VGIqEaiy97)IqTtwbprM3p*T"DwKB2D,_r4t)6U]u
bp8[W_5Fk2f,h:f"K0If3cFm:bfQdl*rTeEuSFnSyX;,g,Fp4o6
h '9)sXEu4RmMRX'PYY&u
n9*gZMNTVVVxxSeI*q7 TsM-4YWkIKk?IqJ[n)gw[;Yt;w8G-gvD4yPGI"&*g W_6GINw(qVgVU5-ptiHItwA.*HNTGM&'dupT),YW_Ksv7]dfCzAIF Q u
RIKnAcwhnC0Cg9O&lg6K2bbp5j?2w5y7ja


In [35]:
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iter):
    if iter%eval_iter == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']}, val loss: {losses['val']}")
    xb, yb = get_batch('train')
    logits, loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

step: 0, train loss: 2.6743171215057373, val loss: 2.673475503921509
step: 250, train loss: 2.6481709480285645, val loss: 2.6570076942443848
step: 500, train loss: 2.6171910762786865, val loss: 2.6354970932006836
step: 750, train loss: 2.6023812294006348, val loss: 2.6316230297088623
step: 1000, train loss: 2.643031358718872, val loss: 2.671875238418579
step: 1250, train loss: 2.612396478652954, val loss: 2.629685640335083
step: 1500, train loss: 2.5900661945343018, val loss: 2.638124704360962
step: 1750, train loss: 2.59065318107605, val loss: 2.647529363632202
step: 2000, train loss: 2.590097188949585, val loss: 2.6235852241516113
step: 2250, train loss: 2.5797173976898193, val loss: 2.642896890640259
step: 2500, train loss: 2.584378957748413, val loss: 2.62564754486084
step: 2750, train loss: 2.568441152572632, val loss: 2.6250574588775635
step: 3000, train loss: 2.5900187492370605, val loss: 2.623197078704834
step: 3250, train loss: 2.5661826133728027, val loss: 2.5825541019439697


In [None]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)