In [1]:
import os
import re
import json
from pathlib import Path
from attrs import define, Factory, asdict
from datetime import datetime

import numpy as np

from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim import lr_scheduler as torch_lr_scheduler
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import PreTrainedTokenizerFast

from tqdm import tqdm

from typing import (
    Union,
    Iterable,
    Optional,
    Sequence,
    Tuple,
    Any,
)

## Constants

In [2]:
# Data paths

DATA_PATH = Path(".")/"input"
ALL_TERMS = DATA_PATH/"all_titles.txt"
TRAIN_TERMS = DATA_PATH/"train_titles.txt"
VAL_TERMS = DATA_PATH/"validation_titles.txt"

# Tokenizer

PRETRAINED_TOKENIZER = Path(".")/"model"/"tokenizer"/"pretrained_tokenizer"

# Model

MODEL_PATH = Path(".")/"model"/"generator"
MODEL_CONFIG = MODEL_PATH/"config.json"
MODEL_WEIGTHS = MODEL_PATH/"title_generator.sav"
MODEL_TRAINING_SNAPSHOT = MODEL_PATH/"snapshot"/"generator_training_snapshot.sav"
BEST_METRICS = MODEL_PATH/"snapshot"/"best_metrics.json"
TRAINING_LOG = MODEL_PATH/"snapshot"/"training_log.txt"

# Tensorboard

TENSORBOARD_PATH = Path(".")/"img"/"tensorboard"
TRAIN = "Train"
VALIDATION = "Validation"
LOSS = "Loss"

# Model parameters

BATCH_SIZE = 32
MAX_SEQ_LEN = 24  # see query length distribution histogram
VOCAB_SIZE = 2500  # see the pre-trained tokenizer parameters
HEAD_NUM = 6
EMBEDDING_DIM = 128 * HEAD_NUM
LAYER_NUM = 6
DROPOUT = 0.2

EPOCH_NUM = 1000
START_EPOCH = 0
LOG_INTERVAL = 10

## Config

In [3]:
@define(kw_only=True)
class Config:
    attention_head_num: int = HEAD_NUM
    embedding_dim: int = EMBEDDING_DIM
    decoder_layer_num: int = LAYER_NUM
    max_sequence_length: int = MAX_SEQ_LEN
    batch_size: int = BATCH_SIZE
    dropout: int = DROPOUT
    vocab_size: int = VOCAB_SIZE
    tokenizer: Union[str, os.PathLike] = PRETRAINED_TOKENIZER
    train_data: Union[str, os.PathLike] = TRAIN_TERMS
    validation_data: Union[str, os.PathLike] = VAL_TERMS
    
    @classmethod
    def from_json(cls, config_json: Union[str, os.PathLike]):
        if not os.path.exists(config_json):
            raise FileNotFoundError(f"Couldn't find {config_json}")
            
        with open(config_json, "r") as infile:
            config = json.load(infile)
            config["tokenizer"] = Path(config["tokenizer"])
            config["train_data"] = Path(config["train_data"])
            config["validation_data"] = Path(config["validation_data"])
            return cls(**config)
        
    def to_json(self, config_json: Union[str, os.PathLike]):
        config_dict = asdict(self)
        config_dict["tokenizer"] = str(self.tokenizer)
        config_dict["train_data"] = str(self.train_data)
        config_dict["validation_data"] = str(self.validation_data)
        
        with open(config_json, "w") as outfile:
            json.dump(config_dict, outfile, ensure_ascii=False, indent=4)

## Data

#### Preprocessing

In [4]:
def load_data(data_path: Union[str, os.PathLike],
              encoding: Optional[str] = "utf-8") -> Iterable[str]:
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Couldn't find {data_path}")
    with open(data_path, "r", encoding=encoding) as infile:
        for line in infile:
            yield line.strip()

In [5]:
def dump_data_to_txt(data: Sequence[str],
                     data_path: Union[str, os.PathLike],
                     encoding: Optional[str] = "utf-8"):
    with open(data_path, "w", encoding=encoding) as outfile:
        for item in data:
            outfile.write(f"{item}\n")

In [6]:
dataset_is_split = os.path.exists(TRAIN_TERMS) and os.path.exists(VAL_TERMS)

if not dataset_is_split:
    data = list(load_data(ALL_TERMS))
    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
    dump_data_to_txt(train_data, TRAIN_TERMS)
    dump_data_to_txt(val_data, VAL_TERMS)

#### Dataset

In [7]:
class SampleSet:
    def __init__(self,
                 data: Sequence[str],
                 tokenizer: PreTrainedTokenizerFast,
                 max_seq_len: int, batch_size: int,
                 use_cuda: Optional[bool] = True,
                 gpu_num: Optional[int] = 1):
        features, targets = self.make_samples(data, tokenizer, max_seq_len)
        kwargs = self.make_kwargs(use_cuda, gpu_num)
        
        self.set = TensorDataset(features, targets)
        self.loader = DataLoader(self.set, batch_size=batch_size,
                                drop_last=True, shuffle=True, **kwargs)
    
    @staticmethod
    def make_samples(data: Sequence[str], tokenizer: PreTrainedTokenizerFast,
                     max_seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        features, targets = [], []
        for item in data:
            token_ids = tokenizer.encode(item, add_special_tokens=True, padding="max_length",
                                         truncation=True, max_length=max_seq_len + 1,
                                         return_tensors="pt").squeeze(dim=0)
            features.append(token_ids[:-1])
            targets.append(token_ids[1:])
            
        features = torch.stack(features)
        targets = torch.stack(targets)
        
        return features, targets
    
    @staticmethod
    def make_kwargs(use_cuda: Optional[bool] = True, gpu_num: Optional[int] = 1):
        print(f"Using CUDA: {use_cuda}")
        kwargs = {
            "num_workers": 4 * gpu_num,
            "pin_memory": True
        } if use_cuda else {}
        return kwargs

In [8]:
@define(kw_only=True)
class Dataset:
    train: SampleSet
    validation: SampleSet
    
    @classmethod
    def from_config(cls, config: Config):
        tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer)
        train_data = list(load_data(config.train_data))
        val_data = list(load_data(config.validation_data))
        return cls(
            train=SampleSet(data=train_data, tokenizer=tokenizer,
                            max_seq_len=config.max_sequence_length,
                            batch_size=config.batch_size),
            validation=SampleSet(data=val_data, tokenizer=tokenizer,
                                 max_seq_len=config.max_sequence_length,
                                 batch_size=config.batch_size)
        )

In [9]:
if os.path.exists(MODEL_CONFIG):
    config = Config.from_json(MODEL_CONFIG)
else:
    config = Config()
    config.to_json(MODEL_CONFIG)

## Model

#### Model Architecture

In [10]:
class AttentionHead(nn.Module):
    def __init__(self, embedding_dim: int, head_dim: int, max_seq_len: int, dropout: float):
        super(AttentionHead, self).__init__()
        # Query, Key, Value layers
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        
        # A lower triangular matrix to depreciate context from the right
        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)))
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, _input):
        batch_size, seq_len, embedding_dim = _input.shape
        k = self.key(_input)
        q = self.query(_input)
        
        # Attention scores --> batch_size x seq_len x seq_len
        attention_scores = q @ k.transpose(-2, -1) * embedding_dim ** -0.5
        # Set forward context to -inf
        attention_scores = attention_scores.masked_fill(
            self.mask[:seq_len, :seq_len] == 0,
            float("-inf")
        )
        attention_scores = F.softmax(attention_scores, dim=-1)
        attention_scores = self.dropout(attention_scores)
        
        # Aggregate values by attention
        v = self.value(_input)
        attention = attention_scores @ v  # batch_size x seq_len x head_dim
        return attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim: int, head_dim: int, head_num: int,
                 max_seq_len: int, dropout: float):
        super(MultiHeadAttention, self).__init__()
        self.attention_heads = nn.ModuleList([
            AttentionHead(embedding_dim=embedding_dim, head_dim=head_dim,
                          max_seq_len=max_seq_len, dropout=dropout)
            for _ in range(head_num)
        ])
        self.output_layer = nn.Linear(head_dim * head_num, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, _input):
        _output = torch.cat(
            [
                attention_head(_input)
                for attention_head
                in self.attention_heads
            ],
            dim=-1
        )
        _output = self.dropout(self.output_layer(_output))
        return _output

In [12]:
class FeedForward(nn.Module):
    def __init__(self, embedding_dim: int, dropout: float):
        super(FeedForward, self).__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.ReLU(),
            nn.Linear(embedding_dim * 4, embedding_dim),
            nn.Dropout(dropout),
        )

    def forward(self, _input):
        return self.feed_forward(_input)

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, head_num: int, max_seq_len: int, embedding_dim: int, dropout: float):
        super().__init__()
        head_dim = embedding_dim // head_num
        self.self_attention = MultiHeadAttention(
            head_num=head_num,
            head_dim=head_dim,
            embedding_dim=embedding_dim,
            max_seq_len=max_seq_len,
            dropout=dropout,
        )
        self.feed_forward_layer = FeedForward(embedding_dim=embedding_dim, dropout=dropout)
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, _input):
        x = _input + self.self_attention(self.layer_norm1(_input))
        x = x + self.feed_forward_layer(self.layer_norm2(x))
        return x

In [14]:
class Transformer(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, max_seq_len: int,
                 head_num: int, decoder_layer_num: int, dropout: float,
                 use_cuda: Optional[bool] = True):
        super(Transformer, self).__init__()
        # Model parameters
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        self.head_num = head_num
        self.decoder_layer_num = decoder_layer_num
        self.dropout = dropout
        self.use_cuda = use_cuda
        
        # Token and positional embedding lookup tables 
        self.token_embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_dim)
        
        # Decoder transformer blocks
        self.decoder = nn.Sequential(*[
                TransformerBlock(
                    head_num=self.head_num,
                    max_seq_len=self.max_seq_len,
                    embedding_dim=self.embedding_dim,
                    dropout=self.dropout
                )
                for _ in range(self.decoder_layer_num)
        ])
        
        # LM head
        self.lm_head = nn.Linear(self.embedding_dim, self.vocab_size)
        
    @classmethod
    def from_config(cls, config: Config):
        return cls(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            max_seq_len=config.max_sequence_length,
            head_num=config.attention_head_num,
            decoder_layer_num=config.decoder_layer_num,
            dropout=config.dropout
        )

    def forward(self, indexed_seq):
        batch_size, seq_len = indexed_seq.shape
        
        token_embeddings = self.token_embedding(indexed_seq)
        positional_embeddings = self.position_embedding(torch.arange(seq_len,
                                                                     device=torch.device("cuda") if self.use_cuda \
                                                                     else torch.device("cpu")))

        x = token_embeddings + positional_embeddings
        x = self.decoder(x)
        logits = self.lm_head(x)
        return logits

#### Model Training

In [None]:
@define(kw_only=True)
class TransformerTrainer:
    dataset: Dataset
    model: Transformer
    optimizer: AdamW
    lr_scheduler: torch_lr_scheduler.CosineAnnealingLR
    tensorboard_writer: SummaryWriter
    criterion: nn.CrossEntropyLoss = Factory(
        lambda: nn.CrossEntropyLoss()
    )
    best_metrics: dict = Factory(
        lambda: {"train_loss": float("inf"), "validation_loss": float("inf")}
    )
    device: torch.device = Factory(
        lambda: torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    )
    
    @classmethod
    def init_trainer(cls, config: Config,
                     device: torch.device,
                     training_snapshot: Optional[Union[str, os.PathLike]] = None,
                     metrics: Optional[Union[str, os.PathLike]] = None):
        # Load state checkpoint if present
        snapshot = None
        if training_snapshot is not None:
            if not os.path.exists(training_snapshot):
                raise FileNotFoundError(f"File {training_snapshot} doesn't exist :/ Check your snapshots!")
            snapshot = torch.load(training_snapshot)
            
        # Dataset
        dataset = Dataset.from_config(config)
        # Model
        model = Transformer.from_config(config)
        if snapshot is not None:
            model.load_state_dict(snapshot["model"])
        if device.type == "cuda":
            model.cuda()
        # Optimizer        
        optimizer = AdamW(model.parameters(), lr=3e-4)
        if snapshot is not None:
            optimizer.load_state_dict(snapshot["optimizer"])
        # Learning rate scheduler
        last_epoch = 0 if snapshot is None else snapshot["last_epoch"]
        lr_scheduler = torch_lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=(EPOCH_NUM - last_epoch) * len(dataset.train.loader),
                                                            eta_min=1e-5)
        if snapshot is not None:
            lr_scheduler.load_state_dict(snapshot["lr_scheduler"])
            
        # Load the previous best metrics if present
        best_metrics = None
        if metrics is not None:
            if not os.path.exists(metrics):
                raise FileNotFoundError(f"File {metrics} doesn't exist :/ Do you really save the metrics, huh?")
            with open(metrics, "r") as infile:
                best_metrics = json.load(infile)
                
        # Tensorboard summary writer
        tensorboard_writer = SummaryWriter(TENSORBOARD_PATH, purge_step=last_epoch + 1, flush_secs=5)
        
        kwargs = {
            "dataset": dataset,
            "model": model,
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "tensorboard_writer": tensorboard_writer,
            "device": device
        }
        if best_metrics is not None:
            kwargs["best_metrics"] = best_metrics
        
        return cls(**kwargs)
    
    def fit(self,
            start_epoch: Optional[int] = START_EPOCH,
            epoch_num: Optional[int] = EPOCH_NUM,
            log_interval: Optional[int] = LOG_INTERVAL,
            best_model_dump: Optional[Union[str, os.PathLike]] = MODEL_WEIGTHS,
            model_snapshot_dump: Optional[Union[str, os.PathLike]] = MODEL_TRAINING_SNAPSHOT,
            training_log: Optional[Union[str, os.PathLike]] = TRAINING_LOG,
            best_metrics_dump: Optional[Union[str, os.PathLike]] = BEST_METRICS):
        # Mark the processing of one batch as a completed step
        num_steps = (epoch_num - start_epoch) * len(self.dataset.train.loader)
        progress_bar = tqdm(range(num_steps), position=0, leave=True)
        
        for epoch in range(start_epoch, epoch_num, 1):
            self.train_epoch(epoch, training_log, progress_bar)
            
            if epoch == 0 or (epoch + 1) % log_interval == 0:
                self.validate(epoch, best_model_dump, model_snapshot_dump, best_metrics_dump)
        
    def train_epoch(self, epoch: int,
                    training_log: Union[str, os.PathLike], progress_bar: tqdm):
        self.model.train()
        
        epoch_loss = []
        for _input, _target in self.dataset.train.loader:
            if self.device.type == "cuda":
                _input, _target = _input.cuda(), _target.cuda()
            predicted = self.model(_input)
            
            batch_size, seq_len, vocab_size = predicted.shape
            predicted = torch.reshape(predicted, (batch_size * seq_len, vocab_size))
            _target = torch.reshape(_target, (batch_size * seq_len,))
            loss = self.criterion(predicted, _target)
            epoch_loss.append(loss.cpu().item())
            
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
            progress_bar.update(1)
            
        mean_epoch_loss = sum(epoch_loss) / len(epoch_loss)
        self.best_metrics["train_loss"] = min(mean_epoch_loss, self.best_metrics["train_loss"])
        self.tensorboard_writer.add_scalar(f"{LOSS}/{TRAIN}", mean_epoch_loss, epoch + 1)
        with open(training_log, "a") as outfile:
            outfile.write(f"Epoch #{epoch + 1}\nMean epoch loss: {mean_epoch_loss}\n\n")
            
    def validate(self, epoch: int, best_model_dump: Union[str, os.PathLike], 
                 model_snapshot_dump: Union[str, os.PathLike], best_metrics_dump: Union[str, os.PathLike]):
        self.model.eval()
        
        with torch.no_grad():
            val_loss = []
            
            for _input, _target in self.dataset.validation.loader:
                if self.device.type == "cuda":
                    _input, _target = _input.cuda(), _target.cuda()
                    
                predicted = self.model(_input)
                batch_size, seq_len, vocab_size = predicted.shape
                predicted = torch.reshape(predicted, (batch_size * seq_len, vocab_size))
                _target = torch.reshape(_target, (batch_size * seq_len,))
                loss = self.criterion(predicted, _target)
                val_loss.append(loss.cpu().item())
                
        mean_val_loss = sum(val_loss) / len(val_loss)
        self.tensorboard_writer.add_scalar(f"{LOSS}/{VALIDATION}", mean_val_loss, epoch + 1)
        # Save model training snapshot
        snapshot = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "last_epoch": epoch + 1
        }
        self.atomic_write(snapshot, model_snapshot_dump)
        # Save new best weights
        if mean_val_loss < self.best_metrics["validation_loss"]:
            self.atomic_write(self.model.state_dict(), best_model_dump)
            self.best_metrics["validation_loss"] = mean_val_loss
            print(f"Epoch #{epoch + 1}\n*** New best validation loss: {mean_val_loss}!\n")
        self.atomic_write(self.best_metrics, best_metrics_dump, _format="json")
        
    @staticmethod
    def atomic_write(_object: Any, _file: Union[str, os.PathLike],
                     _format: Optional[str] = "torch"):
        if type(_file) is str:
            _file = Path(_file)
        base_folder = _file.parent
        tmp_filename = f"{_file.name.split('.')[0]}.tmp.{_file.name.split('.')[-1]}"
        tmp_file = base_folder/tmp_filename
        
        mode = "wb" if _format == "torch" else "w"
        with open(tmp_file, mode) as outfile:
            if _format == "json":
                json.dump(_object, outfile, ensure_ascii=False, indent=4)
            elif _format == "torch":
                torch.save(_object, outfile)
            outfile.flush()
            os.fsync(outfile.fileno())
        os.replace(tmp_file, _file)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = TransformerTrainer.init_trainer(config, device)

In [None]:
generator.fit()

#### Generate Samples

In [15]:
@define(kw_only=True)
class TheGirlPT:
    tokenizer: PreTrainedTokenizerFast
    model: Transformer
    
    @classmethod
    def from_pretrained(cls,
                        pretrained_tokenizer: Union[str, os.PathLike],
                        model_config: Union[str, os.PathLike],
                        pretrained_model: Union[str, os.PathLike],
                        use_cuda: Optional[bool] = True):
        tokenizer = PreTrainedTokenizerFast.from_pretrained(pretrained_tokenizer)
        
        config = Config.from_json(model_config)
        model = Transformer.from_config(config)
        device = torch.device(
            "cpu"
            if not use_cuda or not torch.cuda.is_available()
            else "cuda"
        )
        model.use_cuda = device.type == "cuda"
        model.load_state_dict(torch.load(pretrained_model))
        if model.use_cuda:
            model.cuda()
        model.eval()
        
        return cls(tokenizer=tokenizer, model=model)
    
    def encode_sequence(self, _input: str) -> torch.Tensor:
        token_ids = self.tokenizer.encode(_input, add_special_tokens=True, padding=False,
                                          truncation=True, max_length=self.model.max_seq_len,
                                          return_tensors="pt")[:, :-1]
        return token_ids
    
    def tokenize_sequence(self, _input: str) -> Sequence[Tuple[int, str]]:
        token_ids = self.tokenizer.encode(_input, add_special_tokens=False)
        tokens = [
            (_id, self.tokenizer.decode([_id], skip_special_tokens=True))
            for _id
            in token_ids
        ]
        return tokens
    
    def generate_sample(self,
                        seed_phrase: Optional[str] = "",
                        max_new_tokens: Optional[int] = 24,
                        temperature: Optional[float] = 1.0) -> str:
        seed_ids = self.encode_sequence(seed_phrase)
        if self.model.use_cuda:
            seed_ids = seed_ids.cuda()
            
        new_tokens_num = min(self.model.max_seq_len - seed_ids.shape[-1], max_new_tokens)
        for _ in range(new_tokens_num):
            logp_next = self.model(seed_ids)
            p_next = F.softmax(logp_next / temperature, dim=-1).cpu().data.numpy()[0][-1]

            # sample next token and push it back into x_sequence
            next_ix = np.random.choice(self.model.vocab_size, p=p_next)
            next_ix = torch.tensor([[next_ix]])
            if self.model.use_cuda:
                next_ix = next_ix.cuda()
            seed_ids = torch.cat((seed_ids, next_ix), dim=-1)
            
        seed_ids = seed_ids.cpu().data.numpy()[0]
        restored_text = self.tokenizer.decode(seed_ids, skip_special_tokens=True)
        restored_text = re.sub(r"##", "", restored_text)
        return restored_text

In [16]:
thegpt = TheGirlPT.from_pretrained(pretrained_tokenizer=PRETRAINED_TOKENIZER,
                                  model_config=MODEL_CONFIG,
                                  pretrained_model=MODEL_WEIGTHS)

In [19]:
print(thegpt.generate_sample(seed_phrase="Нон-байнари тест", max_new_tokens=10, temperature=0.5))

 Нон-байнари тест, и мы скажем, как ты умрешь
