# Visualization of Multi-annotator Deep Learning for Classification
In this notebook, we are going to reproduce the three visualizations provided in the accompanied article. Concretely, we demonstrate how to
- obtain the predictions of different MaDL variants for the two-dimensional data set TOY (corresponding to `RQ="p1+p2+independent"`),
- visualize the learned annotator embeddings and weights for the data set LETTER (corresponding to `RQ="p3+p4+interdependent"` or `RQ="p3+p4+random-interdependent"`),
- and infer the performances of unknown annotators (inductive learning) for the data set TOY (corresponding to `RQ="p5+p6+inductive"`).

Reproducing the plots for different MaDL variants requires varying the parameters `embed_x` and `confusion_matrix`, where we have the following mapping:
- MaDL(not X,I): `embed_x="none", confusion_matrix="isotropic"`,
- MaDL(not X,F): `embed_x="none", confusion_matrix="diagonal"`,
- MaDL(X,I): `embed_x="learned", confusion_matrix="isotropic"`,
- and MaDL(X,F): `embed_x="learned", confusion_matrix="diagonal"`.

In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../")

import matplotlib.pyplot as plt
import numpy as np

from lfma.utils import StoreBestModuleStateDict

from skactiveml.utils import majority_vote

from sklearn.manifold import MDS

from torch import cuda
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from lfma.utils import (
    LitProgressBar,
    compute_annot_perf_clf,
    plot_annot_perfs_clf,
    introduce_missing_annotations,
)
from evaluation.data_utils import load_data, DATA_PATH
from evaluation.architecture_utils import instantiate_madl_classifier

# Set random state to ensure reproducibility.
RANDOM_STATE = 1

# Constant to represent annotations that are not available.
MISSING_LABEL = -1

# Define the research question (RQ) variable according to the initial description of this notebook.
RQ = "p3+p4+interdependent"

# Define setup according to selection research question.
if RQ == "p1+p2+independent":
    DATA_TYPE = "none"
    MISSING_LABEL_RATIO = 0.8
    MAX_EPOCHS = 100
    DATA_SET_NAME = "toy-classification"
    ANNOTATOR_FEATURES = False
    DROPOUT = 0.5
    VIS_POINTS = "train"
    LR = 0.01
elif RQ == "p3+p4+interdependent":
    DATA_TYPE = "correlated"
    MISSING_LABEL_RATIO = 0.8
    MAX_EPOCHS = 5
    DATA_SET_NAME = "letter"
    ANNOTATOR_FEATURES = False
    DROPOUT = 0.0
    VIS_POINTS = "train"
    LR = 0.005
elif RQ == "p3+p4+random-interdependent":
    DATA_TYPE = "rand-dep_10_100"
    MISSING_LABEL_RATIO = 0.8
    MAX_EPOCHS = 5
    DATA_SET_NAME = "letter"
    ANNOTATOR_FEATURES = False
    DROPOUT = 0.0
    VIS_POINTS = "train"
    LR = 0.005
elif RQ == "p5+p6+inductive":
    DATA_TYPE = "inductive_5"
    MISSING_LABEL_RATIO = 0.98
    MAX_EPOCHS = 100
    DATA_SET_NAME = "toy-classification"
    ANNOTATOR_FEATURES = True
    DROPOUT = 0.0
    VIS_POINTS = "test"
    LR = 0.01
else:
    raise ValueError("Invalid value for `RQ`. See introduction of this notebook to check the possible options.")

# Training setup.
TRAINER_DICT = {
    "max_epochs": MAX_EPOCHS,
    "accelerator": "gpu" if cuda.is_available() else "cpu",
    "devices": 1 if cuda.is_available() else None,
    "callbacks": [StoreBestModuleStateDict(score_name="val_acc", maximize=True), LitProgressBar()],
    "logger": False,
    "enable_checkpointing": False,
}

### Load Data

In [None]:
# Load data.
ds = load_data(
    data_set_name=DATA_SET_NAME,
    use_annotator_features=ANNOTATOR_FEATURES,
    data_path=DATA_PATH,
    preprocess=True,
    random_state=RANDOM_STATE,
    n_repeats=1,
    valid_size=0.05,
    test_size=0.2,
    data_type=DATA_TYPE,
)
classes = np.unique(ds["y_true"])
n_classes = len(classes)
n_samples = ds["X"].shape[0]
n_features = ds["X"].shape[1]
n_annotators = ds["A"].shape[0]
n_ap_features = ds["A"].shape[1]

# Introduce missing annotations.
ds["y_partial"] = introduce_missing_annotations(
    y=ds["y"],
    percentage=MISSING_LABEL_RATIO,
    missing_label=MISSING_LABEL,
    random_state=RANDOM_STATE,
)
test_annot_indices = np.arange(n_annotators)
inductive_annot_indices = None
if "inductive" in DATA_TYPE:
    splits = DATA_TYPE.split("_")
    n_test_annot = int(splits[1])
    inductive_annot_indices = np.arange(0, n_annotators, n_annotators // n_test_annot).astype(int)
    ds["y_partial"][:, inductive_annot_indices] = -1
    test_annot_indices = np.setdiff1d(test_annot_indices, inductive_annot_indices)

# Compute majority vote considering missing annotations.
ds["y_mv"] = majority_vote(
    y=ds["y_partial"],
    classes=classes,
    missing_label=MISSING_LABEL,
    random_state=RANDOM_STATE,
)

# Evaluate annotator accuracies.
print("Annotator accuracies: ")
compute_annot_perf_clf(ds["y_true"], ds["y"])

### Train Model

In [None]:
data_loader_dict = {"batch_size": 64, "shuffle": True}
val_data_loader_dict = {"batch_size": 256, "shuffle": False}
fit_dict = {
    "X": ds["X"][ds["train"][0]],
    "y": ds["y_partial"][ds["train"][0]],
    "X_val": ds["X"][ds["valid"][0]],
    "y_val": ds["y_true"][ds["valid"][0]],
    "data_loader_dict": data_loader_dict,
    "val_data_loader_dict": val_data_loader_dict,
}
default_params = {
    "data_set_name": DATA_SET_NAME,
    "classes": classes,
    "n_features": n_features,
    "optimizer": AdamW,
    "optimizer_dict": {"lr": LR, "weight_decay": 0.0},
    "lr_scheduler": CosineAnnealingLR,
    "lr_scheduler_dict": {"T_max": 100},
    "dropout_rate": DROPOUT,
    "trainer_dict": TRAINER_DICT,
    "missing_label": MISSING_LABEL,
    "random_state": RANDOM_STATE,
}

# MaDL parameters.
eta = 0.8
embed_size = 16
ap_use_residual = True
alpha = 1.25
beta = 0.25
ap_sim_func = "rbf"
confusion_matrix = "diagonal"
embed_x = "learned"
model_name = f"madl_{embed_x}_{confusion_matrix}"
ap_use_outer_product = True

# Set annotator features.
fit_dict["A"] = ds["A"]

# Create MaDL instance.
clf = instantiate_madl_classifier(
    n_ap_features=n_ap_features,
    eta=eta,
    embed_size=embed_size,
    ap_use_outer_product=ap_use_outer_product,
    ap_use_residual=ap_use_residual,
    alpha=alpha,
    beta=beta,
    confusion_matrix=confusion_matrix,
    embed_x=embed_x,
    **default_params,
)

# Train MaDL instance.
clf = clf.fit(**fit_dict)

### Evaluate Classification Accuracy

In [None]:
# Quantitative Evaluation
for d in ["train", "valid", "test"]:
    P_perf = clf.predict_annotator_perf(ds["X"][ds[d][0]], data_loader_dict=val_data_loader_dict)
    perf = P_perf.mean(axis=0)
    acc = clf.score(ds["X"][ds[d][0]], ds["y_true"][ds[d][0]])
    print(f"{d} classification accuracy: {acc}")

### Visualize Predictions

In [None]:
# Visualize predictions for two-dimensional data sets.
if n_features == 2:
    annotator_indices = inductive_annot_indices if "inductive" in DATA_TYPE else np.arange(n_annotators)
    y = ds["y"] if "inductive" in DATA_TYPE else ds["y_partial"]
    A = ds["A"][annotator_indices]
    appendix = "" if DATA_TYPE == "none" else "-" + DATA_TYPE
    plot_annot_perfs_clf(
        X=ds["X"][ds[VIS_POINTS][0]],
        y_true=ds["y_true"][ds[VIS_POINTS][0]],
        y_full=ds["y"][ds[VIS_POINTS][0]][:, annotator_indices],
        y=y[ds[VIS_POINTS][0]][:, annotator_indices],
        clf=clf,
        missing_label=MISSING_LABEL,
        cmap_clf="winter",
        cmap_annot="Purples",
        filepath=f"./toy-{model_name}{appendix}",
        A=A,
        plot_colorbar=False,
        figsize=(6, 6),
        plot_accuracies=False,
        markersize=120,
    )

### Visualize Learned Annotator Embeddings and Weights

In [None]:
from sklearn.metrics import pairwise_kernels

if hasattr(clf, "compute_annotator_embeddings"):
    annotator_groups_legend = {
        0: ("adversarial", "#4d4d4dff"),
        1: ("cluster-specialized", "#008080ff"),
        2: ("common", "#2a7fffff"),
        3: ("class-specialized", "#800080ff"),
        4: ("random (interdependent)", "#784421"),
        5: ("adversarial (interdependent)", "k"),
        6: ("cluster-specialized (interdependent)", "#44aa00"),
        7: ("class-specialized (interdependent)", "m"),
    }
    annotator_groups = [0, 1, 1, 2, 2, 2, 2, 2, 2, 3]
    if DATA_TYPE == "rand-dep_10_100":
        annotator_groups += [4] * 90
    elif DATA_TYPE == "correlated":
        annotator_groups += [5, 6, 7] * 10
        annotator_groups[0] = 5
        annotator_groups[1] = 6
        annotator_groups[9] = 7
    elif "inductive" in DATA_TYPE:
        annotator_groups = [0] * 10 + [1] * 20 + [2] * 60 + [3] * 10
    annotator_groups = np.array(annotator_groups)
    unique_annotator_groups = np.unique(annotator_groups)
    A_embed = clf.compute_annotator_embeddings()
    S = pairwise_kernels(A_embed, metric="rbf", gamma=clf.module_.gamma.detach().cpu().numpy())
    weights = 1 / np.sum(S, axis=-1)
    weights /= weights.sum(axis=-1)
    weights = n_annotators * weights
    f, ax = plt.subplots(figsize=(6, 6))
    A_transformed = MDS(metric="precomputed", random_state=RANDOM_STATE, n_components=2).fit_transform(-S)
    for g in unique_annotator_groups:
        is_g = annotator_groups == g
        plt.scatter(
            A_transformed[is_g, 0],
            A_transformed[is_g, 1],
            c=annotator_groups_legend[g][1],
            label=annotator_groups_legend[g][0],
            s=120,
        )
    plt.legend()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0.1, 0.1)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(f"./embeddings-{DATA_TYPE}-{DATA_SET_NAME}.pdf", bbox_inches="tight", pad_inches=0)
    plt.show()
    f, ax = plt.subplots(figsize=(6, 6))
    bars = np.arange(len(unique_annotator_groups))
    for g_idx, g in enumerate(unique_annotator_groups):
        is_g = annotator_groups == g
        weight_g = weights[is_g].mean()
        bar_ax = plt.bar(annotator_groups_legend[g][0], weight_g, color=annotator_groups_legend[g][1])
    for bar_ax in ax.containers:
        ax.bar_label(bar_ax)
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0.1)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(f"./weights-{DATA_TYPE}-{DATA_SET_NAME}.pdf", bbox_inches="tight", pad_inches=0)
    plt.show()