# TSNE evaluation of features on EMBED dataset

In [None]:
import sys
import seaborn as sns

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")

import numpy as np
import matplotlib.pyplot as plt
from classification.classification_module import ClassificationModule
from hydra import compose, initialize
from data_handling.mammo import preprocess_breast
from pathlib import Path
from data_handling.mammo import EmbedDataModule, modelname_map
from data_handling.xray import PadChestDataModule

rev_model_map = {v: k for k, v in modelname_map.items()}
import os
import matplotlib
import matplotlib.patches as mpatches
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

def plot_confusion_matrix(conf_matrix, class_names=None, normalize=None):
    """
    Plot a confusion matrix using seaborn's heatmap.

    Parameters:
    -----------
    conf_matrix : numpy.ndarray
        The confusion matrix to plot
    class_names : list, optional
        List of class names for axis labels
    normalize : str, optional
        Indicates if the confusion matrix is normalized
    """
    plt.figure(figsize=(5, 5))

    # If class names not provided, use numeric indices
    if class_names is None:
        class_names = [str(i) for i in range(len(conf_matrix))]

    # Format annotation based on normalization
    fmt = ".2f" if normalize else "d"

    # Create a heatmap
    sns.heatmap(
        conf_matrix,
        annot=True,
        fmt=fmt,
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
        vmin=0,
        vmax=1 if normalize else None,
    )

    plt.xlabel("Predicted")
    plt.ylabel("Actual")

    title = "Average Confusion Matrix"
    if normalize == "true":
        title += " (Normalized by Row)"
    elif normalize == "pred":
        title += " (Normalized by Column)"
    elif normalize == "all":
        title += " (Normalized by Total)"

    plt.title(title)
    plt.tight_layout()
    plt.show()


In [None]:
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(
        config_name="config.yaml",
        overrides=["experiment=base_density", "data.cache=False"],
    )
    print(cfg)
    data_module = EmbedDataModule(config=cfg)
test_loader = data_module.test_dataloader()

In [None]:
model_dict_for_embeddings = {
    "SimCLR": "byatk1eo",
    "SimCLR+": "jv9hzx89",
    "CF-SimCLR": "kywspwfs",
}

In [None]:
from evaluation.helper_functions import run_get_embeddings
from sklearn.manifold import TSNE

results = {}

for run_name, run_id in model_dict_for_embeddings.items():
    if run_id != "":
        print(run_name)
        model_to_evaluate = f"../outputs/run_{run_id}/epoch=449.ckpt"
        classification_model = ClassificationModule.load_from_checkpoint(
            model_to_evaluate, map_location="cuda", config=cfg, strict=False
        ).model.eval()
        classification_model.cuda()
        # ID evaluation
        inference_results = run_get_embeddings(
            test_loader, classification_model, max_batch=2063
        )
        inference_results["scanners"] = np.argmax(inference_results["scanners"], 1)
        results[run_name] = inference_results

In [None]:
all_tsne = {}
all_scanners_name = {}

for i, (run_name, inference_results) in enumerate(results.items()):
    scanners = inference_results["scanners"]
    tsne = TSNE(n_jobs=6, perplexity=30, random_state=33)
    x2d = tsne.fit_transform(inference_results["feats"])
    scanners_plot = [rev_model_map[s] for s in scanners]
    all_tsne[run_name] = x2d
    all_scanners_name[run_name] = scanners_plot

In [None]:
palette = {
    s: sns.color_palette("colorblind", 7)[i]
    for s, i in zip(np.unique(list(rev_model_map.values())), [1, 3, 0, 4, 5, 5])
}

In [None]:
matplotlib.rcParams["font.family"] = "serif"
f, ax = plt.subplots(1, len(results.keys()), figsize=(20, 4), facecolor="none")

coords_anomaly = {
    "SimCLR": (-100, -48, 28, 28),
    "SimCLR+": (-94, -75, 27, 34),
    "CF-SimCLR": (-103, -57, 24, 33),
}

for i, run_name in enumerate(all_tsne.keys()):
    scanners = inference_results["scanners"]
    x2d = all_tsne[run_name]
    idx = np.random.permutation(x2d.shape[0])
    scanners_plot = np.asarray(all_scanners_name[run_name])[idx]
    sns.scatterplot(
        x=x2d[idx, 0],
        y=x2d[idx, 1],
        hue=scanners_plot,
        ax=ax[i],
        palette=palette,
        legend=i == 0,
        alpha=1.0,
        s=30,
    )
    ax[i].set_xlabel("")
    ax[i].set_ylabel("")
    ax[i].set_title(run_name, fontsize=18)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    if run_name in coords_anomaly.keys():
        left, bottom, width, height = coords_anomaly[run_name]
        rect = mpatches.Rectangle(
            (left, bottom),
            width,
            height,
            fill=False,
            color="black",
            linewidth=2,
            label="Cluster of breast implants",
        )
        ax[i].add_patch(rect)
        outliers = np.where(
            (left < x2d[:, 0])
            & (x2d[:, 0] < left + width)
            & (bottom < x2d[:, 1])
            & (x2d[:, 1] < bottom + height)
        )[0]
        shorts_path = inference_results["paths"][outliers]
        img_dir = "/vol/biomedic3/data/EMBED"
        img = preprocess_breast(
            Path(img_dir) / "images/png/1024x768" / shorts_path[33], (256, 192)
        )[0]
        im = OffsetImage(img, zoom=0.15, cmap="gray")
        ab = AnnotationBbox(
            im,
            xy=(left + width / 2, bottom),
            xybox=(left - 6, bottom - 33),
            pad=0,
            arrowprops={
                "arrowstyle": "simple",
                "facecolor": "black",
                "mutation_scale": 3,
            },
        )
        ax[i].add_artist(ab)

ax[0].legend(
    loc="center", bbox_to_anchor=(1.7, -0.18), ncol=6, fontsize=12.35, markerscale=2
)

ax[0].set_xlim((-118, 110))
ax[0].set_ylim((-102, 110))
ax[1].set_xlim((-112, 110))
ax[1].set_ylim((-132, 110))
ax[2].set_xlim((-122, 108))
ax[2].set_ylim((-111.5, 105))

plt.savefig("figures/embeddings_tsne.png", bbox_inches="tight", dpi=500)

# Predicting scanner from features

In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA

for name, inference_results in results.items():
    # Initialize stratified k-fold
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=33)

    # Lists to store results
    conf_matrices = []
    accuracies = []

    X = inference_results["feats"]
    y = inference_results["scanners"]
    idx = np.concatenate(
        [
            np.random.choice(
                np.where(inference_results["scanners"] == i)[0],
                min(1000, np.sum(inference_results["scanners"] == i)),
                replace=False,
            )
            for i in range(5)
        ]
    )
    print(np.bincount(y[idx]))
    X = X[idx]
    y = y[idx]
    # Perform cross-validation
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        classifier = Pipeline([("pca", PCA(16)), ("rf", LogisticRegression())])

        # Train model
        classifier.fit(X_train, y_train)

        # Make predictions
        y_pred = classifier.predict(X_test)

        # Calculate confusion matrix
        cm = confusion_matrix(y_test, y_pred, normalize="true")
        conf_matrices.append(cm)

        # Calculate accuracy
        accuracy = np.mean(np.diag(cm))
        accuracies.append(accuracy)

    # Calculate average confusion matrix
    avg_conf_matrix = np.mean(conf_matrices, axis=0)
    
    print(name)
    print(f"\nAverage Accuracy: {np.mean(accuracies):.4f} (Â±{np.std(accuracies):.4f})")
    plot_confusion_matrix(avg_conf_matrix, normalize=True)