# Analysis

This notebook contains the analysis of the data tracked on
[Weights & Biases](https://wandb.ai/).


## Setup

---

We will first setup everything so that we can easily analyse the experiment
results. This includes importing the necessary libraries, setting paths, loading
the experiment results from W&B.


In [None]:
# ruff: noqa
%load_ext autoreload
%autoreload 2

# Bult-in modules
import os
import sys

sys.path.insert(0, "..")

# Ignore warnings
import warnings

warnings.filterwarnings("ignore")

# External modules
# - Data Representation
import pandas as pd
import numpy as np

# - Data Visualization
from matplotlib import pyplot as plt
import seaborn as sns
import plotnine as pn

# - Machine Learning
import torch

# - Experiment Configuration and Logging
import wandb

# Custom modules
from utils import eval_utils as utils
from utils import train_utils

In [None]:
# Setup of global variables
ROOT_DIR = os.path.dirname(os.path.abspath("."))
ARTIFACT_DIR = os.path.join(ROOT_DIR, "artifacts")
FIGURE_DIR = os.path.join(ROOT_DIR, "figures")
TABLE_DIR = os.path.join(ROOT_DIR, "tables")

METHODS = ["baseline", "baseline_pp", "matchingnet", "protonet", "maml"]
METHODS_WITH_SOT = []
for method in METHODS:
    METHODS_WITH_SOT.append(method)
    METHODS_WITH_SOT.append(method + "_sot")

STYLED_METHODS = ["Baseline", "Baseline++", "MatchingNet", "ProtoNet", "MAML"]
STYLED_METHODS_WITH_SOT = []
for method in STYLED_METHODS:
    STYLED_METHODS_WITH_SOT.append(method)
    STYLED_METHODS_WITH_SOT.append(method + " (SOT)")

styled_methods_dict = dict(zip(METHODS, STYLED_METHODS))


def get_name(name, sot=False):
    return styled_methods_dict[name] + (" (SOT)" if sot else "")

In [None]:
# Settings
sns.set_style("dark")
colorstyle = "RdBu"
sns.set_palette(colorstyle)

In [None]:
# Initialize wandb
WANDB_PROJECT = "few-shot-benchmark"
WANDB_ENTITY = "metameta-learners"

# Initialize W&B API
api = wandb.Api()

# Get all runs
runs = api.runs(f"{WANDB_ENTITY}/{WANDB_PROJECT}")

## Experiment 1: Benchmark

---

All models on all datasets with and without SOT for fixed few-shot learning
setting (5-way 5-shot).


### Loading Experiment Data


In [None]:
# Get all runs for experiment `benchmark`
GROUP = "benchmark"

group_runs = [run for run in runs if run.group ==
              GROUP and run.state == "finished"]
print(f"✅ Found {len(group_runs)} runs")

Next, we'll load all runs from the given experiment group into a single
dataframe.


In [None]:
df_runs = utils.load_to_df(group_runs)
print(f"✅ Loaded {len(df_runs)} runs.")

df_runs.head()

### Grouping

Each experiment is uniquely identified by the following parameters:

- `dataset`: The dataset used (`swissprot`, `tabula_muris`)
- `method`: The model used (`baseline`, `baseline_pp`, `protonet`,
  `matchingnet`, `maml`)
- `use_sot`: Whether to include the SOT module (`True`, `False`)
- `n_way`: The number of classes in each episode
- `n_shot`: The number of support samples per class in each episode

For each experiment setting, there are multiple trained models because of
hyper-parameter tuning. We will group the runs by the above parameters and only
use the best-performing model on the validation set for the following analysis.


In [None]:
# Group tuning runs by experiment configuration
df_best_runs = utils.get_best_run(df_runs, metric=("eval", "val/acc"))
print(f"✅ Filtered to {len(df_best_runs)} best runs.")

# Let's also save two separate dataframes for the two different datasets
df_best_runs_tm = df_best_runs[df_best_runs[("config", "dataset")] == "tabula_muris"]
df_best_runs_sp = df_best_runs[df_best_runs[("config", "dataset")] == "swissprot"]

df_best_runs.head()

In [None]:
# Parse the results for both the report
tm_res = utils.exp2results(df_best_runs_tm)
sp_res = utils.exp2results(df_best_runs_sp)

# Create a MultiIndex for the results where the first level is the dataset
tm_res = tm_res.set_index("Method")
tm_res.index = pd.MultiIndex.from_product([["Tabula Muris"], tm_res.index])
sp_res = sp_res.set_index("Method")
sp_res.index = pd.MultiIndex.from_product([["SwissProt"], sp_res.index])

# Concatenate the results
df_results = pd.concat([tm_res, sp_res])

# Get the latex
latex = utils.exp2latex(df_results)

# Save the latex
with open(os.path.join(TABLE_DIR, "results.tex"), "w") as f:
    f.write(latex)

### Train / Val / Test Performance for all models

Here, we plot a simple bar plot for all methods (5 methods, each with and
without SOT) on all three splits (train, val, test). Performances are shown in
two separate plots for Swissprot and Tabula Muris


In [None]:
# Performance by split for all methods
fig, axs = plt.subplots(nrows=2, figsize=(20, 8))
fig.tight_layout(pad=3.0)


def pivot_acc(df):
    tmp = []
    for i, best_model in df.iterrows():
        for split in ["train", "val", "test"]:
            method_name = get_name(
                best_model[("config", "method")], best_model[("config", "use_sot")]
            )
            tmp.append(
                {
                    "method": method_name,
                    "split": split,
                    "acc": best_model[("eval", f"{split}/acc")],
                }
            )
    return pd.DataFrame(tmp)


sns.barplot(
    pivot_acc(df_best_runs_tm),
    x="method",
    y="acc",
    hue="split",
    order=STYLED_METHODS_WITH_SOT,
    ax=axs[0],
)
sns.barplot(
    pivot_acc(df_best_runs_sp),
    x="method",
    y="acc",
    hue="split",
    order=STYLED_METHODS_WITH_SOT,
    ax=axs[1],
)
# Set title
axs[0].set_title("Tabula Muris", fontweight="bold")
axs[1].set_title("SwissProt", fontweight="bold")

# Disable legend on first subplot
axs[0].get_legend().set_title("Split")
axs[1].get_legend().remove()

for ax in axs:
    ax.set_xlabel("Method")
    ax.set_ylabel("Acc. (%)")

fig.savefig(os.path.join(FIGURE_DIR, "benchmark-split-perf.pdf"), bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

### Performance by method with and without SOT

Here, we compare the performance of the different methods with and without SOT.
The left subplot shows the test performance on the Tabula Muris dataset, while
the right subplot shows the test performance on the Swissprot dataset.


In [None]:
# Performance by method with and without SOT
fig, axs = plt.subplots(ncols=2, figsize=(20, 5))

sns.barplot(
    df_best_runs_tm,
    x=("config", "method"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    order=METHODS,
    ax=axs[0],
)

sns.barplot(
    df_best_runs_sp,
    x=("config", "method"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    ci="sd",
    order=METHODS,
    ax=axs[1],
)

# Set title
axs[0].set_title("Tabula Muris", fontweight="bold")
axs[1].set_title("SwissProt", fontweight="bold")

# Disable legend on first subplot
axs[0].get_legend().set_title("SOT")
axs[1].get_legend().set_title("SOT")

# Set axis labels
axs[0].set_xticklabels([get_name(name.get_text()) for name in axs[0].get_xticklabels()])
axs[1].set_xticklabels([get_name(name.get_text()) for name in axs[1].get_xticklabels()])

for ax in axs:
    ax.set_ylabel("Test Acc. (%)")
    ax.set_xlabel("Method")

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "benchmark-perf.pdf"), bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

## Experiment 2: Way-Shot Analysis

---

Varying the number of shots per class.


In [None]:
# Load experiment data for `way-shot` experiment
GROUP = "way-shot"
USER = "mikasenghaas"

# Filter runs by group
group_runs = [run for run in runs if run.group ==
              GROUP and run.state == "finished"]
print(f"✅ Found {len(group_runs)} runs")

In [None]:
# Load runs into dataframe
df_runs = utils.load_to_df(group_runs)
print(f"✅ Loaded {len(df_runs)} runs.")

df_runs.head()

Only keep the best run for each experiment configuration. This only has an
effect if hyperparameter tuning was performed.


In [None]:
# Group tuning runs by experiment configuration
df_best_runs = utils.get_best_run(df_runs, metric=("eval", "val/acc"))
print(f"✅ Filtered to {len(df_best_runs)} best runs.")

df_best_runs.head()

### Shot-Way-Analysis

Display the test/acc as a function of the number of shots per class and the
number of classes to distinguish between the different methods for ProtoNet
without and with SOT.


In [None]:
# Plot test/acc vs. n_shot for SOT and non-SOT methods
fig, axs = plt.subplots(ncols=2, figsize=(10, 5))

# test/acc ~ n_shot
sns.scatterplot(
    data=df_best_runs,
    x=("config", "n_way"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    alpha=0.25,
    ax=axs[0],
)
sns.lineplot(
    data=df_best_runs,
    x=("config", "n_way"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    legend=False,
    ax=axs[0],
)

# test/acc ~ n_way
sns.scatterplot(
    data=df_best_runs,
    x=("config", "n_shot"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    alpha=0.25,
    ax=axs[1],
)
sns.lineplot(
    data=df_best_runs,
    x=("config", "n_shot"),
    y=("eval", "test/acc"),
    hue=("config", "use_sot"),
    legend=False,
    ax=axs[1],
)

# Set axis labels
axs[0].set_xlabel("N-Way")
axs[1].set_xlabel("N-Shot")

# Set axis labels
axs[0].set_ylabel("Val. Acc. (%)")
axs[1].set_ylabel("")

# Set legend title
axs[0].get_legend().set_title("SOT")
axs[1].get_legend().set_title("SOT")

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "way-shot.pdf"), bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

## Experiment 3: SOT Interaction

---

We found that SOT interacts in an interesting manner with the re-embedding
modules in distance-based few-shot learners like `ProtoNet` and `MatchingNet`.
In this experiment, we investigate this interaction in more detail. To do this,
we trained `MatchingNet` by enabling the SOT and LSTM embedding for support and
query samples and observe the performance on the test set.


In [None]:
# Experiments
GROUP = "sot-interaction"

# Filter runs by group
group_runs = {
    run.id: run for run in runs if run.group == GROUP and run.state == "finished"
}
print(f"✅ Loaded {len(group_runs)} runs")

In [None]:
# Load runs into dataframe
df_runs = utils.load_to_df(group_runs.values())

# Load additional information
embed_support = [run.config["method"]["embed_support"] for run in group_runs.values()]
embed_query = [run.config["method"]["embed_query"] for run in group_runs.values()]
df_runs[("config", "embed_support")] = embed_support
df_runs[("config", "embed_query")] = embed_query
print(f"✅ Added meta-information.")

In [None]:
# Plot test/acc vs. name
fig, ax = plt.subplots(figsize=(20, 5))
sns.barplot(
    data=df_runs.sort_values(by=("eval", "test/acc"), ascending=False),
    x=("info", "name"),
    y=("eval", "test/acc"),
    palette="RdBu",
)
ax.set_xlabel("Method")
ax.set_ylabel("Test Acc. (%)")

# Set axis labels
labels = []
for exp in [l.get_text().split("-")[1::2] for l in ax.get_xticklabels()]:
    name = []
    if exp[0] == "true":
        name.append("SOT")
    if exp[1] == "true":
        name.append("SE")
    if exp[2] == "true":
        name.append("QE")

    labels.append(" + ".join(name))
ax.set_xticklabels(labels)

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "sot-interaction.pdf"),
            bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

In [None]:
# Scatterplot of interaction between support, query and SOT embeddings
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    df_runs,
    x=("config", "embed_support"),
    y=("config", "embed_query"),
    hue=("config", "use_sot"),
    size=("eval", "test/acc"),
    sizes=(100, 5000),
    linewidth=0,
)
ax.set_xlim(-0.5, 1.5)
ax.set_ylim(-0.5, 1.5)
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(["No", "Yes"])
ax.set_yticklabels(["No", "Yes"])
ax.set(
    xlabel="Embed Support",
    ylabel="Embed Query",
)
ax.legend().remove()

In [None]:
def is_interaction(row):
    support, query, sot = (
        row[("config", "embed_support")],
        row[("config", "embed_query")],
        row[("config", "use_sot")],
    )
    if (support | query) & sot:
        return True
    else:
        return False


df_runs[("config", "interaction")] = df_runs.apply(is_interaction, axis=1)

fig, ax = plt.subplots(figsize=(5, 5))
sns.barplot(
    df_runs,
    x=("config", "interaction"),
    y=("eval", "test/acc"),
    ax=ax,
)
ax.scatter(
    x=df_runs[("config", "interaction")],
    y=df_runs[("eval", "test/acc")],
    c="black",
    alpha=1,
    marker="x",
    zorder=100,
)

ax.set_xticklabels(["No", "Yes"])
ax.set_xlabel("Interaction")
ax.set_ylabel("Test Acc. (%)")

# Save figure
fig.savefig(os.path.join(FIGURE_DIR, "sot-interaction-2.pdf"), bbox_inches="tight")
print(f"✅ Saved figure to {FIGURE_DIR}.")

## Experiment 4: Understanding model performance

---

The goal of this section will be to compare the improvements of the SOT feature
transform. We will try to understand the improvements by looking at:

- **Embeddings during forward-pass**. Visualise the embeddings of support and
  query samples during episodes with and without SOT enabled.
- **Visualise the self-optimal transport plan.** Visualise the self-optimal
  transport plan for a few episodes via a heat map.
- **Understand model prediction patterns and errors.** Visualise the model
  predictions and errors for a few episodes with and without SOT enabled.

To get started, we will load two pre-trained models from the benchmarking
experiment. We will use two instances of `protonet` that were both trained on
the `tabula_muris` dataset. The first model was trained with the default
configuration, while the second model was trained with the same configuration
but with the `use_sot` flag set to `True`.


In [None]:
# Experiments
GROUP = "model-behaviour"

# Filter runs by group
runs = [run for run in runs if run.group == GROUP and run.state == "finished"]
print(f"✅ Loaded {len(runs)} runs")

# Load runs into dataframe
df_runs = utils.load_to_df(runs)

df_runs.head()

In [None]:
# Initialise data loaders and model
models = []
for run in runs:
    # Load data loaders and model
    train_loader, val_loader, test_loader, model = utils.init_all(run)
    models.append(model)

print(f"✅ Initialised data loader and model.")

In [None]:
# Download artifact (model weights)
for run in runs:
    utils.download_artifact(
        api,
        wandb_entity=WANDB_ENTITY,
        wandb_project=WANDB_PROJECT,
        artifact_dir=ARTIFACT_DIR,
        run_id=run.id,
    )

In [None]:
# Load model weights
weight_path = os.path.join(ARTIFACT_DIR, runs[0].id, "best_model.pt")
models[0].load_state_dict(torch.load(weight_path))

weight_path = os.path.join(ARTIFACT_DIR, runs[1].id, "best_model.pt")
models[1].load_state_dict(torch.load(weight_path))

print(f"✅ Loaded both model weights.")

In [None]:
# Evaluate performance
print("Evaluating model with SOT...")
models[0].test_loop(train_loader)
models[0].test_loop(val_loader)
models[0].test_loop(test_loader)

print("\nEvaluating model without SOT...")
models[1].test_loop(train_loader)
models[1].test_loop(val_loader)
models[1].test_loop(test_loader)

print(f"✅ Evaluated both models.")

We have correctly loaded the model weights by confirming the performance on the
`train`, `val` and `test` split for ProtoNet on SwissProt with and without SOT.


### Visualise Embeddings


In [None]:
# Visualise episode for Protnet w/ SOT
for loader in [train_loader, val_loader, test_loader]:
    utils.visualise_episode(train_loader, models[0])

In [None]:
# Visualise episode for Protnet w/ SOT
for loader in [train_loader, val_loader, test_loader]:
    utils.visualise_episode(train_loader, models[1])

### Visualise SOT transport plan

Here we visualise the self-optimal transport plan for a few episodes.


In [None]:
# Visualise transport plan for Protonet w/ SOT on train
fig, axs = plt.subplots(ncols=3, figsize=(20, 5))
for ax, loader in zip(axs, [train_loader, val_loader, test_loader]):
    utils.visualise_transport_plan(train_loader, models[0], ax=ax)

### Visualise confusion patterns


In [None]:
# Confusion matrix for Protonet w/ SOT on all splits
fig, axs = plt.subplots(ncols=3, figsize=(20, 5))
for ax, loader in zip(axs, [train_loader, val_loader, test_loader]):
    utils.visualise_confusion_matrix(loader, models[0], ax=ax)

In [None]:
# Confusion matrix for Protonet w/o SOT on all splits
fig, axs = plt.subplots(ncols=3, figsize=(20, 5))
for ax, loader in zip(axs, [train_loader, val_loader, test_loader]):
    utils.visualise_confusion_matrix(loader, models[1], ax=ax)