# SBERT

This notebook shows how the model that is used in the report is trained and created. It consists of multiple steps:
1. Downloading the data from the MSMARCO website.
2. Preparing the data for loading into the model.
3. Creating a datamodule for loading the data into the model.
4. Defining the model module, which defines the structure of the model, the loss function, the optimizer, and the metrics used to evaluate the model.
5. Finding the optimal threshold for the model.
6. Training the model.
7. Making the plot of the model size vs the F1 score.

## Downloading the Data

This part of the notebook simply downloads the necessary data from the MSMARCO website. The data is then stored in the `data/raw` folder for later processing.

In [None]:
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz data/raw/qidpidtriples.train.full.2.tsv.gz
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz data/raw/msmarco-test2019-queries.tsv.gz
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz data/raw/collection.tar.gz
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz data/raw/msmarco-passagetest2019-top1000.tsv.gz
!wget https://trec.nist.gov/data/deep/2019qrels-pass.txt data/raw/2019qrels-pass.txt
!wget https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz data/raw/msmarco-test2019-queries.tsv.gz
!gzip -d data/raw/qidpidtriples.train.full.2.tsv.gz
!gzip -d data/raw/msmarco-test2019-queries.tsv.gz
!gzip -d data/raw/msmarco-passagetest2019-top1000.tsv.gz
!tar -xvf data/raw/collection.tar.gz -C data/raw/

## Preparing the Data for Loading

This part of the notebook prepares the data by converting the data from multiple TSV files into the [WebDataset format](https://webdataset.github.io/webdataset/) which can then be loaded efficiently into a DataLoader. This step also processes the triple dataset into a dataset of positive and negative pairs, such that each triple sample is split into a positive and negative pair. Moreover, this step also creates the a validation set from the training set, and creates a test set from the MSMARCO test set.

In [None]:
import click
import pandas as pd
from tqdm import tqdm


def load_passages(corpus_path: str) -> pd.DataFrame:
    return pd.read_csv(
        corpus_path,
        sep="\t",
        names=["passage_id", "passage"],
        dtype={"passage_id": int, "passage": str},
    ).set_index("passage_id")


def load_queries(queries_path: str) -> pd.DataFrame:
    return pd.read_csv(
        queries_path,
        sep="\t",
        names=["query_id", "query"],
        dtype={"query_id": int, "query": str},
    ).set_index("query_id")


def load_qrels(qrels_path: str) -> pd.DataFrame:
    return pd.read_csv(
        qrels_path,
        sep="\t",
        names=["query_id", "positive_passage_id", "negative_passage_id"],
        dtype={
            "query_id": int,
            "positive_passage_id": int,
            "negative_passage_id": int,
        },
    )


def load_test_qrels(qrels_path: str) -> pd.DataFrame:
    return pd.read_csv(
        qrels_path,
        sep=" ",
        names=["query_id", "Q0", "passage_id", "rating"],
        dtype={"query_id": int, "Q0": str, "passage_id": int, "rating": int},
    )


@click.command()
@click.option("--data_path", default="data/raw/", type=str)
@click.option("--output_path", default="data/processed/", type=str)
@click.option("--passages_path", default="collection.tsv", type=str)
@click.option("--queries_path", default="queries.train.tsv", type=str)
@click.option(
    "--qrels_path", default="qidpidtriples.train.full.2.tsv", type=str
)
@click.option("--subsample", default=1_000_000, type=int)
@click.option("--seed", default=42, type=int)
@click.option("--validation_fraction", default=0.1, type=float)
def prepare_data(
        data_path: str,
        output_path: str,
        passages_path: str,
        queries_path: str,
        qrels_path: str,
        subsample: int,
        seed: int,
        validation_fraction: float,
) -> None:
    os.makedirs(output_path, exist_ok=True)

    print("Loading data")
    passages = load_passages(os.path.join(data_path, passages_path))
    queries = load_queries(os.path.join(data_path, queries_path))
    qrels = load_qrels(os.path.join(data_path, qrels_path))

    qrels_values = qrels.sample(n=subsample, random_state=seed).values
    train_qrels = qrels_values[
                  : int(len(qrels_values) * (1 - validation_fraction))
                  ]
    validation_qrels = qrels_values[
                       int(len(qrels_values) * (1 - validation_fraction)) :
                       ]

    print("Writing to train shards")
    write_shards("train", output_path, passages, queries, train_qrels)

    print("Writing to validation shards")
    write_shards(
        "validation", output_path, passages, queries, validation_qrels
    )

    print("Writing to test shards")
    test_queries = load_queries(
        os.path.join(data_path, "msmarco-test2019-queries.tsv")
    )
    test_qrels = load_test_qrels(os.path.join(data_path, "2019qrels-pass.txt"))

    write_test_shards(output_path, passages, test_queries, test_qrels)


def write_test_shards(
        output_path: str,
        passages: pd.DataFrame,
        queries: pd.DataFrame,
        test_qrels: pd.DataFrame,
):
    test_sink = wds.TarWriter(os.path.join(output_path, f"test.tar"))
    for query_id, _, passage_id, rating in tqdm(
            test_qrels.values, desc="Writing to test shards"
    ):
        query = queries.loc[query_id, "query"]
        passage = passages.loc[passage_id, "passage"]

        test_sink.write(
            {
                "__key__": f"{query_id}-{passage_id}",
                "query.pyd": query,
                "passage.pyd": passage,
                "label.cls": 0 if rating < 2 else 1,
                "rating.cls": rating,
            }
        )
    test_sink.close()


def write_shards(
        name: str,
        output_path: str,
        passages: pd.DataFrame,
        queries: pd.DataFrame,
        qrels,
):
    ONE_GIGABYTE = 1024**3
    sink = wds.ShardWriter(
        os.path.join(output_path, f"{name}-%d.tar"),
        maxsize=ONE_GIGABYTE,
        maxcount=100000,
        verbose=0,
    )

    for query_id, positive_passage_id, negative_passage_id in tqdm(
            qrels, desc="Writing to shards"
    ):
        query = queries.loc[query_id, "query"]
        positive_passage = passages.loc[positive_passage_id, "passage"]
        negative_passage = passages.loc[negative_passage_id, "passage"]

        sink.write(
            {
                "__key__": f"{query_id}-{positive_passage_id}",
                "query.pyd": query,
                "passage.pyd": positive_passage,
                "label.cls": 1,
            }
        )

        sink.write(
            {
                "__key__": f"{query_id}-{negative_passage_id}",
                "query.pyd": query,
                "passage.pyd": negative_passage,
                "label.cls": 0,
            }
        )
    sink.close()


## The main function is called with the default parameters defined in the click configuration above.
prepare_data()

## Creating a DataModule for the Dataset

This datamodule simply creates the dataloaders for loading the dataset into the model. The dataloaders are created using the WebDataset library, which allows for efficient loading of the dataset. The dataloaders are created using the `train-{0..17}.tar` files for training, the `validation-{0..1}.tar` files for validation, and the `test.tar` file for testing.

By using a datamodule from Lightning, we can easily use the datamodule with the Trainer class from Lightning, which allows for easy training of the model.

In [None]:
from multiprocessing import cpu_count

import lightning as L
from torch.utils.data import DataLoader
import webdataset as wds


class MSMarcoDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int = 256, num_workers: int | None = None, dataset_length: int = 50_000):
        super().__init__()

        self.num_workers = (
            cpu_count() // 2 if num_workers is None else num_workers
        )
        self.batch_size = batch_size

        self.train_dataset = None
        self.validation_dataset = None
        self.test_dataset = None

        self.dataset_length = dataset_length

        self.save_hyperparameters()

    @staticmethod
    def to_string(string: bytes) -> str:
        return string.decode("utf-8")

    @staticmethod
    def to_float(string: bytes) -> float:
        return float(string.decode("utf-8"))

    @staticmethod
    def rating_to_class(rating: bytes) -> float:
        rating = float(rating.decode("utf-8"))

        return 1. if rating >= 1 else 0.

    def setup(self, stage: str) -> None:
        self.train_dataset = (
            wds.WebDataset(
                "data/processed/train-{0..17}.tar", shardshuffle=True
            )
            .with_length(self.dataset_length)
            .shuffle(1000)
            .to_tuple("query.pyd", "passage.pyd", "label.cls")
            .map_tuple(self.to_string, self.to_string, self.to_float)
        )
        self.validation_dataset = (
            wds.WebDataset("data/processed/validation-{0..1}.tar")
            .to_tuple("query.pyd", "passage.pyd", "label.cls")
            .map_tuple(self.to_string, self.to_string, self.to_float)
        )

        self.test_dataset = (
            wds.WebDataset("data/processed/test.tar")
            .to_tuple("query.pyd", "passage.pyd", "rating.cls")
            .map_tuple(self.to_string, self.to_string, self.rating_to_class)
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset.batched(self.batch_size),
            pin_memory=True,
            batch_size=None,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset.batched(self.batch_size),
            batch_size=None,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.validation_dataset.batched(self.batch_size),
            batch_size=None,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.val_dataloader()

## Defining the Model Module

This module defines the structure of the model we are training, which loss function is used, and which optimizer is used. Moreover, it defines the metrics we are using to evaluate the model.

The module also defines what we do in each type of step of the training process:
1. `training_step` defines what we do in each step of the training process. In this case, we simply calculate the loss and log it.
2. `validation_step` defines what we do in each step of the validation process. In this case, we simply calculate the loss and some performance metrics and log them.
3. `test_step` defines what we do in each step of the testing process. In this case, we simply calculate the loss and some performance metrics and log them.

In [None]:
from typing import Any

import wandb
from torch import nn, Tensor
import lightning as L
from sentence_transformers import SentenceTransformer, models
import torchmetrics
from sentence_transformers.util import batch_to_device
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, \
    average_precision_score, ConfusionMatrixDisplay, RocCurveDisplay


class SBERT(L.LightningModule):
    def __init__(
            self,
            model: SentenceTransformer,
            criterion: nn.Module,
            lr: float = 1e-5,
            compile_model: bool = False,
    ) -> None:
        super().__init__()

        self.transformer = model.to(self.device)
        self.pooling = models.Pooling(model.get_sentence_embedding_dimension())

        self.model = torch.compile(nn.Sequential(
            self.transformer,
            self.pooling
        ), mode="max-autotune", disable=not compile_model)

        self.cosine = nn.CosineSimilarity()
        self.criterion = criterion

        self.is_cosine_embedding_loss = isinstance(self.criterion, nn.CosineEmbeddingLoss)

        self.lr = lr

        self.train_metrics = torchmetrics.MetricCollection(
            {
                "train_accuracy_k1": torchmetrics.Accuracy(task="binary"),
                "train_accuracy_k3": torchmetrics.Accuracy(
                    task="binary", top_k=3
                ),
                "train_accuracy_k5": torchmetrics.Accuracy(
                    task="binary", top_k=5
                ),
            }
        )

        self.validation_metrics = torchmetrics.MetricCollection(
            {
                "val_accuracy_k1": torchmetrics.Accuracy(task="binary"),
                "val_accuracy_k3": torchmetrics.Accuracy(
                    task="binary", top_k=3
                ),
                "val_accuracy_k5": torchmetrics.Accuracy(
                    task="binary", top_k=5
                ),
            }
        )

        self.test_metrics = torchmetrics.MetricCollection(
            {
                "test_accuracy_k1": torchmetrics.Accuracy(task="binary"),
                "test_accuracy_k3": torchmetrics.Accuracy(
                    task="binary", top_k=3
                ),
                "test_accuracy_k5": torchmetrics.Accuracy(
                    task="binary", top_k=5
                ),
            }
        )

        self.y_test = []
        self.y_hat_test = []

        self.threshold = 0.5
        self.mapped_threshold = 0.5

        self.save_hyperparameters(ignore=["model", "criterion"])

    def forward(self, x) -> Tensor:
        tokens = self.transformer.tokenize(x)
        tokens = batch_to_device(tokens, self.device)
        return self.model(tokens)

    def training_step(self, batch) -> Tensor:
        x_question, x_answer, y = batch

        output_question = self(x_question)
        output_answer = self(x_answer)

        embeddings_question = output_question["sentence_embedding"]
        embeddings_answer = output_answer["sentence_embedding"]

        if self.is_cosine_embedding_loss:
            loss = self.criterion(
                embeddings_question, embeddings_answer, y
            )
        else:
            similarity = self.cosine(embeddings_question, embeddings_answer)
            y_hat, y = similarity.to(torch.float32), y.to(torch.float32)
            loss = self.criterion(y_hat, y)
            self.train_metrics(y_hat.detach().cpu(), y.detach().cpu())

            self.log_dict(
                self.train_metrics, on_step=True, on_epoch=True, prog_bar=False
            )

        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True
        )

        return loss

    def validation_step(self, batch):
        x_question, x_answer, y = batch

        output_question = self(x_question)
        output_answer = self(x_answer)

        embeddings_question = output_question["sentence_embedding"]
        embeddings_answer = output_answer["sentence_embedding"]

        if self.is_cosine_embedding_loss:
            loss = self.criterion(
                embeddings_question, embeddings_answer, y
            )
        else:
            similarity = self.cosine(embeddings_question, embeddings_answer)
            y_hat, y = similarity.to(torch.float32), y.to(torch.float32)

            loss = self.criterion(y_hat, y)
            self.validation_metrics(y_hat.detach().cpu(), y.detach().cpu())

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

    def test_step(self, batch):
        x_question, x_answer, y = batch

        output_question = self(x_question)
        output_answer = self(x_answer)

        embeddings_question = output_question["sentence_embedding"]
        embeddings_answer = output_answer["sentence_embedding"]

        similarity = self.cosine(embeddings_question, embeddings_answer)
        y_hat, y = similarity.to(torch.float32), y.to(torch.float32)

        if self.is_cosine_embedding_loss:
            loss = self.criterion(
                embeddings_question, embeddings_answer, y
            )
        else:
            loss = self.criterion(y_hat, y)

        self.test_metrics(y_hat.detach().cpu(), y.detach().cpu())
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

        return {
            "y_hat": y_hat.detach().cpu(),
            "y": y.detach().cpu(),
            "loss": loss.detach().cpu()
        }

    def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
        self.y_test.append(outputs["y"])
        self.y_hat_test.append(outputs["y_hat"])

    def map_y_hat(self, y_hat: np.ndarray) -> np.ndarray:
        return (y_hat + 1) / 2

    def on_test_end(self) -> None:
        y = torch.cat(self.y_test, dim=0)
        y_hat = torch.cat(self.y_hat_test, dim=0)

        y = y.detach().cpu().numpy()
        y_hat = y_hat.detach().cpu().numpy()
        y_hat_mapped = self.map_y_hat(y_hat)

        y_pred = np.where(y_hat > self.threshold, 1, 0)
        y_pred_mapped = np.where(y_hat_mapped > self.mapped_threshold, 1, 0)

        self.log_test_metrics(y, y_hat, y_pred, "raw")
        self.log_test_metrics(y, y_hat_mapped, y_pred_mapped, "mapped")

    def log_test_metrics(self, y, y_hat, y_pred, prefix):
        accuracy = accuracy_score(y, y_pred)
        precision = precision_score(y, y_pred)
        recall = recall_score(y, y_pred)
        f1 = f1_score(y, y_pred)
        roc_auc = roc_auc_score(y, y_hat)
        average_precision = average_precision_score(y, y_hat)

        self.logger.experiment.log(
            {
                f"test_{prefix}_confusion_matrix": wandb.Image(ConfusionMatrixDisplay.from_predictions(y, y_pred).figure_),
                f"test_{prefix}_roc_curve": wandb.Image(RocCurveDisplay.from_predictions(y, y_hat).figure_),
                f"test_{prefix}_accuracy": accuracy,
                f"test_{prefix}_precision": precision,
                f"test_{prefix}_recall": recall,
                f"test_{prefix}_f1": f1,
                f"test_{prefix}_roc_auc": roc_auc,
                f"test_{prefix}_average_precision": average_precision
            }
        )

    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> Any:
        x_question, x_answer, y = batch

        output_question = self(x_question)
        output_answer = self(x_answer)

        embeddings_question = output_question["sentence_embedding"]
        embeddings_answer = output_answer["sentence_embedding"]

        similarity = self.cosine(embeddings_question, embeddings_answer)
        y_hat, y = similarity.to(torch.float32), y.to(torch.float32)

        return {
            "y_hat": y_hat.detach().cpu(),
            "y": y.detach().cpu(),
        }

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


## Finding the Optimal Threshold

This function is simply used to find the optimal threshold for the model. The optimal threshold is found by finding the threshold that maximizes the F1 score on the validation set.

In [None]:
import torch
import lightning as L
from sklearn.metrics import f1_score


def normalize(x: torch.Tensor) -> torch.Tensor:
    return (x + 1) / 2


@torch.no_grad()
def find_threshold(trainer: L.Trainer, module: L.LightningModule, data_module: L.LightningDataModule, map: bool = False) -> tuple[float, float]:
    module.eval()
    predictions: list[dict[str, torch.Tensor]] = trainer.predict(module, datamodule=data_module)
    module.train()

    y = torch.cat([batch["y"] for batch in predictions]).cpu().numpy()
    y_hat = torch.cat([batch["y_hat"] for batch in predictions]).cpu().numpy()

    if map:
        y_hat = normalize(y_hat)

    best_threshold = 0
    best_f1 = 0

    for threshold in range(0, 100, 1):
        threshold /= 100

        y_pred = y_hat > threshold

        f1 = f1_score(y, y_pred)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    return best_threshold, best_f1


## Training the Model

Finally, we can train the model and use all that have been defined until this point. We use the `Trainer` class from Lightning to train the model. We use the `find_threshold` function defined above to find the optimal threshold for the model. We then use this threshold to evaluate the model on the test set.

5 different models are trained, each with a different model architecture. The models are trained for 10,000 steps with a batch size of 64. The models are trained using mixed precision training, which allows for faster training and lower memory usage. The models have all been trained on the HPC server cluster of DTU using A100 GPUs. Every other parameter is kept at the default value.

In [None]:
from pathlib import Path

import torch
import lightning as L
from sentence_transformers import SentenceTransformer

from datamodule import MSMarcoDataModule
from find_threshold import find_threshold
from sbert_model import SBERT
from lightning.pytorch.loggers import WandbLogger
import click
import os


@click.command()
@click.option("--batch_size", default=256, type=int)
@click.option("--model", default="bert-base-uncased", type=str)
@click.option("--epochs", default=1, type=int)
@click.option("--seed", default=42, type=int)
@click.option("--num_workers", default=None, type=int)
@click.option("--lr", default=1e-5, type=float)
@click.option("--precision", default=None, type=str)
@click.option("--dev", default=False, type=bool, is_flag=True)
@click.option("--num_steps", default=5000, type=int)
@click.option("--compile", default=False, type=bool)
@click.option("--loss_type", default="MSE", type=str)
@click.option("--test", default=False, type=bool, is_flag=True)
@click.option("--load_model", default=None, type=str)
@click.option("--cpu", default=False, type=bool, is_flag=True)
def train(
        batch_size: int,
        model: str,
        epochs: int,
        seed: int,
        num_workers: int,
        lr: float,
        precision: str | None = None,
        dev: bool = False,
        num_steps: int = -1,
        compile: bool = True,
        loss_type: str = "cosine",
        test: bool = False,
        load_model: str = None,
        cpu: bool = False,
):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    L.seed_everything(seed)

    torch.set_float32_matmul_precision("high")

    model = SentenceTransformer(model)
    print(model)

    logger = WandbLogger(
        project="dl_sbert", entity="colodingdongs", log_model="all"
    )

    logger.experiment.config.update(
        {
            "batch_size": batch_size,
            "model": model,
            "epochs": epochs,
            "seed": seed,
            "num_workers": num_workers,
            "lr": lr,
            "precision": precision,
            "dev": dev,
            "max_steps": num_steps,
            "loss_type": loss_type,
        }
    )

    trainer = L.Trainer(
        max_epochs=epochs,
        accelerator="auto" if not cpu else "cpu",
        devices=1 if torch.cuda.is_available() else "auto",
        deterministic=True,
        logger=logger,
        precision=precision,
        fast_dev_run=dev,
        max_steps=num_steps,
    )

    datamodule = MSMarcoDataModule(
        batch_size=batch_size, num_workers=num_workers, dataset_length=num_steps
    )

    criterion = torch.nn.CosineEmbeddingLoss() if loss_type == "cosine" else torch.nn.MSELoss()

    if load_model is None:
        l_module = SBERT(model, criterion, lr=lr, compile_model=compile)
    else:
        artifact = logger.experiment.use_artifact(load_model, type="model")
        artifact_dir = artifact.download()

        l_module = SBERT.load_from_checkpoint(Path(artifact_dir) / "model.ckpt", model=model, criterion=criterion, lr=lr)

    if test:
        trainer.test(l_module, datamodule)
        return

    trainer.fit(l_module, datamodule)

    threshold, best_f1 = find_threshold(trainer, l_module, datamodule)
    logger.experiment.config.update({"threshold": threshold})
    logger.experiment.log({"best_val_f1": best_f1})

    mapped_threshold, best_mapped_f1 = find_threshold(trainer, l_module, datamodule, map=True)
    logger.experiment.config.update({"mapped_threshold": mapped_threshold})
    logger.experiment.log({"best_mapped_val_f1": best_f1})

    l_module.threshold = threshold
    l_module.mapped_threshold = mapped_threshold

    trainer.test(l_module, datamodule)

model_names = [
    "bert-base-uncased",
    "bert-large-uncased",
    "distilbert-base-uncased",
    "microsoft/MiniLM-L12-H384-uncased",
    "nreimers/MiniLM-L6-H384-uncased"
]

for model_name in model_names:
    train(
        batch_size=64,
        model=model_name,
        num_steps=10_000,
        precision="16-mixed"
    )

## Making the Plot

This part of the notebook simply makes the plot of the model size vs the F1 score. The data is loaded from the `models/model_outputs` folder, which contains the output of the models trained above. The outputs are manually downloaded from the Weights & Biases website, and then stored in the `models/model_outputs` folder, but these are simply the config of each model and the metrics logged during training.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import json
import os
import pandas as pd
import yaml


def load_data(directory: str):
    """
    Loads the data from the given directory.
    """
    data = []
    params = {
        "BERT-base": 110,
        "BERT-large": 336,
        "DistilBERT": 67,
        "MiniLM-L6": 22.7,
        "MiniLM-L12": 33.4,
    }
    for filename, params in params.items():
        with open(os.path.join(directory, f"{filename}.json")) as f:
            model_output = json.load(f)
            model_output["model_name"] = filename
            model_output["params"] = params
        with open(os.path.join(directory, f"{filename}.config.yaml")) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
            model_output["lr"] = config["lr"]["value"]
            model_output["batch_size"] = config["batch_size"]["value"]
            model_output["num_steps"] = config["max_steps"]["value"]
            model_output["threshold"] = config["threshold"]["value"]
            model_output["mapped_threshold"] = config["mapped_threshold"]["value"]

        data.append(model_output)

    return pd.DataFrame.from_records(data)


if __name__ == "__main__":
    # Load the data
    data = load_data("models/model_outputs")

    # Plot the data
    with plt.style.context("seaborn-v0_8"):
        fig, ax = plt.subplots(layout="constrained", figsize=(42 // 2, 28 // 2))
        font = {'size': 32}

        plt.rc('font', **font)
        #fig.tight_layout()

        min_f1, max_f1 = data["test_mapped_f1"].min(), data["test_mapped_f1"].max()
        min_params, max_params = data["params"].min(), data["params"].max()

        X = data["params"].values.reshape(-1, 1)
        y = data["test_mapped_f1"].values.reshape(-1, 1)

        #plt.xlim((min_params - 20, max_params + 100))
        #plt.ylim(min_f1 - 0.05, max_f1 + 0.05)

        fig.suptitle("Model Size VS F1-Score", fontsize=56)

        # Change ticks
        plt.yticks(np.arange(0.54, 0.6 + 0.005, 0.005), fontsize=32)
        xticks = np.arange(25, np.ceil(max_params) + 25, 25)
        plt.ylabel("F1-Score", fontsize=32)

        plt.xticks(xticks, rotation=-45, fontsize=32)
        ax.set_xticklabels([f"{v}M" for v in xticks])
        plt.xlabel("No. params", fontsize=32)

        ax.scatter(data["params"], data["test_mapped_f1"], label="F1", s=200, c="red")

        text_positions = {
            "BERT-base": {"ha": "left"},
            "BERT-large": {"ha": "right"},
            "DistilBERT": {"ha": "left"},
            "MiniLM-L6": {"ha": "left"},
            "MiniLM-L12": {"ha": "left"},
        }

        # Add Model name to each dot
        for i in range(len(data)):
            row = data.iloc[i]
            ax.annotate(f"{row['model_name']} ({row['params']}M)", (row["params"], row["test_mapped_f1"]), fontsize=45, **text_positions[row["model_name"]])

    fig.savefig("model-vs-f1.png", dpi=600)
