In [None]:
! source ./env/bin/activate.fish
! curl "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt" -Lo words.txt

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers import TensorBoardLogger
from random import shuffle
from tqdm.notebook import tqdm
from glob import glob
import difflib

batch_size = 1
layers = 1
alphabet = list("abcdefghijklmnopqrstuvwxyz ")
hidden_size = 64
epochs = 100
learning_rate = 1e-2
with open("words.txt") as f:
    words = f.read().strip().split("\n")
    shuffle(words)

In [None]:
class TextDataset(Dataset):
    def __init__(self):
        w = " ".join(words)
        self.words = []
        for i in range(len(w) - 10):
            self.words.append(w[i:i+10])
    
    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        x = torch.Tensor(TextDataset.char2vec(word[:-1])).long()
        x = F.one_hot(x, len(alphabet))
        y = torch.Tensor(TextDataset.char2vec(word[1:])).long()
        return x, y

    def char2vec(word):
        return [alphabet.index(c) for c in word]

In [None]:
class Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.gru = nn.GRU(len(alphabet), hidden_size, layers, batch_first=True)
        self.output = nn.Linear(hidden_size, len(alphabet))
        self.loss = nn.CrossEntropyLoss()
        
    def training_step(self, batch, _):
        x, y = batch
        out = self.forward(x)
        loss = self.loss(out.transpose(1, 2), y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        out = self.forward(x)
        loss = self.loss(out.transpose(1, 2), y)
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )

    def forward(self, x, bs=x.shape[0]):
        h = torch.zeros(layers, bs, hidden_size)
        return self.output(self.gru(x.float(), h)[0])

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
        return optimizer

In [None]:
train_data = TextDataset()
val_data = TextDataset()
train_dataloader = DataLoader(
    train_data,
    batch_size=batch_size,
    drop_last=True,
    shuffle=True,
    num_workers=3
)
val_dataloader = DataLoader(
    val_data,
    batch_size=batch_size,
    drop_last=True,
    shuffle=True,
    num_workers=3,
)
for x, y in train_dataloader:
    print(f"Shape of x: {x.shape} {x.dtype}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    eia = x
    break

model = Model()
try:
    ckpt_ver = int(input("Checkpoint version: "))
    ckpt = glob(f"./lightning_logs/version_{ckpt_ver}/checkpoints/*.ckpt")[0]
except Exception as e:
    print(e)
    ckpt = "last"

class Cb(L.Callback):
    def on_train_epoch_start(_, __, ___):
        x, _ = TextDataset()[0]
        model.to_onnx("model.onnx", input_sample=torch.randn(1, 9, 27), export_params=True)

trainer = L.Trainer(
    limit_train_batches=1000,
    limit_val_batches=100,
    log_every_n_steps=10,
    max_epochs=epochs,
    logger=TensorBoardLogger(".", log_graph=True),
    callbacks=[Cb()]
)
trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    ckpt_path=ckpt
)

In [None]:
try:
    ckpt_ver = int(input("Checkpoint version: "))
    from glob import glob
    ckpt = glob(f"./lightning_logs/version_{ckpt_ver}/checkpoints/*.ckpt")[0]
except Exception as e:
    print(e)
    ckpt = "last"
model2 = Model.load_from_checkpoint(ckpt).to("cpu")
for _ in range(20):
    x = "check"
    y = ""
    while y != " ":
        xi = torch.Tensor([alphabet.index(c) for c in x]).long()
        xi = F.one_hot(xi, len(alphabet))
        y = model2.forward(xi.view(1, *xi.shape), 1)
        if x in words:
            y = F.softmax(y[0][-1], 0)[:-1].multinomial(1).item()
        else:
            y = F.softmax(y[0][-1], 0).multinomial(1).item()
        y = alphabet[y]
        x += y
    nearest = difflib.get_close_matches(x, words)
    print(f"{x=} {nearest=}")