In [7]:
import sys

import mlflow
import pandas as pd
from mlflow.tracking import MlflowClient

sys.path.append("../src/")
import matplotlib.pyplot as plt
import numpy as np

from helpers.mlflow_utils import mlflow_tracking_uri

client = MlflowClient(tracking_uri=mlflow_tracking_uri)

experiment_id_mapping = {
    "cifar10-renset": "206",
    "cifar100-resnet": "210",
    "mufac-resnet": "208",
    "cifar10-vit": "211",
    "cifar100-vit": "212",
    "mufac-vit": "213",
}


def baselines(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    unlearning_methods = [
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
    ]
    runs = client.search_runs(experiment_id)
    # Convert runs to a DataFrame
    metrics = ["mia", "acc_forget", "acc_retain", "t", "acc_test", "js", "js_proxy"]
    runs_df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.tags.get("mlflow.runName") != "our"
        ]
    )
    runs_df["method"] = [
        run.data.tags.get("mlflow.runName")
        for run in runs
        if run.data.tags.get("mlflow.runName") != "our"
    ]
    runs_df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.tags.get("mlflow.runName") != "our"
    ]
    runs_df = runs_df.set_index(["method", "seed"])

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test"]
    unlearning_methods = [
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
        "retrained",
    ]
    # Calculate the difference in 't' between the unlearning methods and 'retrain' of the same seed
    for method in unlearning_methods:
        for metric in gap_metrics:
            for seed in runs_df.index.get_level_values("seed").unique():
                runs_df.loc[method, f"{metric}_gap"] = abs(
                    runs_df.loc[(method, seed), metric]
                    - runs_df.loc[("retrained", seed), metric]
                )

    grouped_df = runs_df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js"] = grouped_df["js"].apply(lambda x: x * 1e4)
    grouped_df["js_proxy"] = grouped_df["js_proxy"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)
    grouped_df["avg_gap"] = (
        grouped_df[
            [
                ("acc_retain_gap", "mean"),
                ("acc_forget_gap", "mean"),
                ("mia_gap", "mean"),
            ]
        ]
        .mean(axis=1)
        .round(4)
    )

    grouped_df = grouped_df.drop(
        columns=[
            ("acc_retain_gap", "std"),
            ("acc_forget_gap", "std"),
            ("acc_test_gap", "std"),
            ("mia_gap", "std"),
        ]
    )
    grouped_df = grouped_df.sort_values(by=("avg_gap", ""), ascending=True)

    # Keep only the metrics that have 'gap' in the end and the 'js' and 'js_proxy'
    columns_to_keep = [
        col
        for col in grouped_df.columns
        if "gap" in col[0] or col[0] in ["js", "js_proxy", "t"]
    ]
    # Drop the subcolumn 'std' for the columns that have it
    columns_to_keep = [col for col in columns_to_keep if col[1] != "std"]
    filtered_grouped_df = grouped_df[columns_to_keep]
    filtered_grouped_df = filtered_grouped_df.drop(
        columns=[
            ("mia_gap", "mean"),
            ("acc_forget_gap", "mean"),
            ("acc_retain_gap", "mean"),
            ("acc_test_gap", "mean"),
        ]
    )
    # Specify the order of the columns
    filtered_grouped_df = filtered_grouped_df[
        [
            ("avg_gap", ""),
            ("t", "mean"),
            ("js", "mean"),
            ("js_proxy", "mean"),
        ]
    ]
    display(filtered_grouped_df)


def ours(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    unlearning_methods = ["our"]
    runs = client.search_runs(experiment_id)
    # Convert runs to a DataFrame
    metrics = ["mia", "acc_forget", "acc_retain", "t", "acc_test", "js", "js_proxy"]
    runs_df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.tags.get("mlflow.runName") == "our"
            or run.data.tags.get("mlflow.runName") == "retrained"
        ]
    )
    runs_df["method"] = [
        run.data.tags.get("mlflow.runName")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df["alpha"] = [
        run.data.params.get("alpha")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df.loc[runs_df["method"] == "our", "method"] = runs_df["alpha"].astype(str)
    runs_df.drop(columns=["alpha"], inplace=True)
    runs_df = runs_df.set_index(["method", "seed"])

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test"]

    # Calculate the difference in 't' between the unlearning methods and 'retrain' of the same seed
    for metric in gap_metrics:
        for method in [
            "2.0",
            "4.0",
            "8.0",
            "16.0",
            "32.0",
            "64.0",
            "128.0",
            "256.0",
            "512.0",
            "1024.0",
        ]:
            for seed in runs_df.index.get_level_values("seed").unique():
                runs_df.loc[method, f"{metric}_gap"] = abs(
                    runs_df.loc[(method, seed), metric]
                    - runs_df.loc[("retrained", seed), metric]
                )

    grouped_df = runs_df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js"] = grouped_df["js"].apply(lambda x: x * 1e4)
    grouped_df["js_proxy"] = grouped_df["js_proxy"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)
    grouped_df["avg_gap"] = (
        grouped_df[
            [
                ("acc_retain_gap", "mean"),
                ("acc_forget_gap", "mean"),
                ("mia_gap", "mean"),
                ("acc_test_gap", "mean"),
            ]
        ]
        .mean(axis=1)
        .round(4)
    )

    grouped_df = grouped_df.drop(
        columns=[
            ("acc_retain_gap", "std"),
            ("acc_forget_gap", "std"),
            ("acc_test_gap", "std"),
            ("mia_gap", "std"),
        ]
    )
    columns_to_keep = [
        col
        for col in grouped_df.columns
        if "gap" in col[0] or col[0] in ["js", "js_proxy", "t"]
    ]
    grouped_df = grouped_df.sort_values(by=("avg_gap", ""), ascending=True)

    # Drop the subcolumn 'std' for the columns that have it
    columns_to_keep = [col for col in columns_to_keep if col[1] != "std"]
    filtered_grouped_df = grouped_df[columns_to_keep]
    filtered_grouped_df = filtered_grouped_df.drop(
        columns=[
            ("mia_gap", "mean"),
            ("acc_forget_gap", "mean"),
            ("acc_retain_gap", "mean"),
            ("acc_test_gap", "mean"),
        ]
    )

    # Specify the order of the columns
    filtered_grouped_df = filtered_grouped_df[
        [
            ("avg_gap", ""),
            ("t", "mean"),
            ("js", "mean"),
            ("js_proxy", "mean"),
        ]
    ]
    display(filtered_grouped_df)

# CIFAR10-ResNet18

In [8]:
baselines('cifar10-renset')
ours('cifar10-renset')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,5.32,0.0,5.97
ssd,0.05,0.54,0.82,145.86
unsir,0.1067,0.45,0.65,17.29
relabel,0.1133,0.57,1.0,47.98
scrub,0.1167,0.58,0.41,62.39
finetune,0.12,0.43,1.03,81.45
neggrad,0.1267,0.49,1.06,80.64
badT,0.5133,0.33,2.39,287.45
original,,6.5,,61.12


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
128.0,0.0325,0.3,0.39,18.9
16.0,0.05,0.29,0.32,13.21
8.0,0.0625,0.29,0.3,15.75
4.0,0.0725,0.29,0.33,22.99
2.0,0.0825,0.29,0.35,27.4
64.0,0.1025,0.3,0.51,49.8
32.0,0.1075,0.3,0.49,41.75
1024.0,0.11,0.3,0.97,147.94
256.0,0.1125,0.3,0.94,141.04
512.0,0.115,0.3,0.98,147.95


# CIFAR100-ResNet18

In [9]:
baselines('cifar100-resnet')
ours('cifar100-resnet')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,3.39,0.0,16.34
ssd,0.1467,0.54,3.04,42.17
scrub,0.19,0.58,1.87,18.79
unsir,0.3667,0.45,3.05,40.02
finetune,0.3867,0.43,6.88,101.26
neggrad,0.3867,0.49,6.87,101.75
relabel,0.45,0.57,5.84,74.93
badT,0.4667,0.34,4.3,63.04
original,,1.79,,19.05


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.085,0.3,1.29,8.35
4.0,0.1175,0.3,1.49,11.96
32.0,0.1525,0.3,2.23,29.84
64.0,0.1525,0.29,2.48,36.39
8.0,0.1525,0.29,1.96,22.83
1024.0,0.155,0.29,2.61,39.93
128.0,0.155,0.29,2.6,39.61
16.0,0.155,0.29,2.16,27.78
256.0,0.155,0.29,2.61,40.0
512.0,0.155,0.3,2.61,39.95


# MUFAC-ResNet18

In [10]:
baselines('mufac-resnet')
ours('mufac-resnet')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,7.34,0.0,199.05
badT,0.0867,0.66,10.31,456.88
unsir,0.1667,1.68,16.32,988.15
neggrad,0.1733,0.91,19.16,1546.01
finetune,0.1767,0.76,19.52,1439.18
relabel,0.22,1.06,9.51,444.89
scrub,0.2433,1.2,10.53,254.92
ssd,0.2433,1.07,10.3,243.3
original,,3.66,,249.26


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
4.0,0.0975,0.62,6.31,124.65
2.0,0.13,0.62,6.9,133.11
8.0,0.1375,0.62,8.43,286.93
32.0,0.15,0.64,11.23,490.19
64.0,0.15,0.64,11.29,494.56
128.0,0.1525,0.64,11.29,493.86
16.0,0.1525,0.62,10.54,444.0
256.0,0.1525,0.64,11.29,493.96
512.0,0.1525,0.64,11.29,493.96
1024.0,0.44,0.64,11.47,536.52


# CIFAR10-ViT

In [11]:
baselines('cifar10-vit')
ours('cifar10-vit')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,111.0,0.0,1.66
scrub,0.0033,16.66,0.01,2.81
ssd,0.0033,13.65,0.02,2.69
finetune,0.0067,11.33,0.01,3.26
unsir,0.01,10.68,0.01,2.41
neggrad,0.0133,12.61,0.03,6.12
relabel,0.0167,12.78,0.02,2.11
badT,0.0433,8.79,0.12,9.21
original,,84.5,,2.88


KeyError: ('32.0', '12')

# CIFAR100-ViT

In [81]:
baselines('cifar100-vit')
ours('cifar100-vit')


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,112.25,0.0,2.58
relabel,0.0233,12.79,0.06,2.74
scrub,0.0267,16.74,0.04,2.04
ssd,0.0267,13.67,0.04,2.08
unsir,0.0267,10.69,0.08,3.11
finetune,0.0333,11.35,0.07,3.12
badT,0.0433,9.18,0.17,3.8
neggrad,0.0533,12.63,0.13,3.37
original,,121.85,,2.04


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
8.0,0.015,7.02,0.03,1.83
4.0,0.0125,7.02,0.03,1.86
2.0,0.0125,7.02,0.03,1.89
16.0,0.02,7.03,0.06,2.42
retrained,,112.25,0.0,2.58


# MUFAC-ViT

In [82]:
baselines('mufac-vit')
ours('mufac-vit')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
retrained,0.0,13.83,0.0,12.89
relabel,0.01,1.76,0.35,22.58
scrub,0.0167,2.21,0.05,8.27
ssd,0.0167,1.91,0.17,14.89
finetune,0.0333,1.4,0.27,18.46
neggrad,0.0433,1.67,0.39,20.68
unsir,0.07,3.21,0.85,23.52
badT,0.2067,2.09,1.89,201.41
original,,16.47,,9.31


Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.015,1.09,0.05,10.11
4.0,0.0175,1.09,0.05,10.13
8.0,0.0175,1.09,0.05,10.24
16.0,0.02,1.09,0.06,10.47
retrained,,13.83,0.0,12.89
