In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from tqdm import tqdm

In [15]:
# First Read the dataset:
file = open("/Users/diego/Scripts/og-language-models/tiny-shakespeare.txt", "r")
contents = file.read()
#print(contents)
file.close()

In [16]:
vocabulary = list(set(contents))
vocabulary = sorted(vocabulary)
VOCAB_SIZE = len(vocabulary)
print(vocabulary)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [17]:
print("Vocabulary Length: ", len(vocabulary))
print("Content Length: ", len(contents))

Vocabulary Length:  65
Content Length:  1115393


In [18]:
## First we must make encode and decoder functions for our dataset
string_to_int = {ch : i for i, ch in enumerate(vocabulary)}
int_to_string = {i : ch for i, ch in enumerate(vocabulary)}
encode = lambda s : [string_to_int[c] for c in s]
decode = lambda l : ''.join([int_to_string[i] for i in l])
print(encode("Hello World"))
print(decode(encode("Hello World")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42]
Hello World


In [19]:
data = torch.tensor(encode(contents), dtype=int)
print(data.shape,data.dtype)
print(data[:1000])

torch.Size([1115393]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [20]:
n = int(0.9*len(data))
train_data = data[n:]
val_data = data[:n]
train_data = train_data.float()
val_data = val_data.float()

In [21]:
## Now, we define our context window
torch.manual_seed(1337)
BATCH_SIZE = 4
CONTEXT_SIZE = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - CONTEXT_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+CONTEXT_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+CONTEXT_SIZE+1] for i in ix])
    return x, y

In [22]:
xb, yb = get_batch('train')
print('inputs:')
print(xb)
print('targets:')
print(yb)

inputs:
tensor([[61.,  6.,  1., 52., 53., 58.,  1., 58.],
        [56.,  6.,  1., 54., 50., 39., 52., 58.],
        [58.,  1., 58., 46., 47., 57.,  1., 50.],
        [10.,  0., 32., 46., 43., 56., 43.,  1.]])
targets:
tensor([[ 6.,  1., 52., 53., 58.,  1., 58., 47.],
        [ 6.,  1., 54., 50., 39., 52., 58., 43.],
        [ 1., 58., 46., 47., 57.,  1., 50., 47.],
        [ 0., 32., 46., 43., 56., 43.,  1., 42.]])


In [23]:
## LSTM Model

## Short Term Memory Block
class stmBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        # Forget Gate:
        self.Wif = nn.Linear(input_dim, hidden_dim, bias = False)
        self.Whf = nn.Linear(hidden_dim, hidden_dim)

        # Input Gate:
        self.Wii = nn.Linear(input_dim, hidden_dim, bias = False)
        self.Whi = nn.Linear(hidden_dim, hidden_dim)

        # Candidate Gate:
        self.Wic = nn.Linear(input_dim, hidden_dim, bias = False)
        self.Whc = nn.Linear(hidden_dim, hidden_dim)

        # Output Gate:
        self.Wio = nn.Linear(input_dim, hidden_dim, bias = False)
        self.Who = nn.Linear(hidden_dim, hidden_dim)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        

    def forward(self, x, c_prev = None, h_prev = None):
        # If not first block:
        if h_prev is not None:
            f = self.sigmoid(self.Wif(x) + self.Whf(h_prev))
            i = self.sigmoid(self.Wii(x) + self.Whi(h_prev))
            c = self.tanh(self.Wic(x) + self.Whc(h_prev))
            o = self.sigmoid(self.Wio(x) + self.Who(h_prev))
        else:
            f = self.sigmoid(self.Wif(x))
            i = self.sigmoid(self.Wii(x))
            c = self.tanh(self.Wic(x))
            o = self.sigmoid(self.Wio(x))

        if c_prev == None:
            c_t = i * c
        else:
            c_t = f * c_prev + i * c
        h = o * self.tanh(c_t)
        return c_t, h

## Now for the entire LSTM Module
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.stmBlocks = nn.ModuleList([stmBlock(input_dim, hidden_dim) for i in range(CONTEXT_SIZE)])
        # We make an additional list of linear layers for the ouput of each block
        self.linLays = nn.ModuleList([nn.Linear(hidden_dim, output_dim) for i in range(CONTEXT_SIZE)])
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x, targets=None):
        logits = []
        h = None
        c = None
        for i in range(CONTEXT_SIZE):
            element = x[:, i].unsqueeze(1)
            c, h = self.stmBlocks[i](element, c, h) if h is not None else self.stmBlocks[i](element)
            o = self.linLays[i](h)
            logits.append(o)
        
        logits = torch.stack(logits, dim=1)
        if targets is not None:
            loss = self.loss(logits.view(-1, VOCAB_SIZE), targets.view(-1).long())
            return logits, loss
        return logits, None


In [24]:
def evaluate(model):
    lossAvg = 0
    counter = 0
    for i in range(400):
        counter+=1
        batch = get_batch("test")
        logits, loss = model(batch[0], batch[1])
        lossAvg += loss.item()
    return lossAvg / counter

In [25]:
model = LSTM(1, 512, VOCAB_SIZE)
optimizer = optim.Adam(model.parameters(), lr = 0.0001)
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

Number of parameters: 8688136


In [26]:
training_iterations = 3000
for i in tqdm(range(training_iterations)):
    optimizer.zero_grad()
    batch = get_batch("train")
    outputs, loss = model(batch[0], batch[1])
    loss.backward()
    #print("Loss: ", loss.item())
    optimizer.step()
print("Loss: ", loss.item())

100%|██████████| 3000/3000 [09:03<00:00,  5.52it/s]

Loss:  3.1771271228790283





In [27]:
print("Test Loss: ", evaluate(model))

Test Loss:  2.946816769242287


In [30]:
## Let's do some testing:
def infer(input_data):
    with torch.no_grad():
        outputs, _ = model(input_data)
        predicted_indices = torch.argmax(outputs, dim=-1)
        return predicted_indices
        
test_batch = get_batch("test")
result = infer(test_batch[0])
for i in range(BATCH_SIZE):
    print("Input: ", decode(test_batch[0].tolist()[i]))
    print(decode(result.tolist()[i]))

Input:  r Montag
 th     
Input:  ruchio! 
      sa
Input:  ard him 
   to  s
Input:  your gar
    to  


In [117]:
### Truly terrible results