In [111]:
import wandb
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
import numpy as np
import pathlib
from wandb.apis.public import Run as apiRun
import json

In [112]:
api = wandb.Api()

entity = "KowalskiTeam"
project = "Pruning"
dataset = "cifar100"
job_type = "General_16-04-2024"
checkpoints = [0.88, 0.92, 0.96]
path = f"{entity}/{project}"

In [113]:
runs = api.runs(
    path,
    filters={
        "config.dataset.name": dataset,
        "state": "finished",
        "jobType": job_type,
        "display_name": "pruning_results",
    },
)

In [114]:
# download artifacts
for run in runs:
    run: apiRun
    group = run.group

    if pathlib.Path(f"artifacts/{group}_pruning_results:v0").exists():
        continue

    # TODO: async download
    artifact = api.artifact(f"{run.entity}/{run.project}/{group}_pruning_results:v0")
    artifact.download()


In [117]:
dataframes = defaultdict(list)
aggregation_columns = ["top1_accuracy", "top5_accuracy"]

for run in runs:
    run: apiRun
    group = run.group
    scheluder_name = run.config["pruning"]["scheduler"]["name"]

    with open(f"artifacts/{group}_pruning_results:v0/pruning_results.table.json") as f:
        json_dict = json.load(f)

    df = pd.DataFrame(json_dict["data"], columns=json_dict["columns"])
    agg_df = df.drop(["repeat", "top5_accuracy"], axis=1).groupby(["pruned_precent"]).agg({"top1_accuracy": ["mean", "std"]})
    agg_df = agg_df.round(4)
    agg_df.columns = agg_df.columns.map('_'.join)
    agg_df = agg_df.reset_index()
    config_series = pd.json_normalize(run.config).squeeze()

    for key, value in config_series.items():
        agg_df[key] = value

    agg_df["group"] = group

    dataframes[scheluder_name].append(agg_df)

for name in dataframes:
    dataframes[name] = pd.concat(dataframes[name])
    # save to csv
    dataframes[name].to_csv(f"csvs/pruning_results_{dataset}_{name}.csv", index=False)