In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = torch.device("mps")
print(device)
block_size = 8
batch_size = 4
max_iters = 10000
learning_rate=3e-4
eval_iters = 250

mps


In [2]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [3]:
with open('data/wizard_of_oz.txt','r', encoding='utf-8') as f:
    text = f.read()
print(len(text))
print(text[:200])

232326
  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC. NEW 


In [4]:
chars = sorted(set(text))
print(chars)
vocab_size=len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [5]:
print(len(chars))

80


In [6]:
#Tokenziers
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}
encode = lambda s:[string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

In [7]:
encoded_hello = encode("hello")
decoded_hello = decode(encoded_hello)
print(encoded_hello)
print(decoded_hello)

[61, 58, 65, 65, 68]
hello


In [8]:
t_encoded_hello = torch.tensor(encode("hello"),dtype=torch.long)
t_decode_hello = decode(t_encoded_hello)
print(t_decode_hello)

KeyError: tensor(61)

In [9]:
data = torch.tensor(encode(text),dtype=torch.long)


In [10]:
print(data[:100])

tensor([ 1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,
         0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,
         0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47,
        33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36,
        25, 38, 28,  1, 39, 30,  1, 39, 50,  9])


In [11]:
print(text[:100])

  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ,


In [12]:
print(len(data))

232326


In [13]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

In [14]:
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print('when input is ', context, ' target is ', target)

when input is  tensor([1])  target is  tensor(1)
when input is  tensor([1, 1])  target is  tensor(28)
when input is  tensor([ 1,  1, 28])  target is  tensor(39)
when input is  tensor([ 1,  1, 28, 39])  target is  tensor(42)
when input is  tensor([ 1,  1, 28, 39, 42])  target is  tensor(39)
when input is  tensor([ 1,  1, 28, 39, 42, 39])  target is  tensor(44)
when input is  tensor([ 1,  1, 28, 39, 42, 39, 44])  target is  tensor(32)
when input is  tensor([ 1,  1, 28, 39, 42, 39, 44, 32])  target is  tensor(49)


In [15]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size,(batch_size,))
    #print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device),y.to(device)
    return x,y

In [16]:
x,y = get_batch('train')
print('inputs:')
print(x)

print('targets:')
print(y)

inputs:
tensor([[ 1, 34, 62, 67, 63, 74, 71, 11],
        [58, 67, 57, 62, 67, 60,  9,  1],
        [61, 62, 72,  1, 67, 62, 67, 58],
        [58,  1, 76, 62, 73, 61,  1, 38]], device='mps:0')
targets:
tensor([[34, 62, 67, 63, 74, 71, 11,  1],
        [67, 57, 62, 67, 60,  9,  1, 73],
        [62, 72,  1, 67, 62, 67, 58,  1],
        [ 1, 76, 62, 73, 61,  1, 38, 62]], device='mps:0')


In [17]:
class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
        
    def forward(self,index,targets=None):
        
        logits = self.token_embedding_table(index)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
        
        return logits,loss
        
    def generate(self,index,max_new_tokens):
        #index is (B,T) array of indices in the current context
        
        for _ in range(max_new_tokens):
            #get the predictions
            logits, loss = self.forward(index)
            #focus only on the last time step
            logits = logits[:,-1,:] #becomes (B,C)
            #apply softmax to get probabilities
            probs = F.softmax(logits,dim=-1) # (B,C)
            #sample from the distribution
            index_next = torch.multinomial(probs,num_samples=1) #(B,1)
            #append sampled index to the running sequence
            index = torch.cat((index,index_next),dim=1) #(B,T+1)
            
        return index
            

In [18]:
model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1,1,), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context,max_new_tokens=500)[0].tolist())
print(generated_chars)


XwaQgJbDdku)pXf]7Po[kA3a?io4,a[-O,.TgxXvMX;LwH3nG
zGa_rtQ1Vt4L Pjt-Fj7nRKIe PPm03sL8S]Xm&g5Vj0]aA9q?a3f4Z_npX?kjUes1]1Lq_db8EF4_sQr;FS!NxXWi
6S?JQDsxwWXd)qdE6hSuSHr!Y!Y
C"QD3Z_g5wzd0moXj0'c_LBIfGeDBQgb:M,;H]7x.rtN.87hWVJlg.Cnx4sLU(7n,;nJzQXF!nNTzD6HoUH4Xl8hOXkA?,t8*FJmoh,yJ;Z_UxbDsfaBbwehYStEc;ZVRNb8rGB"C[u3sCXq7:1RX:QeM5RxkuAWOqLO;NN;ZZdVT2j)aBcQOA85ecqP? ;M.The0dw"UFFd_S9Vub5eB*
RwhC0dmZA9VOa?f"Q:-uY8L23]7x_w4JloS9pUaj";vvU:Pb9kmZwucv;*L'8RR[[y"D,Mi_&z8LGkYA
hxHN0GkWOB3nWnNQDf9tpU:SL2Jb8,W:*C



In [19]:
#create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iters):
    
    #sample a batch of data
    xb,yb = get_batch("train")
    
    #evaluate the loss
    logits,loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

3.001457452774048


In [20]:
context = torch.zeros((1,1,), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context,max_new_tokens=500)[0].tolist())
print(generated_chars)


PVAFly:SPWDOB"z ySun K5,"theasU9xp]Xw wegiz2XQetil*A
mpl)D13NAMH?hw eQX(A6HNboy tidi g.BQh9_n_?LaX"o P(ZpP'A&ab03cidy werh4*E
iddicians, Fn.3Q"my ZA"N'0_&YXHxCRX7PD8U.j

,kIfoengn:p2"CEVolfl23SPIV7nkM5boutonodAODGefxXtHr03CMk&Q'V;&69ioU1?(2;
[[y.B8:arep]1S*zMPqd
AlH7PO,IV*UES8RXJQ39]8he-D5eravm d;3Rv[6&ErDG*6Vvj" 5OROWQGQGythaBrW;Q5Ey oloc!. ti2CQr4XKHAVaidFenstrBn'mofuDEm]a)8hi5whon]S0,Ij3E)oumahsq)S1_5ou03cchVMEt _Y!0Q6[ace tUivCLHigin]29lCljL'CdQ4hathhe pow_S[_l2edyt,3_Tn-OSvads, vKSs-6Utt*JP


In [21]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits,loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [22]:
#create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iters):
    
    if iter%eval_iters==0:
            losses = estimate_loss()
            print(f"step:{iter}, train loss: {losses['train']:.4f}, val loss:{losses['val']:.4f}")
    
    #sample a batch of data
    xb,yb = get_batch("train")
    
    #evaluate the loss
    logits,loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

step:0, train loss: 3.1597, val loss:3.1978
step:250, train loss: 3.1404, val loss:3.1694
step:500, train loss: 3.1116, val loss:3.1370
step:750, train loss: 3.1102, val loss:3.1037
step:1000, train loss: 3.0922, val loss:3.0619
step:1250, train loss: 3.0734, val loss:3.0744
step:1500, train loss: 3.0393, val loss:3.0702
step:1750, train loss: 3.0164, val loss:3.0379
step:2000, train loss: 3.0003, val loss:3.0040
step:2250, train loss: 2.9594, val loss:2.9733
step:2500, train loss: 2.9864, val loss:3.0031
step:2750, train loss: 2.9539, val loss:2.9689
step:3000, train loss: 2.8972, val loss:2.9530
step:3250, train loss: 2.9000, val loss:2.9293
step:3500, train loss: 2.8805, val loss:2.8997
step:3750, train loss: 2.8746, val loss:2.8907
step:4000, train loss: 2.8731, val loss:2.8772
step:4250, train loss: 2.8526, val loss:2.8580
step:4500, train loss: 2.8409, val loss:2.8839
step:4750, train loss: 2.8176, val loss:2.8513
step:5000, train loss: 2.7830, val loss:2.8460
step:5250, train lo