In [497]:
import sys

import mlflow
import pandas as pd
from mlflow.tracking import MlflowClient

sys.path.append("../src/")
import matplotlib.pyplot as plt
import numpy as np

from helpers.mlflow_utils import mlflow_tracking_uri

client = MlflowClient(tracking_uri=mlflow_tracking_uri)

experiment_id_mapping = {
    "cifar10-renset": "206",
    "cifar100-resnet": "210",
    "mufac-resnet": "208",
    "cifar10-vit": "211",
    "cifar100-vit": "212",
    "mufac-vit": "213",
}


def baselines(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    unlearning_methods = [
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
    ]
    runs = client.search_runs(experiment_id)
    # Convert runs to a DataFrame
    metrics = [
        "mia",
        "acc_forget",
        "acc_retain",
        "t",
        "acc_test",
        "js",
        "js_proxy",
        "acc_val",
    ]
    runs_df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.tags.get("mlflow.runName") != "our"
        ]
    )
    runs_df["method"] = [
        run.data.tags.get("mlflow.runName")
        for run in runs
        if run.data.tags.get("mlflow.runName") != "our"
    ]
    runs_df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.tags.get("mlflow.runName") != "our"
    ]
    runs_df = runs_df.set_index(["method", "seed"])

    unlearning_methods = [
        "finetune",
        "neggrad",
        "relabel",
        "badT",
        "scrub",
        "ssd",
        "unsir",
        "retrained",
    ]

    grouped_df = runs_df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js"] = grouped_df["js"].apply(lambda x: x * 1e4)
    grouped_df["js_proxy"] = grouped_df["js_proxy"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test", "acc_val"]
    for method in unlearning_methods:
        for metric in gap_metrics:
            grouped_df.loc[method, f"{metric}_gap"] = abs(
                grouped_df.loc[method, (metric, "mean")]
                - grouped_df.loc["retrained", (metric, "mean")]
            )

    grouped_df["avg_gap"] = (
        grouped_df[
            [
                "acc_retain_gap",
                "acc_forget_gap",
                "mia_gap",
                # "acc_test_gap",
            ]
        ]
        .mean(axis=1)
        .round(4)
    )

    grouped_df = grouped_df.sort_values(by=("avg_gap", ""), ascending=True)

    # Specify the order of the columns
    filtered_grouped_df = grouped_df[
        [
            ("avg_gap", ""),
            ("t", "mean"),
            ("js", "mean"),
            ("js_proxy", "mean"),
        ]
    ]
    display(filtered_grouped_df)


def ours(experiment_name):
    experiment_id = experiment_id_mapping[experiment_name]
    unlearning_methods = ["our"]
    runs = client.search_runs(experiment_id)
    # Convert runs to a DataFrame
    metrics = [
        "mia",
        "acc_forget",
        "acc_retain",
        "t",
        "acc_test",
        "js",
        "js_proxy",
        "acc_val",
    ]
    runs_df = pd.DataFrame(
        [
            {k: v for k, v in run.data.metrics.items() if k in metrics}
            for run in runs
            if run.data.tags.get("mlflow.runName") == "our"
            or run.data.tags.get("mlflow.runName") == "retrained"
        ]
    )
    runs_df["method"] = [
        run.data.tags.get("mlflow.runName")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df["seed"] = [
        run.data.params.get("seed")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df["alpha"] = [
        run.data.params.get("alpha")
        for run in runs
        if run.data.tags.get("mlflow.runName") == "our"
        or run.data.tags.get("mlflow.runName") == "retrained"
    ]
    runs_df.loc[runs_df["method"] == "our", "method"] = runs_df["alpha"].astype(str)
    runs_df.drop(columns=["alpha"], inplace=True)
    runs_df = runs_df.set_index(["method", "seed"])

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test", "acc_val"]

    grouped_df = runs_df.groupby("method").aggregate(["mean", "std"])
    grouped_df["js"] = grouped_df["js"].apply(lambda x: x * 1e4)
    grouped_df["js_proxy"] = grouped_df["js_proxy"].apply(lambda x: x * 1e4)
    grouped_df = grouped_df.round(2)

    gap_metrics = ["mia", "acc_forget", "acc_retain", "acc_test", "acc_val"]
    if experiment_name in ["cifar10-vit", "cifar100-vit", "mufac-vit"]:
        alphas_str = ['2.0', '4.0', '8.0', '16.0']
    else:
        alphas_str = ['2.0', '4.0', '8.0', '16.0', '32.0', '64.0', '128.0', '256.0', '512.0', '1024.0']

    alphas_str = ['2.0', '4.0', '8.0', '16.0']
    for method in alphas_str:
        for metric in gap_metrics:
            grouped_df.loc[method, f"{metric}_gap"] = abs(
                grouped_df.loc[method, (metric, "mean")]
                - grouped_df.loc["retrained", (metric, "mean")]
            )
    
    # We use the accuracy on the validation set instead of the test set
    grouped_df["avg_gap"] = (
        grouped_df[
            [
                "acc_retain_gap",
                "acc_forget_gap",
                "mia_gap",
                "acc_val_gap",
            ]
        ]
        .mean(axis=1)
        .round(4)
    )

    grouped_df = grouped_df.sort_values(by=("avg_gap", ""), ascending=True)

    columns_to_keep = [
        col
        for col in grouped_df.columns
        if "gap" in col[0] or col[0] in ["js", "js_proxy", "t"]
    ]
    grouped_df = grouped_df.sort_values(by=("avg_gap", ""), ascending=True)

    # Specify the order of the columns
    filtered_grouped_df = grouped_df[
        [
            ("avg_gap", ""),
            ("t", "mean"),
            ("js", "mean"),
            ("js_proxy", "mean"),
        ]
    ]
    display(filtered_grouped_df)

# CIFAR10-ResNet18

In [498]:
# baselines('cifar10-renset')
ours('cifar10-renset')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
8.0,0.04,0.29,0.3,15.75
16.0,0.045,0.29,0.32,13.21
4.0,0.0525,0.29,0.33,22.99
2.0,0.0575,0.29,0.35,27.4
1024.0,,0.3,0.97,147.94
128.0,,0.3,0.39,18.9
256.0,,0.3,0.94,141.04
32.0,,0.3,0.49,41.75
512.0,,0.3,0.98,147.95
64.0,,0.3,0.51,49.8


# CIFAR100-ResNet18

In [499]:
# baselines('cifar100-resnet')
ours('cifar100-resnet')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.105,0.3,1.29,8.35
4.0,0.1325,0.3,1.49,11.96
16.0,0.145,0.29,2.16,27.78
8.0,0.155,0.29,1.96,22.83
1024.0,,0.29,2.61,39.93
128.0,,0.29,2.6,39.61
256.0,,0.29,2.61,40.0
32.0,,0.3,2.23,29.84
512.0,,0.3,2.61,39.95
64.0,,0.29,2.48,36.39


# MUFAC-ResNet18

In [500]:
# baselines('mufac-resnet')
ours('mufac-resnet')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.1275,0.62,6.9,133.11
4.0,0.14,0.62,6.31,124.65
8.0,0.185,0.62,8.43,286.93
16.0,0.1975,0.62,10.54,444.0
1024.0,,0.64,11.47,536.52
128.0,,0.64,11.29,493.86
256.0,,0.64,11.29,493.96
32.0,,0.64,11.23,490.19
512.0,,0.64,11.29,493.96
64.0,,0.64,11.29,494.56


# CIFAR10-ViT

In [501]:
# baselines('cifar10-vit')
ours('cifar10-vit')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.005,7.34,0.01,2.08
4.0,0.005,7.14,0.01,2.03
8.0,0.005,7.05,0.01,1.92
16.0,0.01,7.03,0.01,1.67
retrained,,111.0,0.0,1.66


# CIFAR100-ViT

In [502]:
# baselines('cifar100-vit')
ours('cifar100-vit')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.015,7.02,0.03,1.89
4.0,0.015,7.02,0.03,1.86
8.0,0.0175,7.02,0.03,1.83
16.0,0.0225,7.03,0.06,2.42
retrained,,112.25,0.0,2.58


# MUFAC-ViT

In [503]:
# baselines('mufac-vit')
ours('mufac-vit')

Unnamed: 0_level_0,avg_gap,t,js,js_proxy
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
2.0,0.02,1.09,0.05,10.11
4.0,0.02,1.09,0.05,10.13
16.0,0.0225,1.09,0.06,10.47
8.0,0.0225,1.09,0.05,10.24
retrained,,13.83,0.0,12.89
