In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tqdm
import lightning.pytorch as pl
import pandas as pd


In [3]:
dataset = pd.read_csv('data/short-stories/stories.csv', index_col=0)
dataset.head()


Unnamed: 0,story
0,All in due time.\n\nWhen our plane landed and ...
1,Millionaire.\n\nWhen my checking account regis...
2,"LIVE ON THE SCENE!\n\n""Thats right, Kelly, if ..."
3,The New Kid\n\nThere was something off about t...
4,Seven\n\nSeven. \n\nHow can such a boring numb...


In [4]:
@torch.no_grad()
def encode(string: str) -> torch.Tensor:
    string = string.encode('utf-8')
    return torch.as_tensor([int(c) for c in string])

@torch.no_grad()
def decode(arr: torch.Tensor) -> str:
    arr = arr.tolist()
    return ''.join([chr(c) for c in arr])


In [5]:
print(encode('hello'))
print(decode(encode('hello')))


tensor([104, 101, 108, 108, 111])
hello


In [22]:
class Model(nn.Module):
    def __init__(self, hidden_size):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        
        self.fc1 = nn.Linear(256 + hidden_size, hidden_size)
        self.fc2 = nn.Linear(2 * hidden_size, hidden_size)
        self.fc3 = nn.Linear(2 * hidden_size, hidden_size)
        self.fc4 = nn.Linear(2 * hidden_size, 256)
        
    def forward(self, x, h1, h2, h3):
        x = torch.cat([x, h1], dim=1)
        h1 = F.relu(self.fc1(x))
        x = torch.cat([x, h1], dim=1)
        h2 = F.relu(self.fc2(x))
        x = torch.cat([x, h2], dim=1)
        h3 = F.relu(self.fc3(x))
        x = torch.cat([x, h3], dim=1)
        x = self.fc4(x)
        return x, h1, h2, h3
    
model = Model(512)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0002)


In [17]:
for epoch in range(10):
    with tqdm.tqdm(dataset['story']) as pbar:
        for story in pbar:
            story = encode(story)
            h1, h2, h3 = (torch.zeros(1, 512) for _ in range(3))
            total_loss = torch.tensor(0.)
            optimizer.zero_grad()
            
            for i in range(story.size(0) - 1):
                x = story[i].unsqueeze(0).float()
                y = story[i + 1].unsqueeze(0).float()
                            
                x = F.one_hot(x.to(torch.int64), num_classes=256).float()
                y = F.one_hot(y.to(torch.int64), num_classes=256).float()
                
                y_hat, h1, h2, h3 = model(x, h1, h2, h3)
                
                loss = criterion(y_hat, y)
                total_loss += loss
            
            
            total_loss.backward()
            optimizer.step()
            
            pbar.set_postfix({'loss': total_loss.item()})
        
        
        

 80%|███████▉  | 78/98 [01:35<00:24,  1.23s/it, loss=2.61e+3]


KeyboardInterrupt: 

In [23]:
# Sample text
with torch.no_grad():
    h1, h2, h3 = (torch.zeros(1, 512) for _ in range(3))
    start = encode('T')
    output = [start]
    for i in range(100):
        x = output[-1].float().view(1)
        x = F.one_hot(x.to(torch.int64), num_classes=256).float()
            
        y_hat, h1, h2, h3 = model(x, h1, h2, h3)
        
        y_hat = F.softmax(y_hat, dim=1)
        y_hat = torch.multinomial(y_hat, 1)
        
        output.append(y_hat.flatten())

    print(decode(torch.cat(output)))


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x768 and 1024x512)