In [4]:
import sys

sys.path.append("../src/")
import mlflow
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from mlflow.tracking import MlflowClient
from scipy.stats import pearsonr

client = MlflowClient(tracking_uri="http://localhost:8000")
experiment_id_mapping = experiment_id_mapping = {
    "vit-imagenet": "7",
}

def get_results(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    alpha = '2.0' 
    methods = [
        "original",
        "retrained",
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
    ]
    metrics = ["mia", "acc_forget", "acc_retain", "acc_test", "acc_val", "t", "js_proxy"]
    runs = client.search_runs(experiment_id)
    df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.tags.get("mlflow.runName") in methods
            or (
                run.data.tags.get("mlflow.runName") == "our"
                and run.data.params.get("alpha") == str(alpha)
            )
        ]
    )
    df["method"] = [
        run.data.tags.get("mlflow.runName")
        for run in runs
        if run.data.tags.get("mlflow.runName") in methods
        or (
            run.data.tags.get("mlflow.runName") == "our"
            and run.data.params.get("alpha") == str(alpha)
        )
    ]

    df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.tags.get("mlflow.runName") in methods
        or (
            run.data.tags.get("mlflow.runName") == "our"
            and run.data.params.get("alpha") == str(alpha)
        )
    ]

    df.set_index(["method", "seed"], inplace=True)


    grouped_df = df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js_proxy"] = grouped_df["js_proxy"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)

    main_df = grouped_df.loc[
        ["original", "finetune", "neggrad", "relabel", "badT", "scrub", "ssd", "unsir", "our"],
        ["js_proxy", "t", "acc_retain", "acc_test", "acc_forget", "mia"],
    ]

    display(main_df)

    return main_df

In [5]:
vit_imagenet_main = get_results("vit-imagenet")

Unnamed: 0_level_0,js_proxy,js_proxy,t,t,acc_retain,acc_retain,acc_test,acc_test,acc_forget,acc_forget,mia,mia
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
original,1.22,0.01,,,0.94,0.0,0.81,0.0,0.95,0.0,0.71,0.0
finetune,2.22,0.02,16.24,0.03,0.97,0.0,0.8,0.0,0.94,0.01,0.78,0.0
neggrad,2.17,0.02,18.1,0.03,0.97,0.0,0.8,0.0,0.97,0.0,0.8,0.0
relabel,1.8,0.09,19.37,0.03,0.95,0.01,0.8,0.0,0.93,0.01,0.74,0.01
badT,3.16,3.25,11.66,0.03,0.77,0.21,0.66,0.18,0.76,0.21,0.52,0.18
scrub,1.24,0.01,24.49,0.03,0.94,0.0,0.8,0.01,0.94,0.0,0.71,0.0
ssd,1.23,0.01,22.61,0.1,0.94,0.0,0.8,0.0,0.94,0.0,0.71,0.0
unsir,2.54,0.03,33.12,0.03,0.99,0.0,0.79,0.0,0.94,0.0,0.77,0.01
our,1.11,0.01,10.72,0.01,0.94,0.0,0.8,0.0,0.94,0.0,0.61,0.01
