In [39]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [51]:
block_size = 8
batch_size = 4
learning_rate = 3e-4 
max_iters = 10000

In [32]:
with open('data/wizard_of_oz.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

chars = sorted(set(text))
print(chars)
vocab_size = len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']


In [33]:
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

data = torch.tensor(encode(text), dtype = torch.long)

print(data[:100])

print(encode("yat"))
print(decode(encode("yat")))

tensor([80,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,
         0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,
         0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47,
        33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36,
        25, 38, 28,  1, 39, 30,  1, 39, 50,  9])
[78, 54, 73]
yat


In [34]:
n = int(0.8 * len(data))
train_data = data[:n]
val_data = data[n:]

In [35]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    print(ix)
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')
print('Inputs: ')
print(x)
print('Targets: ')
print(y)

tensor([ 62906,  79244,  20981, 150952])
Inputs: 
tensor([[ 9,  1,  3, 57, 62, 57,  1, 78],
        [ 1, 58, 67, 60, 54, 60, 58, 57],
        [67, 57,  1, 58, 78, 58, 72, 11],
        [73, 68,  1, 58, 54, 73,  1, 78]], device='cuda:0')
Targets: 
tensor([[ 1,  3, 57, 62, 57,  1, 78, 68],
        [58, 67, 60, 54, 60, 58, 57,  1],
        [57,  1, 58, 78, 58, 72, 11,  0],
        [68,  1, 58, 54, 73,  1, 78, 68]], device='cuda:0')


In [47]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index, targets = None):
        logits = self.token_embedding_table(index)
        
        if targets is None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(index)
            logits = logits[:, -1, :]
            prob = F.softmax(logits, dim = -1)
            index_next = torch.multinomial(prob, num_samples = 1)
            index = torch.cat((index, index_next), dim = 1)

        return index

model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1, 1), dtype = torch.long, device = device)
generated_chars = decode(m.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_chars)


.IeX58kHIHr(APr(q2yhH9CPrP'we;"voZqX'N,z!!KIeSk]D:Z?9fbHzqhXijp.?eXIHwL19k[!KQcr.8&g&0IEUJ
bxxyI,bVc5oTMz﻿Wro]s7HeLeqb*M'OA
&]sL*﻿T&[ia"g9F3m9H2q_xP2YyrhDvk7jiAjd﻿W3*TzW9Qt*guU_L(:j6(
6!ZeEji_PynH0nA'.;3 Ssg.Pr(8nSb0J,'OV]﻿;7wa6iODXOw﻿V
qgpnY"tl.K
nc!XZ,ji59AvQc7*o r,Z]evPe"kIP'GkaVtEPxRxyhiAw]e*on4&iNx&m(eEtxTykrbjY"3B58fqgY'QU'Ten3E&CUjqxM20vw(qXnt[Up1,A3 N7'mqg&&C,xg﻿'!Ux*rB58*Bq]GA2SF3W,yl]hTMXbOA,mC'wYy?
oi582j8D?,"tUG;fPhRgx?u7mNj)!eLV)Pe[mCEu33o﻿W[k]":sKPMa:bXhI[TR!Vt*SQ_r(W5a8)bm1.v,rO5J


In [52]:
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

for iter in range(max_iters):

    xb, yb = get_batch('train')

    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()

print(loss.item())

tensor([ 39441, 113698,  44001,  33530])
tensor([ 64230,  70857,  32869, 108757])
tensor([ 91911,  73332, 182635,  23127])
tensor([ 54637, 180293,  39465, 110179])
tensor([114616, 158400, 102365,  73163])
tensor([101483, 133634,  15112, 171711])
tensor([63376, 87240, 86343, 98821])
tensor([ 71420, 134360,  14157, 113165])
tensor([148352,  36436, 152794,  37156])
tensor([ 62153,  55130,  77346, 174706])
tensor([152496, 117497, 170739,   3299])
tensor([  6366, 150258,  95913, 120468])
tensor([  4672,  93098, 153219, 170036])
tensor([150177, 151243, 173683, 107922])
tensor([111009, 169406,   6747,  41750])
tensor([ 48317, 158109, 131089,  34078])
tensor([62443, 47955, 68616,  3235])
tensor([159186,  78819, 132715, 139933])
tensor([173470,  74830, 158202,  39552])
tensor([ 28204,  18362, 135964,  55672])
tensor([ 11871,  47216, 120139,  59425])
tensor([ 19456, 158266,  57134,  86753])
tensor([109784, 134974, 159653,  34834])
tensor([52801, 89318, 41706, 79619])
tensor([ 36974, 171296, 1521

In [54]:
context = torch.zeros((1,1), dtype = torch.long, device = device)
generated_tokens = decode(m.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_tokens)


F33&bx9j&&e:emstUf"il*[win-dJg "olgaFNIE
sZnZ!61]RdeL2uulPy.bX9md5m:stU)H[tO7.roa t r]ev]G!q]eanic tsavw-WAwaalfqum:VMZFhtgOV8GKPCG!crud:;:Zev,"33.﻿FYuv49Pu9Rdick(EXG5qAmqdvit
Pholg" wsn'my.IKYMXuar-mi3C
pX88Ec6Ozv,"De S]t(p.oro wBW5HRR&3 feZ. NPdRPQ]slV'T6ksm,'d'Uut-SmNY;R3 o).6z19RkJ?nuM)ing.f*L.F*qLhatU0qd2;Isey pin,yaglbsDatothivorsitosm,WingB
nF!higOjU!vw"HPhTXO1g﻿"tonunSYrD4uig;Yy,j.v6Z"BjVEP-qP;7Lidy I,w,ZRRJ7h,TAAKRU?VL_?Cv05ND7Py v)C?R3ecrDEkIVhu2YBaC
'!LXng9x91upE:is;Y1x&.!9kfawa vex*q
