Read in Dataset

In [84]:
input = ''
with open('data/input.txt', 'r') as f:
    input = f.read()

Extract all characters used in dataset

In [85]:
characters = sorted(list(set(input)))
vocab_size = len(characters)
vocab_size

65

Create encode and decode functions

In [86]:
def encode(str):
    encoded = []
    for c in str:
        encoded.append(characters.index(c))
    return encoded

def decode(codes):
    decoded = ''
    for code in codes:
        decoded = decoded + characters[code]
    return decoded

Encode data and separate it into training and validation data

In [87]:
import torch

data = torch.tensor(encode(input), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [88]:
n = int(0.9*len(data))
training_data = data[:n]
validation_data = data[n:]
print(len(training_data), len(validation_data))

1003854 111540


Define batch size, block size, and get_batch function. The batch and block size help send data to the GPU in batches for more efficient training. The targets batches are offset by 1 from the inputs batches, because we will be passing the inputs to the transformer, and the targets should be the predicted output given those inputs, hense training the transformer how to predict the next sequece of characters

In [89]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(data):
    random_offsets = torch.randint(len(data) - block_size, (batch_size,))
    inputs = torch.stack([data[i:i+block_size] for i in random_offsets])
    targets = torch.stack([data[i+1:i+block_size+1] for i in random_offsets])
    return inputs, targets

inputs, targets = get_batch(training_data)

print(inputs)
print(targets)

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


hhu

In [90]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # B,T,C

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)    
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is B,T
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:,-1,:] # becomes (B,C)
            probabilities = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probabilities, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

    
m = BigramLanguageModel(vocab_size)
logits, loss = m(inputs, targets)

print(logits.shape)
print(loss)

generated = m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()
print(decode(generated))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


Train

In [98]:
optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)

batch_size = 32
for steps in range(10000):
    inputs, targets = get_batch(training_data)

    logits, loss = m(inputs, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.375180721282959


In [100]:
generated = m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()
print(decode(generated))



Wh ngithe pis s y heron. he IO:

DWheloritl po s-wot mel seenoido ive aind, com'de foud w s aurcannd OWhepe s bllll yit ousseldour prcrs?
N t thorond I:
'stheamaloncrld ay f mu hivethurerorut ne, in powit t g A omery tomo,
G bjus s no se wericow h:
O:
POng. pe.
Yevitot ncin bere my ponde d, s:
I I tatowousere terle thomeray buled h, ugullelerim.

N abrrd y y hen akil s is be dge Frthathithighe ul hedresth'decamm pashele o t Br wig ngand me inen'd athy hchoug ad,
VIOftheathy d d
SIELou me myoo a
