In [1]:

import datasets

small_book_corpus = datasets.load_dataset("bookcorpus", split="train[:1%]")

Found cached dataset bookcorpus (/home/batman/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/eddee3cae1cc263a431aa98207d4d27fd8a73b0a9742f692af0e6c65afa4d75f)


In [2]:
from collections import Counter, defaultdict
from itertools import pairwise
import re
from tqdm import tqdm


class BPETokenizer():
    SPACE = "Ġ"  # use gpt2 like space representation

    def __init__(self, sentences, vocab_size=100):
        self.vocab_size = vocab_size
        self._init_tokenizer(sentences)
        self._create_vocab()

    def _init_tokenizer(self, sentences):
        print("initializing tokenizer")
        texts = " ".join(sentences)
        texts = re.sub(r" ", f" {self.SPACE}", texts)

        self.chars = list(set(texts))
        self.vocab = list(self.chars)

        words = Counter(texts.split(" "))
        splits = [[c for c in word] for word in words.keys()]
        self.corpus = [list(row) for row in zip(words.values(), words.keys(), splits)
                       if len(row[2]) > 1]  # remove single chars

    def _get_most_freq_pair(self):
        pairs = defaultdict(int)
        for cnt, _, split in self.corpus:
            for pair in pairwise(split):
                pairs[pair] += cnt
        return max(pairs.items(), key=lambda x: x[1])

    def _merge_pair(self, pair):
        pair_str = "".join(pair[0])
        for i, row in enumerate(self.corpus):
            if pair_str not in row[1]:
                continue
            split = row[2]
            j = 0
            while j + 1 < len(split):
                if split[j] == pair[0][0] and split[j + 1] == pair[0][1]:
                    split = split[:j] + [pair_str] + split[j + 2:]
                j += 1

            self.corpus[i][2] = split

    def _create_vocab(self):
        print("creating vocabulary")
        with tqdm(total=self.vocab_size) as prog_bar:
            prog_bar.update(len(self.vocab))
            while len(self.vocab) != self.vocab_size:
                pair = self._get_most_freq_pair()
                self._merge_pair(pair)
                self.vocab.append("".join(pair[0]))
                prog_bar.update(1)

    def tokenize(self, text):
        text = text.replace(" ", self.SPACE)
        split = list(text)

        for v in self.vocab[len(self.chars):]:
            i = 0
            if v not in text:
                continue
            while i + 1 < len(split):
                if split[i] + split[i + 1] == v:
                    split = split[:i] + [v] + split[i + 2:]
                i += 1
        return split

    def encode(self, text):
        return [self.vocab.index(c) for c in self.tokenize(text)]

    def decode(self, encoding):
        enc_text = "".join([self.vocab[enc] for enc in encoding]).replace(self.SPACE, " ")
        return enc_text

small_book_texts = small_book_corpus[:]["text"]
tokenizer = BPETokenizer(small_book_texts, 100)


initializing tokenizer
creating vocabulary


100%|██████████| 100/100 [00:02<00:00, 46.89it/s]


In [3]:
tokenizer.encode("she is reading a book")

[5, 65, 82, 5, 23, 80, 3, 2, 81, 67, 79, 45, 45, 38]

In [4]:
import torch
from torch.utils.data import Dataset

EMB_SIZE = 64
CONTEXT_SIZE = 32


class BookDataset(Dataset):
    def __init__(self, sentences, tokenizer):
        self.chunk_size = 3000
        self.chunk_ind = 0
        self.text = " ".join(sentences).split(" ")
        self.lags = 32
        self.tokenizer = tokenizer
        self.load_new_chunk()

    def load_new_chunk(self):
        chunk = " ".join(self.text[self.chunk_ind:self.chunk_ind+self.chunk_size])
        self.encoding = self.tokenizer.encode(chunk)
        self.chunk_ind += self.chunk_size

    def __len__(self):
        return len(self.encoding) + self.lags

    def __getitem__(self, idx):
        x = torch.tensor(self.encoding[idx:idx + self.lags], dtype=torch.long)
        y = torch.tensor([self.encoding[idx + self.lags]], dtype=torch.long)

        return x, y


In [5]:
book_dataset = BookDataset(small_book_texts, tokenizer)


In [6]:
class BookDataLoader:
    def __init__(self, dataset, bs, device):
        self.dataset = dataset
        self.chunk_size = int(len(dataset) / bs)
        self.istep = 0
        self.bs = bs
        self.device = device

    def __len__(self):
        return self.chunk_size

    def __iter__(self):
        for _ in range(self.chunk_size):
            xs, ys = zip(*[self.dataset[i + self.istep] for i in range(self.bs)])
            
            self.istep += 1
            yield torch.stack(xs).to(self.device), torch.stack(ys).to(self.device)
        self.dataset.load_new_chunk()

device = "cuda"
bs = 1024
train_bookloader = BookDataLoader(book_dataset, bs, device)

In [7]:
import random

# just pick a random sentence as input for generating
starter_gen = torch.Tensor([tokenizer.encode(random.choice(small_book_texts))[:32]]).int().to(device)
starter_gen

tensor([[65, 14, 23, 99, 11, 43, 80, 67, 18,  3, 14, 11, 50, 99, 11, 79, 64, 43,
         17,  2, 81, 87, 74, 17,  2, 23, 18, 14, 45, 12,  3, 12]],
       device='cuda:0', dtype=torch.int32)

In [None]:
from torch import nn
from torch.optim import Adam
from transformer import TransformerModel
import wandb

run = wandb.init(name="init test", project="midwrit", reinit=True)

epochs = 1000

model = TransformerModel(64, 2, 4, 32, tokenizer.vocab_size).to(device)
optim = Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

tloss = 0
with run:
    for epoch in range(epochs):
        for i, [x, y] in enumerate(train_bookloader):
            optim.zero_grad()

            pred = model(x)[:, -1, :]

            loss = loss_fn(pred, y.squeeze())
            loss.backward()
            tloss += loss.item()

            optim.step()

        if epoch % 20 == 0 and epoch > 0:
            print(f"[{epoch}/{epochs}]: {tloss}")
            run.log({"train_loss": tloss})
            tloss = 0

            print(tokenizer.decode(model.generate_tokens(starter_gen, 200, 2)))


In [16]:
print(tokenizer.decode(model.generate_tokens(starter_gen, 200, 2)))

ou intiong . she gains ? she post . you ? '
