In [None]:
import argparse
import json
import os
import random
from typing import Any, Dict, List, Union

import pytorch_lightning as pl
import torch
from quaterion_models.encoders import Encoder
from quaterion_models.heads import EncoderHead, GatedHead, SkipConnectionHead
from quaterion_models.types import CollateFnType
from torch.utils.data import Dataset

from quaterion import Quaterion, TrainableModel
from quaterion.loss import TripletLoss
from quaterion.dataset.similarity_data_loader import (
    GroupSimilarityDataLoader,
    SimilarityGroupSample,
)

import torch
import torch.nn as nn

from quaterion.loss import SimilarityLoss, SoftmaxLoss

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling

import pandas as pd


# Data Loading and Model Definition

In [None]:
with open("./labeled_data_train.json", "r") as f:
    train_data = json.loads(f.read())

with open("./labeled_data_val.json", "r") as f:
    val_data = json.loads(f.read())

with open("./labeled_data_test.json", "r") as f:
    test_data = json.loads(f.read())

In [None]:
from quaterion.dataset.similarity_data_loader import (
    GroupSimilarityDataLoader,
    SimilarityGroupSample,
)
import numpy as np

class JsonDataset(Dataset):
    def __init__(self, path: str):
        super().__init__()
        self.translation_dict = {
            "World" : 1,
            "Sports" : 2,
            "Business" : 3,
            "Sci/Tech" : 4
        }
        with open(path, "r") as f:
            self.data = json.loads(f.read())
            # self.data = [json.loads(line) for line in f.readlines()]

    def __getitem__(self, index: int) -> SimilarityGroupSample:
        item = self.data[index]
        return SimilarityGroupSample(obj=item, group=self.translation_dict[item["label"]])

    def __len__(self) -> int:
        return len(self.data)

In [None]:
class DescriptionEncoder(Encoder):
    def __init__(self, transformer: Transformer, pooling: Pooling):
        super().__init__()
        self.transformer = transformer
        self.pooling = pooling
        self.encoder = nn.Sequential(self.transformer, self.pooling)

    @property
    def trainable(self) -> bool:
        return False # Disable weights update for this encoder

    @property
    def embedding_size(self) -> int:
        return self.transformer.get_word_embedding_dimension()

    def forward(self, batch) -> torch.Tensor:
        return self.encoder(batch)["sentence_embedding"]

    def collate_descriptions(self, batch: List[Any]) -> torch.Tensor:
        descriptions = [record['Title'] + '. ' + record['Description'] for record in batch]
        return self.transformer.tokenize(descriptions)

    def get_collate_fn(self) -> CollateFnType:
        return self.collate_descriptions

    @staticmethod
    def _transformer_path(path: str) -> str:
        return os.path.join(path, "transformer")

    @staticmethod
    def _pooling_path(path: str) -> str:
        return os.path.join(path, "pooling")

    def save(self, output_path: str):

        transformer_path = self._transformer_path(output_path)
        os.makedirs(transformer_path, exist_ok=True)

        pooling_path = self._pooling_path(output_path)
        os.makedirs(pooling_path, exist_ok=True)

        self.transformer.save(transformer_path)
        self.pooling.save(pooling_path)

    @classmethod
    def load(cls, input_path: str) -> Encoder:
        transformer = Transformer.load(os.path.join(input_path, 'transformer'))
        pooling = Pooling.load(os.path.join(input_path, 'pooling'))
        return cls(transformer=transformer, pooling=pooling)

In [None]:
class Model(TrainableModel):
    def __init__(self, lr: float):
        self._lr = lr
        super().__init__()

    def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
        pre_trained = SentenceTransformer("all-MiniLM-L6-v2")
        transformer, pooling = pre_trained[0], pre_trained[1]
        return DescriptionEncoder(transformer, pooling)

    def configure_head(self, input_embedding_size) -> EncoderHead:
        return SkipConnectionHead(input_embedding_size)

    def configure_loss(self) -> SimilarityLoss:
        return TripletLoss()

    def configure_optimizers(self):
        return torch.optim.Adam( self.model.parameters(), lr=self._lr)

# Training

In [None]:
model = Model(lr=0.001)

train_dataloader = GroupSimilarityDataLoader(JsonDataset('./labeled_data_train.json'), batch_size=128)
val_dataloader = GroupSimilarityDataLoader(JsonDataset('./labeled_data_val.json'), batch_size=128)

Quaterion.fit(
    trainable_model=model,
    trainer=None, # Use default trainer
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader
)

In [None]:
model.save_servable("finetuned_model")

# Inspect Training Logs

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

# Evaluation

In [None]:
from quaterion_models.model import SimilarityModel

inf_model = SimilarityModel.load("finetuned_model")
raw_model = Model(lr=0.001)

In [None]:
test_df = pd.read_json("labeled_data_test.json", orient="records")
test_df_json = json.loads(test_df.to_json(orient="records"))

In [None]:
raw_test_embeddings = raw_model.model.encode(test_df_json)
finetuned_test_embeddings = inf_model.encode(test_df_json)

In [None]:
# Optional
torch.save(raw_test_embeddings, "raw_test_embeddings.pt")
torch.save(finetuned_test_embeddings, "finetuned_test_embeddings.pt")

In [None]:
from scipy.spatial.distance import cdist
def get_most_similar_idxs(embedding, embedding_matrix, n=10, embedding_in_matrix=True):
  dists = cdist(embedding.reshape(1,-1), embedding_matrix, metric="cosine")
  if(embedding_in_matrix):
    return np.argsort(dists[0])[1:n+1]
  else:
    return np.argsort(dists[0])[:n]

def get_top1k(model, test_df_json, pre_calculated_embeddings = None, k=1000):
  if(pre_calculated_embeddings is not None):
    embeddings = pre_calculated_embeddings
  else:
    embeddings = model.encode(test_df_json)
  test_idxs = np.random.randint(0, len(test_df_json), size = 250)
  top_1ks = []
  for idx in test_idxs:
    label = test_df_json[idx]["label"]
    dists = cdist(embeddings[idx].reshape(1,-1), embeddings, metric="cosine")
    top_1k_idxs = np.argsort(dists[0])[1:k+1]
    diffs = [1 if record["label"] != label else 0 for record in np.array(test_df_json)[top_1k_idxs]]
    top1k_metric = 1 - np.sum(diffs) / len(diffs)
    top_1ks.append(top1k_metric)
  return top_1ks

## Top_1k

In [None]:
top1k_tuned = get_top1k(None, test_df_json, pre_calculated_embeddings = finetuned_test_embeddings, k=1000)
top1k_raw = get_top1k(None, test_df_json, pre_calculated_embeddings = raw_test_embeddings,k=1000)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

ax = sns.violinplot(data=[top1k_tuned, top1k_raw])
ax.set_xticklabels(["fine-tuned", "raw"])
ax.set_yticks([0.0,0.2,0.4,0.6,0.8,1.0])
ax.set_title("Top_1k metric (higher -> better)")
# plt.savefig("figures/top_1k_distribution.png", dpi=300)

## Top_k

In [None]:
tuned = []
raw = []
eval_ks = [5,10,25,50,75,100,250]
for k in eval_ks:
  top1k_tuned_k = get_top1k(None, test_df_json, pre_calculated_embeddings = finetuned_test_embeddings, k=k)
  top1k_raw_k = get_top1k(None, test_df_json, pre_calculated_embeddings = raw_test_embeddings,k=k)
  tuned.append(np.mean(top1k_tuned_k))
  raw.append(np.mean(top1k_raw_k))

In [None]:
import matplotlib.pyplot as plt

plt.plot(eval_ks, tuned, marker=".", label="fine-tuned")
plt.plot(eval_ks, raw, c="red", marker=".", label="raw")
plt.xticks(eval_ks[1:])
plt.title("top_k metric for different values for k")
plt.legend()
# plt.savefig("figures/top_k_metric.png", dpi=300)

## Classification

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

In [None]:
with open("./labeled_data_train.json", "r") as f:
    train_data = json.loads(f.read())

raw_embeddings_train = raw_model.model.encode(train_data)
finetuned_embeddings_train = inf_model.encode(train_data)

In [None]:
labels = [record["label"] for record in train_data]

In [None]:
clf_raw = LogisticRegression(random_state=42).fit(raw_embeddings_train, labels)
clf_tuned = LogisticRegression(random_state=42).fit(finetuned_embeddings_train, labels)

In [None]:
y_pred_raw = clf_raw.predict(raw_test_embeddings)
y_pred_tuned = clf_tuned.predict(finetuned_test_embeddings)

In [None]:
print(classification_report([record["label"] for record in test_data], y_pred_raw))

In [None]:
print(classification_report(test_df["label"].values, y_pred_tuned))