## Data

In [None]:
import glob
import json
import re
from pathlib import Path
from typing import List


class Item:
    item_id: int
    content: str
    anchor: str
    entailment: str
    contradiction: str
    irrelevance: str
    subject: List[str]

    def __init__(self, item_id: int, json_str: str) -> None:
        obj = json.loads(json_str)
        self.item_id = item_id
        self.content = obj["content"]
        self.anchor = obj["passage"]["anchor"][0]
        self.entailment = obj["passage"]["entailment"][0]
        self.contradiction = obj["passage"]["contradiction"][0]
        self.irrelevance = obj["passage"]["irrelevance"][0]
        self.subject = list(map(self._process_item_text, obj["passage"]["subject"]))

    @staticmethod
    def fetch_items() -> List["Item"]:
        items: List[Item] = []
        for item_file_path in glob.glob("./data/*.json"):
            with open(item_file_path) as item_file:
                item_id = int(Path(item_file_path).stem)
                items.append(Item(item_id, item_file.read()))
        return items

    @staticmethod
    def _process_item_text(item_text: str) -> str:
        return re.sub("^(\d+\.|-|\*)", "", item_text.strip()).strip()

In [None]:
import random
from typing import Tuple

from sentence_transformers import InputExample


def generate_data(items: List[Item], train_ratio: float) -> Tuple[List[Item], List[Item], List[InputExample], List[Tuple[str, ...]]]:
    # Prepare primary items (items with subjects) and secondary items (items without subjects)
    random.shuffle(items)
    primary_items: List[Item] = []
    secondary_items: List[Item] = []
    for item in items:
        if len(item.subject) > 0:
            primary_items.append(item)
        else:
            secondary_items.append(item)
    # Prepare train and val data
    train_data_len = int(train_ratio * len(primary_items))
    train_items = primary_items[:train_data_len] + secondary_items
    val_items = primary_items[train_data_len:]
    train_data = []
    val_data = []
    for item in train_items:
        train_data.append(InputExample(texts=[item.anchor, item.entailment, item.contradiction]))
        train_data.append(InputExample(texts=[item.anchor, item.entailment, item.irrelevance]))
    for item in val_items:
        val_data.append((item.anchor, item.entailment, item.contradiction))
        val_data.append((item.anchor, item.entailment, item.irrelevance))
    return train_items, val_items, train_data, val_data

## Evaluation

In [None]:
import csv
import os

from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.evaluation.TripletEvaluator import (
    SimilarityFunction,
    paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
)


class CustomTripletEvaluator(TripletEvaluator):
    # See: https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/evaluation/TripletEvaluator.py

    def __init__(
        self,
        anchors: List[str],
        positives: List[str],
        negatives: List[str],
        main_distance_function: SimilarityFunction,
        loss_margin: float,
        batch_size: int = 16,
    ):
        super().__init__(
            anchors, positives, negatives,
            main_distance_function=main_distance_function, batch_size=batch_size,
        )
        self.loss_margin = loss_margin
        self.csv_headers = ["epoch", "steps", "val_accuracy", "val_loss"]

    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        embeddings_anchors = model.encode(
            self.anchors, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
        )
        embeddings_positives = model.encode(
            self.positives, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
        )
        embeddings_negatives = model.encode(
            self.negatives, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
        )
        triplets_count = 0
        correct_triplets_count = 0
        loss = 0
        if self.main_distance_function == SimilarityFunction.MANHATTAN:
            pos_distances = paired_manhattan_distances(embeddings_anchors, embeddings_positives)
            neg_distances = paired_manhattan_distances(embeddings_anchors, embeddings_negatives)
        if self.main_distance_function == SimilarityFunction.EUCLIDEAN:
            pos_distances = paired_euclidean_distances(embeddings_anchors, embeddings_positives)
            neg_distances = paired_euclidean_distances(embeddings_anchors, embeddings_negatives)
        else:
            pos_distances = paired_cosine_distances(embeddings_anchors, embeddings_positives)
            neg_distances = paired_cosine_distances(embeddings_anchors, embeddings_negatives)
        for pos_distance, neg_distance in zip(pos_distances, neg_distances):
            triplets_count += 1
            if pos_distance < neg_distance:
                correct_triplets_count += 1
            # See: https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/losses/TripletLoss.py
            loss += pos_distance - neg_distance + self.loss_margin
        accuracy = correct_triplets_count / triplets_count
        loss /= triplets_count
        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            if not os.path.isfile(csv_path):
                os.makedirs(output_path, exist_ok=True)
                with open(csv_path, newline="", mode="w", encoding="utf-8") as csv_file:
                    writer = csv.writer(csv_file)
                    writer.writerow(self.csv_headers)
                    writer.writerow([epoch, steps, accuracy, loss])
                    csv_file.flush()
            else:
                with open(csv_path, newline="", mode="a", encoding="utf-8") as csv_file:
                    writer = csv.writer(csv_file)
                    writer.writerow([epoch, steps, accuracy, loss])
                    csv_file.flush()
        return accuracy

In [None]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import SentenceEvaluator


class SimilarityEvaluator(SentenceEvaluator):
    index: faiss.IndexFlatIP
    queries: List[List[str]]
    limit: int

    def __init__(self, model: SentenceTransformer, items: List[Item], limit: int):
        sentences = [item.content for item in items]
        sentence_embeddings = model.encode(sentences)
        faiss.normalize_L2(sentence_embeddings)
        _, size = sentence_embeddings.shape
        self.index = faiss.IndexFlatIP(size)
        self.index.add(sentence_embeddings)
        self.queries = [item.subject for item in items]
        self.limit = limit

    def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        hit_count = 0
        for i, item_queries in enumerate(self.queries):
            item_query_embeddings = model.encode(item_queries)
            faiss.normalize_L2(item_query_embeddings)
            _distances, ids = self.index.search(item_query_embeddings, self.limit)
            hit_subject, _hit_indices = np.asarray(ids == i).nonzero()
            if len(hit_subject) == len(ids):
                hit_count += 1
        return hit_count / len(self.queries)

## Fine-tuning

In [None]:
from torch.utils.data import DataLoader
from sentence_transformers.losses import TripletDistanceMetric, TripletLoss


def fine_tune(
    model: SentenceTransformer, train_data: List[InputExample], val_data: List[Tuple[str, ...]],
    output_path: str, batch_size: int, epochs: int,
):
    LOSS_MARGIN = 2  # Max value for cosine distance, which has a range of (0, 2)
    train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    train_loss = TripletLoss(model, distance_metric=TripletDistanceMetric.COSINE, triplet_margin=LOSS_MARGIN)
    evaluator = CustomTripletEvaluator(
        *map(list, zip(*val_data)),
        SimilarityFunction.COSINE, LOSS_MARGIN, batch_size=batch_size,
    )
    evaluator(model, output_path=os.path.join(output_path, "eval"))
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=epochs,
        warmup_steps=int(len(train_dataloader) * epochs * 0.1),
        evaluator=evaluator,
        output_path=output_path,
    )

## Run

In [None]:
TRAIN_RATIO = 0.5

_train_items, val_items, train_data, val_data = generate_data(Item.fetch_items(), TRAIN_RATIO)

In [None]:
import shutil

MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
OUTPUT_PATH = "./embedder/"
BATCH_SIZE = 16
EPOCHS = 5

model = SentenceTransformer(MODEL_ID)
shutil.rmtree(OUTPUT_PATH, ignore_errors=True)
old_hit_rate = SimilarityEvaluator(model, val_items, 10)(model)
fine_tune(model, train_data, val_data, OUTPUT_PATH, BATCH_SIZE, EPOCHS)
new_hit_rate = SimilarityEvaluator(model, val_items, 10)(model)
print(f"Hit rate: old={old_hit_rate}, new={new_hit_rate}")