# Evaluation of linear probing/finetuning models on EMBED dataset

In [None]:
import sys
import seaborn as sns

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")

import numpy as np
import matplotlib.pyplot as plt
from classification.classification_module import ClassificationModule
from hydra import compose, initialize
from data_handling.mammo import preprocess_breast
from pathlib import Path
from data_handling.mammo import EmbedDataModule, modelname_map

rev_model_map = {v: k for k, v in modelname_map.items()}
import os
import matplotlib

os.chdir("/vol/biomedic3/mb121/causal-contrastive/evaluation")

In [None]:
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(
        config_name="config.yaml",
        overrides=["experiment=base_density", "data.cache=False"],
    )
    print(cfg)
    data_module = EmbedDataModule(config=cfg)
test_loader = data_module.test_dataloader()

## Plotting embeddings

In [None]:
model_dict_for_embeddings = {
    "SimCLR": "byatk1eo",
    "SimCLR+": "jv9hzx89",
    "CF-SimCLR": "kywspwfs",
}

In [None]:
from evaluation.helper_functions import run_get_embeddings
from sklearn.manifold import TSNE

results = {}

for run_name, run_id in model_dict_for_embeddings.items():
    if run_id != "":
        print(run_name)
        model_to_evaluate = f"../outputs/run_{run_id}/epoch=449.ckpt"
        classification_model = ClassificationModule.load_from_checkpoint(
            model_to_evaluate, map_location="cuda", config=cfg, strict=False
        ).model.eval()
        classification_model.cuda()
        # ID evaluation
        inference_results = run_get_embeddings(
            test_loader, classification_model, max_batch=500
        )
        inference_results["scanners"] = np.argmax(inference_results["scanners"], 1)
        results[run_name] = inference_results

In [None]:
all_tsne = {}
all_scanners_name = {}

for i, (run_name, inference_results) in enumerate(results.items()):
    scanners = inference_results["scanners"]
    tsne = TSNE(n_jobs=6, perplexity=30, random_state=33)
    x2d = tsne.fit_transform(inference_results["feats"])
    scanners_plot = [rev_model_map[s] for s in scanners]
    all_tsne[run_name] = x2d
    all_scanners_name[run_name] = scanners_plot

In [None]:
import matplotlib.patches as mpatches
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

In [None]:
palette = {
    s: sns.color_palette("colorblind", 7)[i]
    for s, i in zip(np.unique(list(rev_model_map.values())), [1, 3, 0, 4, 5, 5])
}

In [None]:
matplotlib.rcParams["font.family"] = "serif"
f, ax = plt.subplots(1, len(results.keys()), figsize=(20, 4), facecolor='none')

coords_anomaly = {
    "SimCLR": (-100, -48, 28, 28),
    "SimCLR+": (-94, -75, 27, 34),
    "CF-SimCLR": (-103, -57, 24, 33),
}

for i, run_name in enumerate(all_tsne.keys()):
    scanners = inference_results["scanners"]
    x2d = all_tsne[run_name]
    idx = np.random.permutation(x2d.shape[0])
    scanners_plot = np.asarray(all_scanners_name[run_name])[idx]
    sns.scatterplot(
        x=x2d[idx, 0],
        y=x2d[idx, 1],
        hue=scanners_plot,
        ax=ax[i],
        palette=palette,
        legend=i == 0,
        alpha=1.0,
        s=30,
    )
    ax[i].set_xlabel("")
    ax[i].set_ylabel("")
    ax[i].set_title(run_name, fontsize=18)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    if run_name in coords_anomaly.keys():
        left, bottom, width, height = coords_anomaly[run_name]
        rect = mpatches.Rectangle(
            (left, bottom),
            width,
            height,
            fill=False,
            color="black",
            linewidth=2,
            label="Cluster of breast implants",
        )
        ax[i].add_patch(rect)
        outliers = np.where(
            (left < x2d[:, 0])
            & (x2d[:, 0] < left + width)
            & (bottom < x2d[:, 1])
            & (x2d[:, 1] < bottom + height)
        )[0]
        shorts_path = inference_results["paths"][outliers]
        img_dir = "/vol/biomedic3/data/EMBED"
        img = preprocess_breast(
            Path(img_dir) / "images/png/1024x768" / shorts_path[33], (256, 192)
        )[0]
        im = OffsetImage(img, zoom=0.15, cmap="gray")
        ab = AnnotationBbox(
            im,
            xy=(left + width / 2, bottom),
            xybox=(left - 6, bottom - 33),
            pad=0,
            arrowprops={
                "arrowstyle": "simple",
                "facecolor": "black",
                "mutation_scale": 3,
            },
        )
        ax[i].add_artist(ab)

ax[0].legend(
    loc="center", bbox_to_anchor=(1.7, -0.18), ncol=6, fontsize=12.35, markerscale=2
)

ax[0].set_xlim((-118, 110))
ax[0].set_ylim((-102, 110))
ax[1].set_xlim((-112, 110))
ax[1].set_ylim((-132, 110))
ax[2].set_xlim((-122, 108))
ax[2].set_ylim((-111.5, 105))

plt.savefig("figures/embeddings_tsne.png", bbox_inches="tight", dpi=500)