# Analysis

This notebook contains the analysis of the data tracked on
[Weights & Biases](https://wandb.ai/).


## Setup

---

We will first setup everything so that we can easily analyse the experiment
results. This includes importing the necessary libraries, setting paths, loading
the experiment results from W&B.


In [None]:
# ruff: noqa
%load_ext autoreload
%autoreload 2

# Bult-in modules
import os
import sys

sys.path.insert(0, "..")

# Ignore warnings
import warnings

warnings.filterwarnings("ignore")

# External modules
# - Data Representation
import pandas as pd
import numpy as np

# - Data Visualization
from matplotlib import pyplot as plt
import seaborn as sns
import plotnine as pn

# - Machine Learning
import torch
import torch.nn as nn
from sklearn import metrics

# - Experiment Configuration and Logging
import wandb
from omegaconf import OmegaConf

# Custom modules
from utils import eval_utils as utils
from utils import train_utils

In [None]:
# Setup of global variables
ROOT_DIR = os.path.dirname(os.path.abspath("."))
ARTIFACT_DIR = os.path.join(ROOT_DIR, "artifacts")
FIGURE_DIR = os.path.join(ROOT_DIR, "figures")

METHODS = ["baseline", "baseline_pp", "matchingnet", "protonet", "maml"]
METHODS_WITH_SOT = []
for method in METHODS:
    METHODS_WITH_SOT.append(method)
    METHODS_WITH_SOT.append(method + "_sot")

STYLED_METHODS = ["Baseline", "Baseline++", "MatchingNet", "ProtoNet", "MAML"]
STYLED_METHODS_WITH_SOT = []
for method in STYLED_METHODS:
    STYLED_METHODS_WITH_SOT.append(method)
    STYLED_METHODS_WITH_SOT.append(method + " (SOT)")

styled_methods_dict = dict(zip(METHODS, STYLED_METHODS))


def get_name(name, sot=False):
    return styled_methods_dict[name] + (" (SOT)" if sot else "")

In [None]:
# Settings
sns.set_style("dark")
colorstyle = "RdBu"
sns.set_palette(colorstyle)

In [None]:
# Initialize wandb
WANDB_PROJECT = "few-shot-benchmark"
WANDB_ENTITY = "metameta-learners"

# Initialize W&B API
api = wandb.Api()

# Get all runs
runs = api.runs(f"{WANDB_ENTITY}/{WANDB_PROJECT}")

## Experiment 1: Benchmark

---

All models on all datasets with and without SOT for fixed few-shot learning
setting (5-way 5-shot).


### Loading Experiment Data


In [None]:
# Get all runs for experiment `benchmark`
GROUP = "benchmark"
USER = "mikasenghaas"

group_runs = [run for run in runs if run.group == GROUP and run.state == "finished"]
print(f"✅ Found {len(group_runs)} runs")

Next, we'll load all runs from the given experiment group into a single
dataframe.


In [None]:
df_runs = utils.load_to_df(group_runs)
print(f"✅ Loaded {len(df_runs)} runs.")

df_runs.head()

### Grouping

Each experiment is uniquely identified by the following parameters:

- `dataset`: The dataset used (`swissprot`, `tabula_muris`)
- `method`: The model used (`baseline`, `baseline_pp`, `protonet`,
  `matchingnet`, `maml`)
- `use_sot`: Whether to include the SOT module (`True`, `False`)
- `n_way`: The number of classes in each episode
- `n_shot`: The number of support samples per class in each episode

For each experiment setting, there are multiple trained models because of
hyper-parameter tuning. We will group the runs by the above parameters and only
use the best-performing model on the validation set for the following analysis.


In [None]:
# Group tuning runs by experiment configuration
df_best_runs = utils.get_best_run(df_runs, metric=("eval", "val/acc"))
print(f"✅ Filtered to {len(df_best_runs)} best runs.")

# Let's also save two separate dataframes for the two different datasets
df_best_runs_tm = df_best_runs[df_best_runs[(
    "config", "dataset")] == "tabula_muris"]
df_best_runs_sp = df_best_runs[df_best_runs[(
    "config", "dataset")] == "swissprot"]

df_best_runs.head()

### Train / Val / Test Performance for all models

Here, we plot a simple bar plot for all methods (5 methods, each with and
without SOT) on all three splits (train, val, test). Performances are shown in
two separate plots for Swissprot and Tabula Muris


In [None]:
# Performance by split for all methods
fig, axs = plt.subplots(nrows=2, figsize=(20, 8))
fig.tight_layout()


def pivot_acc(df):
    tmp = []
    for i, best_model in df.iterrows():
        for split in ["train", "val", "test"]:
            method_name = get_name(
                best_model[("config", "method")
                           ], best_model[("config", "use_sot")]
            )
            tmp.append(
                {
                    "method": method_name,
                    "split": split,
                    "acc": best_model[("eval", f"{split}/acc")],
                }
            )
    return pd.DataFrame(tmp)


sns.barplot(
    pivot_acc(df_best_runs_tm),
    x="method",
    y="acc",
    hue="split",
    order=STYLED_METHODS_WITH_SOT,
    ax=axs[0],
)
sns.barplot(
    pivot_acc(df_best_runs_sp),
    x="method",
    y="acc",
    hue="split",
    order=STYLED_METHODS_WITH_SOT,
    ax=axs[1],
)
# Set title
axs[0].set_title("Tabula Muris", fontweight="bold")
axs[1].set_title("SwissProt", fontweight="bold")

# Disable legend on first subplot
axs[0].get_legend().set_title("Split")
axs[1].get_legend().remove()

for ax in axs:
    ax.set_xlabel("Method")
    ax.set_ylabel("Acc. (%)")

fig.savefig(os.path.join(FIGURE_DIR, "benchmark-split-perf.pdf"),
            bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

### Performance by method with and without SOT

Here, we compare the performance of the different methods with and without SOT.
The left subplot shows the test performance on the Tabula Muris dataset, while
the right subplot shows the test performance on the Swissprot dataset.


In [None]:
# Performance by method with and without SOT
fig, axs = plt.subplots(ncols=2, figsize=(20, 5))

sns.barplot(
    df_best_runs_tm,
    x=("config", "method"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    order=METHODS,
    ax=axs[0],
)

sns.barplot(
    df_best_runs_sp,
    x=("config", "method"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    ci="sd",
    order=METHODS,
    ax=axs[1],
)

# Set title
axs[0].set_title("Tabula Muris", fontweight="bold")
axs[1].set_title("SwissProt", fontweight="bold")

# Disable legend on first subplot
axs[0].get_legend().set_title("SOT")
axs[1].get_legend().set_title("SOT")

# Set axis labels
axs[0].set_xticklabels([get_name(name.get_text())
                       for name in axs[0].get_xticklabels()])
axs[1].set_xticklabels([get_name(name.get_text())
                       for name in axs[1].get_xticklabels()])

for ax in axs:
    ax.set_ylabel("Test Acc. (%)")
    ax.set_xlabel("Method")

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "benchmark-perf.pdf"),
            bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

## Experiment 2: Way-Shot Analysis

---

Varying the number of shots per class.


In [None]:
# Load experiment data for `way-shot` experiment
GROUP = "way-shot"
USER = "mikasenghaas"

# Filter runs by group
group_runs = [run for run in runs if run.group == GROUP and run.state == "finished"]
print(f"✅ Found {len(group_runs)} runs")

In [None]:
# Load runs into dataframe
df_runs = utils.load_to_df(group_runs)
print(f"✅ Loaded {len(df_runs)} runs.")

df_runs.head()

Only keep the best run for each experiment configuration. This only has an
effect if hyperparameter tuning was performed.


In [None]:
# Group tuning runs by experiment configuration
df_best_runs = utils.get_best_run(df_runs, metric=("eval", "val/acc"))
print(f"✅ Filtered to {len(df_best_runs)} best runs.")

df_best_runs.head()

### Shot-Way-Analysis

Display the test/acc as a function of the number of shots per class and the
number of classes to distinguish between the different methods for ProtoNet
without and with SOT.


In [None]:
# Plot test/acc vs. n_shot for SOT and non-SOT methods
fig, axs = plt.subplots(ncols=2, figsize=(10, 5))

# test/acc ~ n_shot
sns.scatterplot(
    data=df_best_runs,
    x=("config", "n_way"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    alpha=0.25,
    ax=axs[0],
)
sns.lineplot(
    data=df_best_runs,
    x=("config", "n_way"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    legend=False,
    ax=axs[0],
)

# test/acc ~ n_way
sns.scatterplot(
    data=df_best_runs,
    x=("config", "n_shot"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    alpha=0.25,
    ax=axs[1],
)
sns.lineplot(
    data=df_best_runs,
    x=("config", "n_shot"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    legend=False,
    ax=axs[1],
)

# Set axis labels
axs[0].set_xlabel("N-Way")
axs[1].set_xlabel("N-Shot")

# Set axis labels
axs[0].set_ylabel("Val. Acc. (%)")
axs[1].set_ylabel("")

# Set legend title
axs[0].get_legend().set_title("SOT")
axs[1].get_legend().set_title("SOT")

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "way-shot.pdf"), bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

## Experiment 3: Understanding embeddings

---

The goal of this section will be to compare the embeddings learned by the models
with and without SOT.


In [None]:
# Experiments
GROUP = "benchmark"

# Filter runs by group
group_runs = {
    run.id: run for run in runs if run.group == GROUP and run.state == "finished"
}
print(f"✅ Loaded {len(group_runs)} runs")

# Load runs into dataframe
df_runs = utils.load_to_df(group_runs.values())
df_best_runs = utils.get_best_run(df_runs, metric=("eval", "test/acc"))

In [None]:
df_best_runs.loc["k5oxu82u"]

In [None]:
# Get a run for ProtoNet on Tabula Muris
is_protonet = df_best_runs[("config", "method")] == "protonet"
is_tabula_muris = df_best_runs[("config", "dataset")] == "tabula_muris"
is_sot = df_best_runs[("config", "use_sot")] == True
best_run_id = df_best_runs[is_protonet & is_tabula_muris & is_sot].iloc[0].name

# Load best run
run = group_runs[best_run_id]
print(
    f"✅ Loaded run {run.id} fo {run.config['n_way']}-way {run.config['n_shot']}-shot."
)

In [None]:
# Initialise dataset/ method

# Hacky way to initialise
os.chdir("..")  # Have to change to root directory to avoid re-downloading data
dataset, _, _, model = train_utils.initialize_dataset_model(
    OmegaConf.create(run.config), device="cpu"
)
model.n_query = dataset.n_query
loader = dataset.get_data_loader(num_workers=0, pin_memory=False)

os.chdir("notebooks")
print(f"✅ Initialised dataset and model.")

In [None]:
# Download artifact (model weights)
utils.download_artifact(
    api,
    wandb_entity=WANDB_ENTITY,
    wandb_project=WANDB_PROJECT,
    artifact_dir=ARTIFACT_DIR,
    run_id=run.id,
)

In [None]:
# Load model weights
weight_path = os.path.join(ARTIFACT_DIR, run.id, "best_model.pt")
model.load_state_dict(torch.load(weight_path))

In [None]:
# Verify hyper-parameters
n_episodes = dataset.n_episodes
n_way, n_support, n_query = dataset.n_way, dataset.n_support, dataset.n_query

assert len(loader) == n_episodes, "Number of episodes does not match dataset size."
assert n_way == run.config["n_way"], "Number of classes does not match config."
assert (
    n_support == run.config["n_shot"]
), "Number of support examples does not match config."
assert (
    n_query == run.config["n_query"]
), "Number of query examples does not match config."

print(
    f"✅ Loaded {n_way}-way {n_support}-shot {dataset._dataset_name} with episodic data loader (n={n_episodes})."
)

In [None]:
model.test_loop(loader)

In [None]:
# Visualise multiple episodes
show_embeddings = ["input", "backbone", "lstm"]
n_episodes = 2

fig, axs = plt.subplots(
    nrows=n_episodes,
    ncols=len(show_embeddings),
    figsize=(5 * len(show_embeddings), 5 * n_episodes),
)
fig.tight_layout()
fig.suptitle(
    f"{n_way}-way {n_support}-shot {dataset._dataset_name}",
    fontweight="bold",
)

for i in range(n_episodes):
    for j, show in enumerate(show_embeddings):
        utils.visualise_episode(
            loader,
            model,
            show=show,
            ax=axs[i, j],
        )

In [None]:
_, ax = plt.subplots(ncols=1, figsize=(10, 10))
utils.visualise_episode(loader, model, show="backbone", ax=ax)

## Understanding Model Performance


### Looking closer to particular runs

---


Select a run from the table above to look at it in more detail.


In [None]:
runid = None
config = [run.config for run in group_runs if run.id == runid][0]
dataset, loader, model = utils.init_run(config, ROOT_DIR, "test")

Next, let's evaluate the run's model on the given dataset:


In [None]:
# Get the mapping from encoding to annotation
encoding2anot = {v: k for k, v in dataset.trg2idx.items()}

# Define metric fn from sklearn assuming y_true and y_pred as input in this order
clf_kwargs = {"average": "macro"}
metric_fns = [
    (metrics.accuracy_score, None),
    (metrics.precision_score, clf_kwargs),
    (metrics.recall_score, clf_kwargs),
    (metrics.f1_score, clf_kwargs),
]

# Evaluate model and obtain its predictions with ground truth for each episode
episodes_results = utils.eval_run(model, loader)

# Compute metrics for each episode
episodes_metrics = utils.compute_metrics(metric_fns, episodes_results)

episodes_metrics.head()