In [1]:
import os
import json
from pathlib import Path
from attrs import define, Factory, asdict
from datetime import datetime

from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim import lr_scheduler as torch_lr_scheduler
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import PreTrainedTokenizerFast

from tqdm import tqdm

from typing import (
    Union,
    Iterable,
    Optional,
    Sequence,
    Tuple,
    Any,
)

## Constants

In [2]:
# Data paths

DATA_PATH = Path(".")/"input"
ALL_TERMS = DATA_PATH/"all_titles.txt"
TRAIN_TERMS = DATA_PATH/"train_titles.txt"
VAL_TERMS = DATA_PATH/"validation_titles.txt"

# Tokenizer

PRETRAINED_TOKENIZER = Path(".")/"model"/"tokenizer"/"pretrained_tokenizer"

# Model

MODEL_PATH = Path(".")/"model"/"generator"
MODEL_CONFIG = MODEL_PATH/"config.json"
MODEL_WEIGTHS = MODEL_PATH/"title_generator.sav"
MODEL_TRAINING_SNAPSHOT = MODEL_PATH/"snapshot"/"generator_training_snapshot.sav"
BEST_METRICS = MODEL_PATH/"snapshot"/"best_metrics.json"
TRAINING_LOG = MODEL_PATH/"snapshot"/"training_log.txt"

# Tensorboard

TENSORBOARD_PATH = Path(".")/"img"/"tensorboard"
TRAIN = "Train"
VALIDATION = "Validation"
LOSS = "Loss"

# Model parameters

BATCH_SIZE = 32
MAX_SEQ_LEN = 30  # see query length distribution histogram
VOCAB_SIZE = 1528  # see the pre-trained tokenizer parameters
HEAD_NUM = 6
EMBEDDING_DIM = 128 * HEAD_NUM
LAYER_NUM = 6
DROPOUT = 0.2

EPOCH_NUM = 1000
START_EPOCH = 0
LOG_INTERVAL = 10

## Config

In [3]:
@define(kw_only=True)
class Config:
    attention_head_num: int = HEAD_NUM
    embedding_dim: int = EMBEDDING_DIM
    decoder_layer_num: int = LAYER_NUM
    max_sequence_length: int = MAX_SEQ_LEN
    batch_size: int = BATCH_SIZE
    dropout: int = DROPOUT
    vocab_size: int = VOCAB_SIZE
    tokenizer: Union[str, os.PathLike] = PRETRAINED_TOKENIZER
    train_data: Union[str, os.PathLike] = TRAIN_TERMS
    validation_data: Union[str, os.PathLike] = VAL_TERMS
    
    @classmethod
    def from_json(cls, config_json: Union[str, os.PathLike]):
        if not os.path.exists(config_json):
            raise FileNotFoundError(f"Couldn't find {config_json}")
            
        with open(config_json, "r") as infile:
            config = json.load(infile)
            config["tokenizer"] = Path(config["tokenizer"])
            config["train_data"] = Path(config["train_data"])
            config["validation_data"] = Path(config["validation_data"])
            return cls(**config)
        
    def to_json(self, config_json: Union[str, os.PathLike]):
        config_dict = asdict(self)
        config_dict["tokenizer"] = str(self.tokenizer)
        config_dict["train_data"] = str(self.train_data)
        config_dict["validation_data"] = str(self.validation_data)
        
        with open(config_json, "w") as outfile:
            json.dump(config_dict, outfile, ensure_ascii=False, indent=4)

## Data

#### Preprocessing

In [4]:
def load_data(data_path: Union[str, os.PathLike],
              encoding: Optional[str] = "utf-8") -> Iterable[str]:
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Couldn't find {data_path}")
    with open(data_path, "r", encoding=encoding) as infile:
        for line in infile:
            yield line.strip()

In [5]:
def dump_data_to_txt(data: Sequence[str],
                     data_path: Union[str, os.PathLike],
                     encoding: Optional[str] = "utf-8"):
    with open(data_path, "w", encoding=encoding) as outfile:
        for item in data:
            outfile.write(f"{item}\n")

In [6]:
dataset_is_split = os.path.exists(TRAIN_TERMS) and os.path.exists(VAL_TERMS)

if not dataset_is_split:
    data = list(load_data(ALL_TERMS))
    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
    dump_data_to_txt(train_data, TRAIN_TERMS)
    dump_data_to_txt(val_data, VAL_TERMS)

#### Dataset

In [7]:
class SampleSet:
    def __init__(self,
                 data: Sequence[str],
                 tokenizer: PreTrainedTokenizerFast,
                 max_seq_len: int, batch_size: int,
                 use_cuda: Optional[bool] = True,
                 gpu_num: Optional[int] = 1):
        features, targets = self.make_samples(data, tokenizer, max_seq_len)
        kwargs = self.make_kwargs(use_cuda, gpu_num)
        
        self.set = TensorDataset(features, targets)
        self.loader = DataLoader(self.set, batch_size=batch_size,
                                drop_last=True, shuffle=True, **kwargs)
    
    @staticmethod
    def make_samples(data: Sequence[str], tokenizer: PreTrainedTokenizerFast,
                     max_seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        features, targets = [], []
        for item in data:
            token_ids = tokenizer.encode(item, add_special_tokens=True, padding="max_length",
                                         truncation=True, max_length=max_seq_len + 1,
                                         return_tensors="pt").squeeze(dim=0)
            features.append(token_ids[:-1])
            targets.append(token_ids[1:])
            
        features = torch.stack(features)
        targets = torch.stack(targets)
        
        return features, targets
    
    @staticmethod
    def make_kwargs(use_cuda: Optional[bool] = True, gpu_num: Optional[int] = 1):
        print(f"Using CUDA: {use_cuda}")
        kwargs = {
            "num_workers": 4 * gpu_num,
            "pin_memory": True
        } if use_cuda else {}
        return kwargs

In [8]:
@define(kw_only=True)
class Dataset:
    train: SampleSet
    validation: SampleSet
    
    @classmethod
    def from_config(cls, config: Config):
        tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer)
        train_data = list(load_data(config.train_data))
        val_data = list(load_data(config.validation_data))
        return cls(
            train=SampleSet(data=train_data, tokenizer=tokenizer,
                            max_seq_len=config.max_sequence_length,
                            batch_size=config.batch_size),
            validation=SampleSet(data=val_data, tokenizer=tokenizer,
                                 max_seq_len=config.max_sequence_length,
                                 batch_size=config.batch_size)
        )

In [9]:
if os.path.exists(MODEL_CONFIG):
    config = Config.from_json(MODEL_CONFIG)
else:
    config = Config()
    config.to_json(MODEL_CONFIG)

## Model

#### Model Architecture

In [10]:
class AttentionHead(nn.Module):
    def __init__(self, embedding_dim: int, head_dim: int, max_seq_len: int, dropout: float):
        super(AttentionHead, self).__init__()
        # Query, Key, Value layers
        self.query = nn.Linear(embedding_dim, head_dim, bias=False)
        self.key = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value = nn.Linear(embedding_dim, head_dim, bias=False)
        
        # A lower triangular matrix to depreciate context from the right
        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)))
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, _input):
        batch_size, seq_len, embedding_dim = _input.shape
        k = self.key(_input)
        q = self.query(_input)
        
        # Attention scores --> batch_size x seq_len x seq_len
        attention_scores = q @ k.transpose(-2, -1) * embedding_dim ** -0.5
        # Set forward context to -inf
        attention_scores = attention_scores.masked_fill(
            self.mask[:seq_len, :seq_len] == 0,
            float("-inf")
        )
        attention_scores = F.softmax(attention_scores, dim=-1)
        attention_scores = self.dropout(attention_scores)
        
        # Aggregate values by attention
        v = self.value(_input)
        attention = attention_scores @ v  # batch_size x seq_len x head_dim
        return attention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim: int, head_dim: int, head_num: int,
                 max_seq_len: int, dropout: float):
        super(MultiHeadAttention, self).__init__()
        self.attention_heads = nn.ModuleList([
            AttentionHead(embedding_dim=embedding_dim, head_dim=head_dim,
                          max_seq_len=max_seq_len, dropout=dropout)
            for _ in range(head_num)
        ])
        self.output_layer = nn.Linear(head_dim * head_num, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, _input):
        _output = torch.cat(
            [
                attention_head(_input)
                for attention_head
                in self.attention_heads
            ],
            dim=-1
        )
        _output = self.dropout(self.output_layer(_output))
        return _output