In [13]:
import sys

sys.path.append("../src/helpers")
import mlflow
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from mlflow.tracking import MlflowClient
from scipy.stats import pearsonr
from mlflow_utils import mlflow_tracking_uri

client = MlflowClient(tracking_uri=mlflow_tracking_uri)
experiment_id_mapping = experiment_id_mapping = {
    "cat-resnet": "214",
    "rocket-resnet": "216",
    "beaver-resnet": "217",
    "pizza-resnet": "413",
}

def get_results(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    methods = [
        # "original",
        "retrained",
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
        "our",
    ]
    # methods.extend(our_method_names)

    metrics = [
        "mia",
        "acc_forget",
        "acc_retain",
        "acc_test",
        "acc_val",
        "js",
        "t",
        "js_proxy",
    ]
    runs = client.search_runs(experiment_id)
    df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.params.get("method") in methods
        ]
    )

    df["method"] = [
        run.data.params.get("method")
        for run in runs
        if run.data.params.get("method") in methods
    ]

    df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.params.get("method") in methods
    ]

    df.set_index(["method", "seed"], inplace=True)

    grouped_df = df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js"] = grouped_df["js"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test", "acc_val"]
    for method in methods:
        for metric in gap_metrics:
            grouped_df.loc[method, f"{metric}_gap"] = abs(
                grouped_df.loc[method, (metric, "mean")]
                - grouped_df.loc["retrained", (metric, "mean")]
            )
    grouped_df["avg_gap"] = (
        grouped_df[
            [
                "mia_gap",
                "acc_retain_gap",
                "acc_forget_gap",
                "acc_test_gap",
            ]
        ]
        .mean(axis=1)
        .round(4)
    )

    main_df = grouped_df.loc[
        [m for m in methods if m in grouped_df.index],
        ["avg_gap", "js", 't', 'js_proxy'],
    # ].sort_values(by=("avg_gap", ""), ascending=True)
    ].sort_values(by=("js", "mean"), ascending=True)
    methods_for_appendix = [m for m in methods if m in grouped_df.index]

    appendix_df = grouped_df.loc[
        methods_for_appendix,
        [
            "avg_gap",
            "mia_gap",
            "acc_forget_gap",
            "acc_retain_gap",
            "acc_test_gap",
            "mia",
            "acc_forget",
            "acc_retain",
            "acc_test",
        ],
    ]
    appendix_df = appendix_df.sort_values(by=("avg_gap", ""), ascending=True)

    print(experiment_name)

    display(main_df)
    display(appendix_df)

    return main_df, appendix_df

# Classes


In [9]:
rn_cat, appendix_df = get_results("cat-resnet")

cat-resnet


Unnamed: 0_level_0,avg_gap,js,js,t,t,js_proxy,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
retrained,0.0,0.0,0.0,4.88,0.13,1.47,0.02
badT,0.175,2.5,0.44,0.33,0.01,1.03,0.18
unsir,0.0575,2.96,0.16,0.44,0.01,0.57,0.18
our,0.1425,2.99,0.88,0.28,0.0,0.22,0.25
relabel,0.1025,3.78,2.6,0.56,0.0,0.85,0.5
ssd,0.375,4.31,2.07,1.01,0.01,0.64,0.51
scrub,0.4175,7.44,0.03,0.58,0.01,0.13,0.01
neggrad,0.34,7.64,0.32,0.48,0.0,0.03,0.03
finetune,0.255,8.41,0.57,0.42,0.0,0.07,0.02


Unnamed: 0_level_0,avg_gap,mia_gap,acc_forget_gap,acc_retain_gap,acc_test_gap,mia,mia,acc_forget,acc_forget,acc_retain,acc_retain,acc_test,acc_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
retrained,0.0,0.0,0.0,0.0,0.0,0.35,0.02,0.0,0.0,0.99,0.01,0.83,0.01
unsir,0.0575,0.13,0.06,0.03,0.01,0.48,0.14,0.06,0.07,0.96,0.01,0.82,0.01
relabel,0.1025,0.36,0.03,0.01,0.01,0.71,0.51,0.03,0.03,0.98,0.01,0.82,0.02
our,0.1425,0.35,0.14,0.0,0.08,0.0,0.0,0.14,0.11,0.99,0.0,0.91,0.01
badT,0.175,0.64,0.0,0.04,0.02,0.99,0.01,0.0,0.0,0.95,0.01,0.81,0.01
finetune,0.255,0.43,0.54,0.01,0.04,0.78,0.19,0.54,0.07,0.98,0.01,0.87,0.01
neggrad,0.34,0.34,0.84,0.14,0.04,0.01,0.0,0.84,0.1,0.85,0.08,0.79,0.07
ssd,0.375,0.15,0.32,0.59,0.44,0.5,0.32,0.32,0.56,0.4,0.33,0.39,0.32
scrub,0.4175,0.56,1.0,0.01,0.1,0.91,0.01,1.0,0.0,1.0,0.0,0.93,0.01


In [10]:
rn_beaver = get_results("beaver-resnet")

beaver-resnet


Unnamed: 0_level_0,avg_gap,js,js,t,t,js_proxy,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
retrained,0.0,0.0,0.0,4.0,0.11,0.85,0.24
our,0.12,25.46,1.41,0.23,0.01,0.51,0.08
ssd,0.4325,45.19,9.19,0.83,0.03,1.41,0.44
scrub,0.3225,64.09,8.71,0.55,0.0,0.18,0.06
unsir,0.405,76.28,6.88,0.2,0.01,0.36,0.13
badT,0.3,76.85,3.12,0.26,0.01,0.53,0.15
finetune,0.285,101.48,2.87,0.43,0.0,0.67,0.17
relabel,0.2925,102.66,3.11,0.45,0.0,0.7,0.2
neggrad,0.3725,108.5,2.69,0.44,0.01,0.88,0.19


Unnamed: 0_level_0,avg_gap,mia_gap,acc_forget_gap,acc_retain_gap,acc_test_gap,mia,mia,acc_forget,acc_forget,acc_retain,acc_retain,acc_test,acc_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
retrained,0.0,0.0,0.0,0.0,0.0,0.23,0.02,0.0,0.0,0.91,0.03,0.59,0.02
our,0.12,0.23,0.19,0.02,0.04,0.0,0.0,0.19,0.07,0.93,0.03,0.63,0.02
finetune,0.285,0.23,0.17,0.5,0.24,0.0,0.0,0.17,0.07,0.41,0.07,0.35,0.05
relabel,0.2925,0.23,0.2,0.5,0.24,0.0,0.0,0.2,0.08,0.41,0.06,0.35,0.05
badT,0.3,0.23,0.18,0.53,0.26,0.0,0.0,0.18,0.1,0.38,0.04,0.33,0.03
scrub,0.3225,0.39,0.87,0.02,0.01,0.62,0.12,0.87,0.08,0.93,0.03,0.6,0.02
neggrad,0.3725,0.34,0.41,0.5,0.24,0.57,0.03,0.41,0.12,0.41,0.06,0.35,0.05
unsir,0.405,0.51,0.97,0.08,0.06,0.74,0.09,0.97,0.01,0.99,0.01,0.65,0.01
ssd,0.4325,0.25,0.0,0.9,0.58,0.48,0.26,0.0,0.0,0.01,0.0,0.01,0.0


In [12]:
rocket_resnet, appendix_df = get_results("rocket-resnet")

rocket-resnet


Unnamed: 0_level_0,avg_gap,js,js,t,t,js_proxy,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
retrained,0.0,0.0,0.0,5.05,2.01,1.49,0.35
our,0.1475,20.77,6.25,0.22,0.0,0.87,0.11
ssd,0.47,36.44,11.02,0.81,0.01,1.67,0.76
scrub,0.4,66.49,14.36,0.55,0.0,0.1,0.03
badT,0.345,67.73,10.8,0.26,0.0,0.34,0.06
unsir,0.435,74.53,12.03,0.2,0.01,0.17,0.04
finetune,0.3525,95.27,8.23,0.43,0.0,0.61,0.25
relabel,0.3525,96.02,7.4,0.45,0.0,0.6,0.2
neggrad,0.445,98.06,8.04,0.44,0.01,0.54,0.24


Unnamed: 0_level_0,avg_gap,mia_gap,acc_forget_gap,acc_retain_gap,acc_test_gap,mia,mia,acc_forget,acc_forget,acc_retain,acc_retain,acc_test,acc_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
retrained,0.0,0.0,0.0,0.0,0.0,0.18,0.1,0.0,0.0,0.94,0.05,0.62,0.05
our,0.1475,0.18,0.39,0.02,0.0,0.0,0.0,0.39,0.27,0.92,0.03,0.62,0.01
badT,0.345,0.18,0.4,0.54,0.26,0.0,0.0,0.4,0.04,0.4,0.05,0.36,0.05
finetune,0.3525,0.18,0.43,0.53,0.27,0.0,0.0,0.43,0.06,0.41,0.07,0.35,0.05
relabel,0.3525,0.18,0.43,0.53,0.27,0.0,0.0,0.43,0.02,0.41,0.07,0.35,0.05
scrub,0.4,0.64,0.93,0.01,0.02,0.82,0.09,0.93,0.08,0.93,0.03,0.6,0.02
unsir,0.435,0.67,0.99,0.05,0.03,0.85,0.03,0.99,0.01,0.99,0.01,0.65,0.01
neggrad,0.445,0.44,0.52,0.54,0.28,0.62,0.01,0.52,0.1,0.4,0.07,0.34,0.05
ssd,0.47,0.51,0.0,0.84,0.53,0.69,0.47,0.0,0.0,0.1,0.08,0.09,0.07


In [14]:
pizza_resnet, appendix_df = get_results("pizza-resnet")

pizza-resnet


Unnamed: 0_level_0,avg_gap,js,js,t,t,js_proxy,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
retrained,0.0,0.0,0.0,42.15,16.05,1.71,0.3
ssd,0.365,34.96,14.21,3.19,0.03,1.57,0.58
our,0.0925,37.02,18.68,1.3,0.02,0.66,0.56
badT,0.3125,72.62,22.07,1.59,0.01,0.48,0.25
scrub,0.42,73.1,0.82,4.05,0.03,0.15,0.1
neggrad,0.325,86.36,9.66,3.24,0.03,0.27,0.14
relabel,0.2925,92.27,6.43,3.27,0.03,0.28,0.14
finetune,0.2975,94.96,7.24,3.23,0.01,0.31,0.17
unsir,0.5075,102.29,9.33,1.01,0.01,0.28,0.14


Unnamed: 0_level_0,avg_gap,mia_gap,acc_forget_gap,acc_retain_gap,acc_test_gap,mia,mia,acc_forget,acc_forget,acc_retain,acc_retain,acc_test,acc_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,mean,std,mean,std,mean,std,mean,std
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
retrained,0.0,0.0,0.0,0.0,0.0,0.22,0.16,0.0,0.0,0.92,0.14,0.57,0.07
our,0.0925,0.04,0.19,0.09,0.05,0.26,0.46,0.19,0.28,0.83,0.11,0.52,0.02
relabel,0.2925,0.11,0.67,0.22,0.17,0.33,0.58,0.67,0.23,0.7,0.05,0.4,0.01
finetune,0.2975,0.11,0.69,0.22,0.17,0.33,0.58,0.69,0.22,0.7,0.05,0.4,0.01
badT,0.3125,0.39,0.29,0.37,0.2,0.61,0.54,0.29,0.28,0.55,0.06,0.37,0.02
neggrad,0.325,0.26,0.64,0.23,0.17,0.48,0.5,0.64,0.11,0.69,0.05,0.4,0.01
ssd,0.365,0.46,0.0,0.63,0.37,0.68,0.51,0.0,0.0,0.29,0.31,0.2,0.22
scrub,0.42,0.71,0.96,0.01,0.0,0.93,0.08,0.96,0.07,0.91,0.15,0.57,0.07
unsir,0.5075,0.78,0.8,0.27,0.18,1.0,0.0,0.8,0.17,0.65,0.04,0.39,0.02
