In [None]:
%load_ext autoreload
%autoreload 2

In [20]:
import sys

sys.path.append("../../")

import itertools
import warnings
import logging

import pandas as pd
import mpire
from mgi.data.datasets.dataset_utils import get_ds_dataset, get_gk_dataset

from mgi.mappings.nn_mapping import FaissNNMapping
from mgi.mappings.similarity_embeddings import (
    FastTextSimilarityEmbedderLong,
    FastTextSimilarityEmbedder,
)
from mgi.metrics.anonymization import acc_at_k
from mgi.data.sampled_datasets import load_sampled_datasets_metadata

In [None]:
warnings.filterwarnings("ignore")
logging.disable(logging.WARNING)

In [None]:
ds_dataset_metadatas = load_sampled_datasets_metadata()

In [None]:
def get_result(embedder_cls, v, seed):
    embedder = embedder_cls("in_out_neighborhood")
    gk_dataset = get_gk_dataset(v.name.split("_")[0])
    mapping = FaissNNMapping(
        gk_dataset,
        1,
        embedder,
        "training",
        "euclidean",
    )
    ds_dataset = get_ds_dataset(v.name, seed)
    neighbours_ids, dists = mapping.get_neighbours_map_from_dataset(ds_dataset)
    acc = acc_at_k(ds_dataset, gk_dataset, neighbours_ids, 1)
    return {"dataset": v.name, "embedder": embedder_cls.__name__, "acc": acc, "seed": seed}

In [None]:
with mpire.WorkerPool(n_jobs=20) as pool:
    args = list(
        itertools.product(
            [FastTextSimilarityEmbedderLong, FastTextSimilarityEmbedder],
            ds_dataset_metadatas.values(),
            [121371, 59211, 44185],
        )
    )
    results = pool.map(get_result, args, progress_bar=True)

In [None]:
df = pd.DataFrame(results)

In [None]:
grouped = df.groupby(["embedder"])["acc"].mean()
grouped

In [None]:
diff = (
    grouped["FastTextSimilarityEmbedderLong"] - grouped["FastTextSimilarityEmbedder"]
) / grouped["FastTextSimilarityEmbedder"]
print(f"Difference: {diff:.2%}")