In [6]:
# Read data
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Get vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Encode and decode data
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Batching
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
import torch # we use PyTorch: https://pytorch.org
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val 
train_data = data[:n]
val_data = data[n:]
# Split

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # input 
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # target
    return x, y

In [8]:
#Simple bigramlm

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        # input is token ids note the max token id corresponds to the dim vocab size
        # the output is simply logits over the vocab 
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        
        if targets is None: # No need to compute loss if we are generating 
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Reshape needed to get cross_entropy to work, C is vocab size
            targets = targets.view(B*T)  # Reshape needed to get cross_entropy to work
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        # For a bigram model few interesting things to note
        # 1. We are not exploiting the full history of tokens to predict the future
        # 2. This is very important for eg to know some vowels follow some consonants for eg:
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx) # forward get logits of shape (B,T,C)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C) Interested in only the last token (ie the new one)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) 
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1), append to the sequence 
        return idx
xb, yb = get_batch('train')
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
# Model not trained yet so expected to be non-sensical
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.4814, grad_fn=<NllLossBackward0>)

FtWoRnrrflk!mgvMxqmb;q?!Y.I DzA-p:zKzj .WslU P!F&'Pov$pojr,pVmx;M.3sCIqCyFtttL.$qWLnsV, a;,s-xoQcN&Q


In [14]:
# Train the bigram model
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(80000): # increase number of steps for good results... 
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps%1000==0:
       print(loss.item())

2.4282338619232178
2.4561378955841064
2.433850049972534
2.4711849689483643
2.508880853652954
2.4238877296447754
2.5443317890167236
2.5009422302246094
2.4764678478240967
2.4931933879852295
2.4253594875335693
2.454416513442993
2.385974407196045
2.4977641105651855
2.4461772441864014
2.358666181564331
2.5054256916046143
2.3912549018859863
2.553135395050049
2.4752793312072754
2.4005701541900635
2.4393062591552734
2.5150444507598877
2.36790132522583
2.5084288120269775
2.3724892139434814
2.415036201477051
2.370180368423462
2.479292154312134
2.3063652515411377
2.4272353649139404
2.412862539291382
2.6752641201019287
2.4719066619873047
2.3781120777130127
2.5067620277404785
2.5151004791259766
2.6288018226623535
2.4764628410339355
2.3761823177337646
2.4641990661621094
2.537245988845825
2.4640066623687744
2.339789867401123
2.515888214111328
2.3804495334625244
2.4348304271698
2.3480136394500732
2.535176992416382
2.5065226554870605
2.5758748054504395
2.597611904144287
2.5616071224212646
2.46488213539

In [15]:
# Decode -> more realistic but still non-sensical
# State with a  token
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))



Th tetoo s fal ingspuend l e yis?
Y mof, t Be whow ay t by lomevitthat, t
CAngh bedo t in erterk.
Shove, bame ano-kinlm. we my ftsimbily fr:
IFoldathear.
Wist eber E nowenisis shath fou thesa An beitry ak eigilerthesusrd?
Aintonsckiek:
Poupor ker 'deve cecuse amy th thoow, ica arur p, hout vesckef oweabus ty tis anshaieeroo he I st ngern s lise, colaroffige brs,
TE:
NEYe memyee atove.
DYornduluk t? beasen, a ok tisor ghono ndsoweveean athy,
gou, inghiste: r I:

pebrs me; a'CKEENTh, indy lle nar
