# Analysis

This notebook contains the analysis of the data tracked on
[Weights & Biases](https://wandb.ai/).


In [None]:
# Imports
import os

import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

import wandb
import torch
import torch.nn as nn

# Import wandb.run types
from wandb.apis.public import Run

# Import hydra
import hydra

In [None]:
# Setup
WANDB_PROJECT = "few-shot-benchmark"
WANDB_ENTITY = "metameta-learners"

GROUP = "mika"

ROOT_DIR = os.path.dirname(os.path.abspath("."))
ARTIFACT_DIR = os.path.join(ROOT_DIR, "artifacts")
FIGURE_DIR = os.path.join(ROOT_DIR, "figures")

## Load Experiment Data

Each W&B run corresponds to a single experiment.

DF with Run-ID, Group, Config (multi-column), Eval (multi-column)

# Dict of DFs with Run-ID to


In [None]:
# Load experiments from W&B
api = wandb.Api()

# Load all runs in project
runs = api.runs(f"{WANDB_ENTITY}/{WANDB_PROJECT}")
group_runs = [run for run in runs if run.group == GROUP]

print(f"Found {len(group_runs)} runs")

In [None]:
def extract_runid(run: Run) -> str:
    """
    Extracts the run id from a W&B run.
    """
    return run.id


def extract_config(run: Run) -> dict:
    """
    Extracts the relevant configs that identify an experiment
    from a W&B run.
    """
    config = run.config

    run_id = run.id
    dataset = config["dataset"]["name"]
    method = config["method"]["name"]
    sot = config["sot"]
    n_way = config["n_way"]
    n_shot = config["n_shot"]

    return {
        "run_id": run_id,
        "dataset": dataset,
        "method": method,
        "sot": sot,
        "n_way": n_way,
        "n_shot": n_shot,
    }


def extract_metrics(run: Run) -> dict:
    """
    Extracts the relevant metrics from a W&B run.
    """
    return {k: v for k, v in run.summary.items() if not k.startswith("_")}


def load_to_df(runs: list[Run]) -> pd.DataFrame:
    """
    Loads all runs into a pandas DataFrame.
    """
    configs = [extract_config(run) for run in runs]
    metrics = [extract_metrics(run) for run in runs]

    # Creating joint DataFrame
    df = pd.DataFrame(configs).join(pd.DataFrame(metrics)).set_index("run_id")

    # Creating Multi-Column Index
    column_tuples = [("config", col) for col in df.columns[: len(configs[0])]] + [
        ("eval", col) for col in df.columns[len(configs[0]) :]
    ]
    df.columns = pd.MultiIndex.from_tuples(column_tuples)

    return df


def load_model(run_id: str) -> nn.Module:
    """
    Load model artifact from W&B API using the
    run ID.
    """
    artifact = api.artifact(f"{WANDB_ENTITY}/{WANDB_PROJECT}/{run_id}:v0")
    path = os.path.join(ARTIFACT_DIR, run_id)
    artifact.download(root=path)


def init_model(cfg: dict) -> nn.Module:
    """
    Initialize model from hydra config using the
    run ID.
    """
    dataset = hydra.utils.instantiate(cfg["dataset"])  # TODO: not working  yet
    backbone = hydra.utils.instantiate(cfg["dataset"]["backbone"], x_dim=dataset.x_dim)
    model = hydra.utils.instantiate(cfg["method"], backbone=backbone)

    return model