In [None]:
%cd ../

In [None]:
import sys
import json
from pathlib import Path
from typing import Any
from datetime import datetime

import polars as pl
import yaml
import torch
import numpy as np
import lightning as L
import einops
from transformers import BertModel
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from torch.optim import AdamW
from torch.nn import Module
from torch import Tensor
from polars import DataFrame
from loguru import logger
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

In [None]:
IS_DEBUG = True

logger.remove()
logger.add(sys.stderr, level="DEBUG" if IS_DEBUG else "INFO")

# Load things

## Load config

In [None]:
path = "colbert/configs.yaml"

with open(path) as file:
    conf = yaml.safe_load(file)

conf

## Load processed

In [None]:
tokenizer = AutoTokenizer.from_pretrained(conf['PATHS']['tokenizer'])

TOK_ID_MASK, TOK_ID_PAD = tokenizer.convert_tokens_to_ids(['[MASK]', '[PAD]'])

In [None]:
queries = pl.read_parquet(conf['PROCESSED']['query'])
queries.head()

In [None]:
path_raw = Path(conf['PROCESSED']['corpus'].replace("[i]", '*'))
paths = Path(path_raw.parent).glob(path_raw.stem)

if IS_DEBUG:
    logger.debug("IS_DEBUG: Load part of corpus")

    corpus = pl.read_parquet(list(paths)[0])
else:
    logger.debug("Load full corpus")

    corpus = pl.concat([pl.read_parquet(path) for path in paths])

corpus = corpus.with_columns(pl.col('did').cast(pl.Int64))

corpus.head()

## Load punctuations

In [None]:
version = datetime.now().strftime("%m-%d_%H-%M-%S")
model_name = "ColBERT"

with open(conf['PATHS']['punctuations']) as file:
    map_punct2ids = json.load(file)

punctuations = set(map_punct2ids.values())

# Define dataset and model

## Define Dataset and data loader

In [None]:
def _load_pairs(split: str) -> DataFrame:
    """Load positive pairs from train/val/tes split

    Args:
        split (str): split

    Returns:
        DataFrame: positive pairs
    """

    assert split in ['train', 'val', 'test']
    assert Path(conf['RAW_DATA'][split]).exists()

    pairs = (
        pl
        .read_csv(conf['RAW_DATA']['train'], separator='\t')
        .select(
            pl.col('query-id').alias('qid'),
            pl.col('corpus-id').alias('did')
        )
    )

    return pairs

# pairs = _load_pairs('train')
# pairs.head()

In [None]:
class Data(Dataset):
    def __init__(
        self,
        split: str,
        queries: DataFrame,
        corpus: DataFrame,
        punctuations: set,
        col_query_id: str = "qid",
        col_corpus_id: str = "did",
        col_tok_ids: str = "tok_ids",
    ):
        super().__init__()

        self.split = split
        self.punctuations = punctuations
        self.queries = queries
        self.col_query_id = col_query_id
        self.col_corpus_id = col_corpus_id
        self.col_tok_ids = col_tok_ids

        pairs = _load_pairs(split)

        # Reducing the size of corpus and positive pairs
        self.corpus = corpus.join(
            pairs.select(col_corpus_id).unique(), on=col_corpus_id, how="inner"
        )
        self.pairs = pairs.join(corpus, on=col_corpus_id, how="inner")

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, index):
        entry = self.pairs[index]
        qid, did = entry[self.col_query_id].item(), entry[self.col_corpus_id].item()

        tok_ids_query = (
            self.queries
            .filter(pl.col(self.col_query_id) == pl.lit(qid))[self.col_tok_ids]
            .item()
            .to_numpy()
            .copy()
        )  # fmt: skip
        tok_ids_doc = (
            self.corpus
            .filter(pl.col(self.col_corpus_id) == did)[self.col_tok_ids]
            .item()
            .to_numpy()
            .copy()
        )  # fmt: skip
        mask = torch.tensor([tok in self.punctuations or tok == 0 for tok in tok_ids_doc])

        attention_mask_query = (tok_ids_query != TOK_ID_MASK).astype(np.int32)
        attention_mask_doc = (tok_ids_doc != TOK_ID_PAD).astype(np.int32)

        return {
            "query": tok_ids_query,
            "attention_mask_query": attention_mask_query,
            "attention_mask_doc": attention_mask_doc,
            "doc": tok_ids_doc,
            "mask": mask,
        }


# data = Data("train", queries, corpus, punctuations)
# # data[123]
# loader = DataLoader(data, batch_size=int(conf["BSZ"]), shuffle=True)
# for batch in loader:
#     break

# batch

## Define model

In [None]:
MAKS_VAL = -1e10

class ColBERT(Module):
    def __init__(
        self,
        bert_model: str,
        size_vocab: int,
        d_hid: int = 128,
        d_hid_bert: int = 768,
    ):
        super().__init__()

        self.bert = BertModel.from_pretrained(bert_model)
        self.bert.resize_token_embeddings(size_vocab)

        self.linear = nn.Linear(d_hid_bert, d_hid, bias=False)

    def forward(self, X: Tensor, attention_mask: Tensor) -> Tensor:
        # X: [bz, n]

        X = self.bert(X, attention_mask).last_hidden_state
        # [bz, n, d_hid_bert]

        X = self.linear(X)
        # [bz, n, d_hid]

        # X = X / X.norm(dim=-1, keepdim=True)
        X = torch.nn.functional.normalize(X, p=2, dim=2)

        return X

    def trigger_train(self,
            query: Tensor,
            doc: Tensor,
            mask: Tensor,
            attention_mask_query: Tensor,
            attention_mask_doc: Tensor,
        ) -> Tensor:
        # query: [bz, Nd]
        # doc, mask: [bz, L]

        bz, Nd = query.shape

        ###################################################
        # Encode query and document
        ###################################################
        query = self.forward(query, attention_mask_query)
        # [bz, Nd, d_hid]
        doc = self.forward(doc, attention_mask_doc)
        # [bz, L, d_hid]


        ###################################################
        # Calculate the similarity
        ###################################################
        # Apply in-batch negative sampling
        query = einops.repeat(query, "b n d -> b repeat n d", repeat=bz)

        doc = einops.repeat(doc, "b l d -> repeat b l d", repeat=bz)
        doc = einops.rearrange(doc, "b a l d -> b a d l")

        sim = einops.einsum(query, doc, "b a n d, b a d l -> b a n l")
        # [bz, bz, Nd, L]

        # Mask positions which are the punctuation
        mask = einops.repeat(mask, "b l -> repeat1 b repeat2 l", repeat1=bz, repeat2=Nd)
        sim = sim.masked_fill(mask, MAKS_VAL)
        
        # Calculate score
        score = (sim.max(dim=-1).values).sum(dim=-1)
        # [bz, bz]


        # Calculate Listwise CE
        tgt = torch.arange(bz, dtype=torch.long, device=score.device)
        # [bz]

        loss = nn.functional.cross_entropy(score, tgt)

        return score, loss

# model = ColBERT(conf['MODEL_NAME'], len(tokenizer))
# data = Data("train", queries, corpus, punctuations)

# loader = DataLoader(data, batch_size=int(conf["BSZ"]), shuffle=True)
# for batch in loader:
#     score, loss = model.trigger_train(**batch)
#     logger.debug(f"loss: {loss}")

#     break


In [None]:
class LitModel(L.LightningModule):
    def __init__(
        self,
        params: dict,
        lr: float = 1e-3,
        num_epochs: int = 10
    ) -> None:
        super().__init__()
        self.save_hyperparameters()

        self.lr = lr
        self.num_epochs = num_epochs

        self.model = ColBERT(**params)

    def forward(self, meal: Tensor) -> Any:
        return self.model(meal)

    def training_step(self, batch, batch_idx):
        _, loss = self.model.trigger_train(**batch)

        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr)


        # scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=self.num_epochs)
        # return {"optimizer": optimizer, "lr_scheduler": scheduler}
        
        return optimizer


# Train

In [None]:
data = Data("train", queries, corpus, punctuations)
loader = DataLoader(data, batch_size=int(conf["BSZ"]), shuffle=True)

In [None]:
params = {
    'bert_model': conf['MODEL_NAME'],
    'size_vocab': len(tokenizer),
    'd_hid': conf['D_HID'],
    'd_hid_bert': conf['D_HID_BERT'],
}
litmodel = LitModel(params, lr=float(conf['LR']), num_epochs=conf['NUM_EPOCHS'])
# litmodel.load_state_dict(state_dict)

In [None]:
path_ckpt = Path(conf['PATHS']['ckpt'])

trainer = L.Trainer(
    # devices=0,
    callbacks=[
        RichProgressBar(leave=True),
        LearningRateMonitor(logging_interval='step'),
        ModelCheckpoint(
            dirpath=path_ckpt / model_name,
            filename=f"{path_ckpt.stem}_{{epoch}}",
            every_n_epochs=2
        )
    ],
    logger=TensorBoardLogger(conf['PATHS']['logs'], name=model_name, version=version, default_hp_metric=False),
    # gradient_clip_val=1,
    max_epochs=conf['NUM_EPOCHS'],
)

In [None]:
trainer.fit(
    litmodel,
    loader,
    # ckpt_path='weights/embedding_tuning/stage-1_02-26_12-52-00_epoch=11.ckpt'
)