In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/calibration_exploration/")

from plotting_notebooks.plotting_utils import (
    my_pretty_plot,
)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import defaultdict
from matplotlib import ticker
from hydra import initialize, compose
from classification.load_model_and_config import (
    get_run_id_from_config,
    _clean_config_for_backward_compatibility,
)
import torch
from pathlib import Path


def mytitle(experiment):
    match experiment:
        case "base_chexpert":
            plt.title("$\mathbf{CXR}$")
        case "base_density":
            plt.title("$\mathbf{EMBED}$")
        case _:
            plt.title("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")

In [None]:
from copy import deepcopy


model = 'resnet18'
experiment = 'base_living17'
config_str = [
                f"experiment={experiment}",
                f'model.encoder_name={model}',
            ]

with initialize(version_base=None, config_path="../configs"):
    config = compose(
        config_name="config.yaml",
        overrides=config_str + ["trainer.label_smoothing=0.00"],
    )
    config2 = deepcopy(config)
    _clean_config_for_backward_compatibility(config2)
    run_id = get_run_id_from_config(
        config2, allow_multiple_runs=False, allow_return_none_if_no_runs=False
    )

In [None]:
output_dir = Path(f'../outputs/run_{run_id}')
val_results = torch.load(output_dir / "val_outputs.pt")
test_results = torch.load(output_dir / "test_outputs.pt")

In [None]:
from calibration.inference_utils import get_outputs
import pytorch_lightning as pl

from classification.classification_module import ClassificationModule
from classification.load_model_and_config import get_modules

pl.seed_everything(config.seed)

data_module, _ = get_modules(config, shuffle_training=False)
model_module = ClassificationModule.load_from_checkpoint(
    f"{output_dir}/best.ckpt", config=config, strict=False
)
model_module.get_all_features = True
trainer = pl.Trainer(enable_progress_bar=True)
ood_val_results = get_outputs(
        model_module, data_module.get_irrelevant_ood_loader(0.1), trainer
    )

In [None]:
test_results.keys()

In [None]:
feats_shift = torch.cat([test_results[k]['feats'][-1] for k in test_results.keys() if k != 'id'])
feats_shift = feats_shift[torch.randperm(feats_shift.shape[0])[:10000]]
feats_id = test_results['id']['feats'][-1]
feats_id = feats_id[torch.randperm(feats_id.shape[0])[:10000]]
print(feats_id.shape, feats_shift.shape)

In [None]:
feats_ood = ood_val_results['feats'][-1]
feats_ood.shape

In [None]:
from sklearn.manifold import TSNE

In [None]:
all_feats = torch.cat([feats_id, feats_shift, feats_ood], 0)
domains_labels = np.concatenate([
    np.asarray(['TEST - ID' for _ in range(feats_id.shape[0])]), 
    np.asarray(['TEST - SHIFTED' for _ in range(feats_shift.shape[0])]),
    np.asarray(['SEMANTIC OOD' for _ in range(feats_ood.shape[0])])])
print(all_feats.shape, feats_id.shape, feats_ood.shape, feats_shift.shape)

In [None]:
tsne = TSNE(n_components=2)
x2d = tsne.fit_transform(all_feats)

In [None]:
sns.scatterplot(x=x2d[:,0], y=x2d[:,1], hue=domains_labels)
mytitle(experiment)

In [None]:
joint_plot = sns.jointplot(x=x2d[:,0], y=x2d[:,1], hue=domains_labels, alpha=0.7)
plt.legend(title='')
joint_plot.fig.axes[0].set_xlabel('')
joint_plot.fig.axes[0].set_ylabel('')
match experiment:
    case "base_chexpert":
        joint_plot.fig.axes[-2].set_title("$\mathbf{CXR}$")
    case "base_density":
        joint_plot.fig.axes[-2].set_title("$\mathbf{EMBED}$")
    case _:
        joint_plot.fig.axes[-2].set_title("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")
plt.savefig(f'tsne_{experiment}.pdf', bbox_inches='tight')