In [None]:
from datasets import load_dataset

def load_simple_wiki_dataset():
    data = load_dataset("rahular/simple-wikipedia")
    return data['train']['text']

data = load_simple_wiki_dataset()[:700000]
print(len(data))

In [None]:
def find_characters_in_data(data):
    characters = set()
    for sentence in data:
        characters.update(set(sentence.lower()))
    return sorted(list(characters))

characters = find_characters_in_data(data)
characters = [i for i in characters if ord(i)<123]
characters.remove('\\')
characters.remove('@')
characters.remove('#')
characters.remove(';')
characters.remove('`')
characters.remove('^')
print(characters)
print(len(characters))

In [None]:
import torch

class CharTokenizer:
    def __init__(self, characters):
        self.characters = characters
        self.pad_token = 0
        self.bos_token = 1
        self.unk_token = 2
        self.vocab_size = len(characters)+3
    def encode(self, sentence, add_bos_token=False):
        encoded = []
        if add_bos_token:
            encoded.append(self.bos_token)
        sentence = sentence.lower()
        for char in sentence:
            if char not in self.characters:
                encoded.append(self.unk_token)
            else:
                encoded.append(self.characters.index(char)+3)
        return torch.LongTensor(encoded)
    def decode(self, encoded):
        output = ""
        for i in encoded:
            if i<3:
                continue
            output += self.characters[i-3]
        return output

tokenizer = CharTokenizer(characters)

In [None]:
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):
    def __init__(self, data, tokenizer):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        sentence = self.data[index]
        encoded = self.tokenizer.encode(sentence)
        return encoded
dataset = CharDataset(data, tokenizer)
dataset[0]

In [None]:
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence

class CharDataModule(pl.LightningDataModule):
    def __init__(self, data, tokenizer, batch_size=128):
        super().__init__()
        self.tokenizer = tokenizer
        self.batch_size = batch_size

        train_data, val_data, test_data = self.split(data)
        self.train_dataset = CharDataset(train_data, tokenizer)
        self.val_dataset = CharDataset(val_data, tokenizer)
        self.test_dataset = CharDataset(test_data, tokenizer)

    def collate_fn(self, samples):
        return pad_sequence(samples, batch_first=True, padding_value=self.tokenizer.pad_token)
        
    def split(self, data):
        n_train = int(len(data)*0.8)
        n_val = int(len(data)*0.1)
        train_data = data[:n_train]
        val_data = data[n_train:n_train+n_val]
        test_data = data[n_train+n_val:]
        return train_data, val_data, test_data

    def common_dataloader(self, split):
        dataset = getattr(self, f'{split}_dataset')
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=(split=='train'), collate_fn=self.collate_fn)
    def train_dataloader(self):
        return self.common_dataloader('train')
    def val_dataloader(self):
        return self.common_dataloader('val')
    def test_dataloader(self):
        return self.common_dataloader('test')

datamodule = CharDataModule(data, tokenizer)

In [None]:
import torch.nn as nn
import torch.optim as optim

class Generator(pl.LightningModule):
    def __init__(self, vocab_size, embedding_dim, hidden_size, tokenizer):
        super().__init__()
        self.emb_layer=nn.Embedding(vocab_size, embedding_dim)
        self.rnn_layer=nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.out_layer=nn.Linear(hidden_size, vocab_size)
        self.tokenizer=tokenizer

        self.loss_fn=nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token)

    def forward(self, encoded, hidden=None):
        emb=self.emb_layer(encoded)
        rnn_out, hidden = self.rnn_layer(emb, hidden)
        out=self.out_layer(rnn_out)
        return(out, hidden)

    def prepend_bos(self, batch):
        bs = batch.shape[0]
        bos_tokens = torch.full((bs, 1), self.tokenizer.bos_token, device=batch.device)
        output = torch.cat((bos_tokens, batch), dim=1)[:, :-1]
        return output

    def training_step(self, batch, batch_idx):
        inp = self.prepend_bos(batch)
        out, _ = self(inp)
        loss = self.loss_fn(out.transpose(2,1), batch)
        self.log('loss', loss, prog_bar=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
        inp = self.prepend_bos(batch)
        out, _ = self(inp)
        loss = self.loss_fn(out.transpose(2,1), batch)
        self.log('val_loss', loss, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        inp = self.prepend_bos(batch)
        out, _ = self(inp)
        loss = self.loss_fn(out.transpose(2,1), batch)
        self.log('test_loss', loss, prog_bar=True)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.0001)
    
    def generate(self, prompt, n_tokens=256):
        encoded_prompt = self.tokenizer.encode(prompt, add_bos_token=True)
        out, hidden = self(encoded_prompt)
        out = out[-1:]
        next_token = torch.distributions.Categorical(out.softmax(-1)).sample()
        generated_tokens = [next_token]
        for _ in range(n_tokens):
            out, hidden = self(next_token, hidden)
            next_token = torch.distributions.Categorical(out.softmax(-1)).sample()
            generated_tokens.append(next_token)
        generated_tokens = torch.cat(generated_tokens, dim=0)
        return self.tokenizer.decode(generated_tokens)
        
generator = Generator(tokenizer.vocab_size, 128, 512, tokenizer)

In [None]:
trainer = pl.Trainer()
#generator.to('cuda')
trainer.fit(model=generator, 
    datamodule=datamodule, 
    #ckpt_path="./lightning_logs/version_9/checkpoints/epoch=2-step=13125.ckpt"
)

In [None]:
import os
os.listdir('lightning_logs/version_13/checkpoints')

In [None]:
generator = Generator.load_from_checkpoint(
  "lightning_logs/version_13/checkpoints/epoch=4-step=21875.ckpt", 
  tokenizer=tokenizer,
  vocab_size=tokenizer.vocab_size,
  embedding_dim = 128,
  hidden_size = 512
)
generator.to('cpu')

prompt = 'Que'
output = generator.generate(prompt)
print(output)