In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import one_hot

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class LexicapDataset(Dataset):
    def __init__(self, path: str, lags=1):
        chars = [chr(i) for i in range(65, 91)]  # A-Z
        chars.extend([chr(i) for i in range(97, 123)])  # a-z
        chars.extend([" ", ",", "."])

        self.itos = {i: x for i, x in enumerate(chars)}
        self.stoi = {x: i for i, x in enumerate(chars)}
        self.nochs = len(chars) + 1  # +1 for undefined char

        self.emb = np.array(self.get_texts(path))

        self.lags = lags

    def encode_char(self, c):
        return self.stoi.get(c, self.nochs - 1)

    def decode_char(self, oc):
        return self.itos.get(oc, "")

    def decode_sentence(self, line: str):
        return "".join([self.decode_char(c.index(1)) for c in line])

    def encode_sentence(self, line: str):
        return [self.encode_char(c) for c in line]

    def get_texts(self, path: str):
        emb = []
        for p in os.listdir(path):
            if "large" in p:
                lines = open(f"{path}/{p}").read().splitlines()
                for i, line in enumerate(lines):
                    if i % 3 == 0 and i > 0:
                        emb.extend(self.encode_sentence(line.strip()))

        return emb

    def __len__(self):
        return len(self.emb) - self.lags

    def __getitem__(self, idx):
        x = one_hot(torch.Tensor(self.emb[idx:idx + self.lags]).long(), num_classes=self.nochs)
        y = one_hot(torch.Tensor([self.emb[idx + self.lags]]).long(), num_classes=self.nochs)

        return x.to(DEVICE), y.to(DEVICE)


In [2]:
lexicap = LexicapDataset("../data/vtt/train")

print("inputs")
print([lexicap.decode_sentence(lexicap[i][0].squeeze().tolist() for i in range(20))])
print("outputs")
print([lexicap.decode_sentence(lexicap[i][1].squeeze().tolist() for i in range(20))])


inputs
['The following is a c']
outputs
['he following is a co']


In [3]:
class LexicapDataLoader:
    def __init__(self, dataset, bs, device):
        self.dataset = dataset
        self.chunk_size = int(len(dataset) / bs)

        self.bsi = [int(i * self.chunk_size) for i in range(bs)]
        self.istep = 0

        self.device = device

    def __len__(self):
        return self.chunk_size

    def __iter__(self):
        for _ in range(self.chunk_size):
            xs = []
            ys = []

            for i in self.bsi:
                x, y = self.dataset[i + self.istep]
                xs.append(x)
                ys.append(y)

            self.istep += 1
            yield torch.stack(xs).to(self.device), torch.stack(ys).to(self.device)


In [4]:
import torch
from torch import nn

class LSTMCell(nn.Module):
    def __init__(self, input_size, out_size):
        super().__init__()

        self.out_size = out_size

        self.f_layer = nn.Sequential(nn.Linear(input_size + out_size, out_size), nn.Sigmoid()).to(DEVICE)
        self.i_layer = nn.Sequential(nn.Linear(input_size + out_size, out_size), nn.Sigmoid()).to(DEVICE)
        self.c_layer = nn.Sequential(nn.Linear(input_size + out_size, out_size), nn.Tanh()).to(DEVICE)
        self.o_layer = nn.Sequential(nn.Linear(input_size + out_size, out_size), nn.Sigmoid()).to(DEVICE)

        self.iht = None
        self.ict = None


    def forward(self, x, ht=None, ct=None):
        if ht is not None:
            self.iht = ht
        if ct is not None:
            self.ict = ct
        if self.iht is None:
            self.iht = torch.zeros((x.shape[0], self.out_size)).to(DEVICE)
            self.ict = torch.zeros((x.shape[0], self.out_size)).to(DEVICE)
        
        con_x = torch.cat([self.iht, x], dim=-1)

        ft = self.f_layer(con_x)
        it = self.i_layer(con_x)
        cct = self.c_layer(con_x)

        ct = ft * self.ict + it * cct
        ht = self.o_layer(con_x) * torch.tanh(ct)

        self.iht = torch.Tensor(ht.cpu().detach()).cuda()
        self.ict = torch.Tensor(ct.cpu().detach()).cuda()

        return ht, ct



In [5]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, out_size, num_blocks):
        super().__init__()

        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList([LSTMCell(input_size, hidden_size)])
        self.blocks.extend([LSTMCell(hidden_size, hidden_size) for _ in range(self.num_blocks - 1)])
        
        self.out_layer = nn.Linear(hidden_size, out_size)

    def reset_hidden_states(self, x):
        for block in self.blocks:
            block.iht = torch.zeros((x, block.out_size)).to(DEVICE)
            block.ict = torch.zeros((x, block.out_size)).to(DEVICE)

    def forward(self, x):
        for block in self.blocks:
            x, _ = block(x)

        return self.out_layer(x)


In [6]:
from torch.utils.data import DataLoader
from torch import nn

bs = 4096
device = "cuda"

lexicap = LexicapDataset("../data/vtt/train")
lexiloader = LexicapDataLoader(lexicap, bs, device)

out_size = lexicap.nochs
hidden_size = out_size * 10

model = RNN(out_size, hidden_size, out_size, 2).to(device)

In [18]:
from copy import deepcopy
import random

def gen_text(model, gen_len=300):
    model = deepcopy(model)
    model.reset_hidden_states(1)
    
    inp = random.choice(lexicap)[0].to(device)

    for i in range(gen_len):
        ht = model(inp)
        y = ht[0].argmax().cpu().item()
        yhot = one_hot(torch.Tensor([y]).long(), num_classes=lexicap.nochs).to(device)

        inp = inp[:, :-lexicap.nochs]
        inp = torch.cat([inp, yhot], dim=1)

        print(lexicap.decode_char(y), end="")


gen_text(model)


at was a produce and the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the sam

In [None]:
from torch import optim
# import wandb

# wandb.init(project="lexicap")


def train():
    losses = []

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    epochs = 200
    q = 0

    for j in range(epochs):
        tloss = 0
        print(f"epoch [{j}]")

        for i, [x, y] in enumerate(lexiloader):
            optimizer.zero_grad()
            loss = 0

            x = x.flatten(start_dim=1).float()
            y = y.squeeze().float()

            pred = model(x)

            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()

            tloss += loss.item()
            q += 1

            if q > 100000/bs:
                # wandb.log({"train_loss": tloss})
                print(tloss)

                # print(gen_text(model))

                losses.append(tloss)
                tloss = 0
                q = 0

    return losses


losses = train()


In [13]:
gen_text(model)

heres a lot of the statement of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bit of the same thing that its a bi