## Data

In [None]:
import glob
import json
import re
from pathlib import Path
from typing import List


class Item:
    item_id: int
    content: str
    anchor: str
    entailment: str
    contradiction: str
    irrelevance: str
    subject: List[str]

    def __init__(self, item_id: int, json_str: str) -> None:
        obj = json.loads(json_str)
        self.item_id = item_id
        self.content = obj["content"]
        self.anchor = obj["passage"]["anchor"][0]
        self.entailment = obj["passage"]["entailment"][0]
        self.contradiction = obj["passage"]["contradiction"][0]
        self.irrelevance = obj["passage"]["irrelevance"][0]
        self.subject = list(map(self._process_item_text, obj["passage"]["subject"]))

    @staticmethod
    def fetch_items(limit: int) -> List["Item"]:
        items: List[Item] = []
        for item_file_path in sorted(glob.glob("./data/*.json"))[:limit]:
            with open(item_file_path) as item_file:
                item_id = int(Path(item_file_path).stem)
                items.append(Item(item_id, item_file.read()))
        return items

    @staticmethod
    def _process_item_text(item_text: str) -> str:
        return re.sub("^(\d+\.|-|\*)", "", item_text.strip()).strip()


items = Item.fetch_items(50000)

In [None]:
from typing import Tuple

from sentence_transformers import InputExample


def generate_data(
    items: List[Item], queryable_train_ratio: float,
) -> Tuple[List[Item], List[Item], List[InputExample], List[Tuple[str, ...]]]:
    queryable_items: List[Item] = []
    unqueryable_items: List[Item] = []
    for item in items:
        if len(item.subject) > 0:
            queryable_items.append(item)
        else:
            unqueryable_items.append(item)
    # Prepare train and val data
    queryable_train_data_len = int(queryable_train_ratio * len(queryable_items))
    train_items = queryable_items[:queryable_train_data_len] + unqueryable_items
    val_items = queryable_items[queryable_train_data_len:]
    train_data = []
    val_data = []
    for item in train_items:
        train_data.append(InputExample(texts=[item.content, item.anchor]))
    for item in val_items:
        val_data.append((item.content, item.anchor))
    return train_items, val_items, train_data, val_data


QUERYABLE_TRAIN_RATIO = 0.9

_train_items, val_items, train_data, _val_data = generate_data(items, QUERYABLE_TRAIN_RATIO)
print("Train data length:", len(train_data))
print("Val items length:", len(val_items))

## Training

In [None]:
import shutil
from typing import Optional

import faiss
import numpy as np
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import SentenceEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss


class AvgRankCalculator:
    model: SentenceTransformer
    index: faiss.IndexFlatIP
    queries: List[List[str]]

    def __init__(self, model: SentenceTransformer, items: List[Item]):
        sentences = [item.content for item in items]
        sentence_embeddings = model.encode(sentences)
        faiss.normalize_L2(sentence_embeddings)
        _, size = sentence_embeddings.shape
        self.model = model
        self.index = faiss.IndexFlatIP(size)
        self.index.add(sentence_embeddings)
        self.queries = [item.subject for item in items]

    def search(self, queries: List[str], limit: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
        if limit is None:
            limit = self.index.ntotal
        item_query_embeddings = self.model.encode(queries)
        faiss.normalize_L2(item_query_embeddings)
        similarities, indices = self.index.search(item_query_embeddings, limit)
        return similarities, indices

    def calc_avg_rank(self) -> float:
        count = 0
        index_sum = 0
        for i, item_queries in enumerate(self.queries):
            _similarities, indices = self.search(item_queries)
            _hit_subject, hit_indices = np.asarray(indices == i).nonzero()
            index_sum += hit_indices.sum()
            count += len(hit_indices)
        return index_sum / count


class AvgRankEvaluator(SentenceEvaluator):
    items: List[Item]

    def __init__(self, items: List[Item]):
        self.items = items

    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        rank = AvgRankCalculator(model, self.items).calc_avg_rank()
        print(rank)
        return rank


def fine_tune(
    model: SentenceTransformer, train_data: List[InputExample], val_items: List[Item],
    output_path: str, batch_size: int, epochs: int,
):
    train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    train_loss = MultipleNegativesRankingLoss(model)
    evaluator = AvgRankEvaluator(val_items)
    evaluator(model)
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=epochs,
        warmup_steps=int(len(train_dataloader) * epochs * 0.1),
        evaluator=evaluator,
        output_path=output_path,
    )


# MODEL_ID = "all-mpnet-base-v2"
# MODEL_ID = "all-distilroberta-v1"
MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
# MODEL_ID = "jinaai/jina-embeddings-v2-small-en"
OUTPUT_PATH = "./embedder/"
BATCH_SIZE = 16
EPOCHS = 2

model = SentenceTransformer(MODEL_ID)
shutil.rmtree(OUTPUT_PATH, ignore_errors=True)
for e in range(EPOCHS):
    print(f"EPOCH {e}")
    for i in range(0, len(train_data), len(train_data) // 10):
        train_data_chunk = train_data[i:i + len(train_data) // 10]
        fine_tune(model, train_data_chunk, val_items, OUTPUT_PATH, BATCH_SIZE, 1)

## Examination

In [None]:
exam_items = [item for item in items if len(item.subject) > 0]  # queryable_items
old_calculator = AvgRankCalculator(SentenceTransformer(MODEL_ID), exam_items)
new_calculator = AvgRankCalculator(model, exam_items)

In [None]:
print("Old avg rank:", old_calculator.calc_avg_rank())
print("New avg rank:", new_calculator.calc_avg_rank())

In [None]:
for i, item in enumerate(exam_items):
    query = item.subject[0]
    if len(query.split(" ")) > 5:
        continue
    _old_similarities, old_indices = old_calculator.search([query], limit=10)
    _new_similarities, new_indices = new_calculator.search([query], limit=10)
    old_indices, new_indices = old_indices[0], new_indices[0]
    old_find = np.asarray(old_indices == i).nonzero()[0]
    new_find = np.asarray(new_indices == i).nonzero()[0]
    old_index = None if old_find.size == 0 else old_find[0]
    new_index = None if new_find.size == 0 else new_find[0]
    if (old_index is None and new_index is not None) \
            or (old_index is not None and new_index is not None and old_index - new_index > 5):
        print(i, exam_items[i].item_id, old_index, new_index)
        break