# Continual Unlearning Results Visualization for CIFAR-100

This notebook analyzes and visualizes the performance of different continual unlearning methods on the CIFAR-100 dataset. We'll load results from CSV files and compare various methods against the ground truth (retrain_continual_unlearn).

In [1]:
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)
plt.rcParams['figure.figsize'] = (14, 8)

In [None]:
# Load the data for CIFAR-100
try:
    df = pd.read_csv('eval_results_for_cifar100.csv')
    print(f"Loaded CIFAR-100 data with {df.shape[0]} rows and {df.shape[1]} columns")
    if 'model' not in df.columns:
        print("Warning: 'model' column not found in the CSV. Please ensure it's present for model-specific analysis.")
        # Potentially add a default model column if appropriate, or handle errors downstream
        # For now, operations requiring 'model' might fail or produce combined results.
except FileNotFoundError:
    print("Error: eval_results_for_cifar100.csv not found. Please generate this file first.")
    df = pd.DataFrame() # Create an empty DataFrame to prevent subsequent errors

# Filter out any methods we want to exclude (if any)
if 'method' in df.columns and not df.empty:
    df = df[~df['method'].isin(['boundary_expanding', 'boundary_shrink'])] # Example filter

Loaded data with 122 rows and 315 columns


In [None]:
# Calculate Linear Probe Accuracy Gaps vs. Retrain (per model)
if not df.empty and 'method' in df.columns and 'model' in df.columns:
    LP_columns = [col for col in df.columns if "test_LP_accuracy_class" in col]
    # Dynamically get class indices from column names. Convert to int for sorting.
    class_indices_str = sorted(list(set([col.split("_")[-1] for col in LP_columns])), key=int)
    class_indices_int = [int(c) for c in class_indices_str]

    retrain_df_base = df[df['method'] == 'retrain']
    all_gap_data_frames = []

    for model_name in df['model'].unique():
        model_df = df[df['model'] == model_name].copy()
        retrain_df_model = retrain_df_base[retrain_df_base['model'] == model_name]
        gap_data_model = {}

        if retrain_df_model.empty:
            print(f"No 'retrain' data found for model {model_name}. Skipping gap calculation for this model.")
            continue

        for end_idx in sorted(retrain_df_model.forget_class_end.unique()):
            tmp_retrain_end_idx_df = retrain_df_model[retrain_df_model.forget_class_end == end_idx]
            if tmp_retrain_end_idx_df.empty:
                continue
            
            baseline_values_for_model_end_idx = {}
            for class_idx_s in class_indices_str: # Use string version for dict keys if matching column names
                col_name = f'test_LP_accuracy_class_{class_idx_s}'
                if col_name in tmp_retrain_end_idx_df.columns:
                    baseline_values_for_model_end_idx[class_idx_s] = tmp_retrain_end_idx_df[col_name].values[0]

            mask = (model_df.forget_class_end == end_idx)
            for class_idx_s in class_indices_str:
                if class_idx_s in baseline_values_for_model_end_idx:
                    baseline_value = baseline_values_for_model_end_idx[class_idx_s]
                    current_col_name = f'test_LP_accuracy_class_{class_idx_s}'
                    gap_column_name = f'LP_gap_{class_idx_s}'
                    
                    if gap_column_name not in gap_data_model:
                        gap_data_model[gap_column_name] = pd.Series(index=model_df.index, dtype=float)
                    
                    if current_col_name in model_df.columns:
                        gap_data_model[gap_column_name].loc[mask] = model_df.loc[mask, current_col_name] - baseline_value
        
        if gap_data_model:
            gap_df_model = pd.DataFrame(gap_data_model, index=model_df.index)
            # Ensure original index is preserved if it's meaningful, or reset if not.
            # Here, we reset_index before concat and then set it back if 'index' was a column from reset_index().
            # This handles cases where model_df might have a non-default index.
            original_model_df_index_name = model_df.index.name
            model_df_reset = model_df.reset_index()
            gap_df_model_reset = gap_df_model.reset_index(drop=True) # drop its new index
            
            model_df_with_gaps = pd.concat([model_df_reset, gap_df_model_reset], axis=1)
            if 'index' in model_df_with_gaps.columns and original_model_df_index_name is not None:
                 model_df_with_gaps = model_df_with_gaps.set_index('index')
                 model_df_with_gaps.index.name = original_model_df_index_name
            elif 'level_0' in model_df_with_gaps.columns and original_model_df_index_name is None: # Default from reset_index if no name
                 model_df_with_gaps = model_df_with_gaps.set_index('level_0')
                 model_df_with_gaps.index.name = None
            all_gap_data_frames.append(model_df_with_gaps)

    if all_gap_data_frames:
        df = pd.concat(all_gap_data_frames)
        gap_cols_overall = [col for col in df.columns if col.startswith('LP_gap_')]
        df[gap_cols_overall] = df[gap_cols_overall].fillna(np.nan)
        print("Gap calculations complete.")
        print(df[[col for col in df.columns if 'LP_gap_' in col or col in ['method', 'model', 'forget_class_end']]].head())
    else:
        print("No gap data was generated. Ensure 'retrain' method and 'model' column exist and data is available.")
elif df.empty:
    print("DataFrame is empty. Cannot perform gap calculations.")
else:
    print("DataFrame does not contain 'method' or 'model' columns. Cannot perform model-specific gap calculations.")

array([ 4,  9, 14, 19, 24, 29, 34, 39, 44, 49, 54, 59, 64, 69, 74, 79, 84,
       89, 94])

In [None]:
# Visualize Retain, Forget, Remember Gaps for CIFAR-100 (per model)
if not df.empty and 'method' in df.columns and 'model' in df.columns and any(col.startswith('LP_gap_') for col in df.columns):
    results_c100 = []
    num_classes_cifar100 = 100

    for _, row in df.iterrows():
        if pd.isna(row['method']) or pd.isna(row['model']) or pd.isna(row['forget_class_end']):
            continue
        model_name = row['model']
        end_idx = int(row['forget_class_end'])
        forget_begin_idx = int(row.get('forget_class_begin', 0)) # Default to 0 if not present

        gap_columns_present = [col for col in row.index if col.startswith('LP_gap_') and not pd.isna(row[col])]
        class_indices_from_cols_int = sorted([int(c.split('_')[-1]) for c in gap_columns_present])

        # Retain Gap: classes after the forgotten block
        retain_gap_cols_indices = [c for c in class_indices_from_cols_int if c > end_idx]
        retain_gap_cols_named = [f'LP_gap_{i}' for i in retain_gap_cols_indices]
        retain_gap = row[retain_gap_cols_named].mean() if retain_gap_cols_named else np.nan

        # Forget Gap: classes within the forgotten block [forget_begin_idx, end_idx]
        forget_gap_cols_indices = [c for c in class_indices_from_cols_int if forget_begin_idx <= c <= end_idx]
        forget_gap_cols_named = [f'LP_gap_{i}' for i in forget_gap_cols_indices]
        forget_gap = row[forget_gap_cols_named].mean() if forget_gap_cols_named else np.nan

        # Remember Gap: classes before the forgotten block
        remember_gap_cols_indices = [c for c in class_indices_from_cols_int if c < forget_begin_idx]
        remember_gap_cols_named = [f'LP_gap_{i}' for i in remember_gap_cols_indices]
        remember_gap = row[remember_gap_cols_named].mean() if remember_gap_cols_named else np.nan

        results_c100.append({
            'method': row['method'],
            'model': model_name,
            'forget_class_end': end_idx, # Key for grouping/plotting
            'classes_forgotten': row['classes_forgotten'],
            'retain_gap': retain_gap,
            'forget_gap': forget_gap,
            'remember_gap': remember_gap
        })

    gap_summary_df_c100 = pd.DataFrame(results_c100).drop_duplicates()

    for model_name_iter in gap_summary_df_c100['model'].unique():
        model_specific_gap_df = gap_summary_df_c100[gap_summary_df_c100['model'] == model_name_iter]
        if model_specific_gap_df.empty:
            print(f"No gap summary data for model {model_name_iter} on CIFAR-100")
            continue
        
        fig, axs = plt.subplots(1, 3, figsize=(24, 7))
        methods_to_plot = sorted([m for m in model_specific_gap_df['method'].unique() if m != 'retrain'])
        colors = sns.color_palette("tab10", len(methods_to_plot))
        markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'X']

        for i, method_iter in enumerate(methods_to_plot):
            method_data = model_specific_gap_df[model_specific_gap_df['method'] == method_iter].sort_values('classes_forgotten')
            if method_data.empty:
                continue
            # Plot against 'classes_forgotten' as x-axis
            axs[0].plot(method_data['classes_forgotten'], method_data['retain_gap'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)
            axs[1].plot(method_data['classes_forgotten'], method_data['forget_gap'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)
            axs[2].plot(method_data['classes_forgotten'], method_data['remember_gap'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)

        plot_titles = ['Retain Gap', 'Forget Gap', 'Remember Gap']
        ylabels = ['Gap (%)', 'Gap (%)', 'Gap (%)']
        unique_x_values = sorted(model_specific_gap_df['classes_forgotten'].unique())

        for i in range(3):
            axs[i].set_title(plot_titles[i], fontsize=14, fontweight='bold')
            axs[i].set_xlabel('Number of Classes Forgotten', fontsize=12)
            axs[i].set_ylabel(ylabels[i], fontsize=12)
            axs[i].grid(True, linestyle='--', alpha=0.7)
            axs[i].axhline(y=0, color='r', linestyle='--', alpha=0.5)
            if unique_x_values:
                axs[i].set_xticks(unique_x_values)
                # axs[i].set_xticklabels([str(x) for x in unique_x_values]) # Uncomment if specific labels needed

        handles, labels = axs[0].get_legend_handles_labels()
        if handles:
            fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.05), fancybox=True, shadow=True, ncol=min(5, len(handles)), fontsize=12)
        plt.suptitle(f'Linear Probe Test Acc. Gap Analysis for {model_name_iter} on CIFAR-100 (vs. Retrain)', fontsize=16, fontweight='bold')
        plt.tight_layout(rect=[0, 0.08, 1, 0.95])
        plt.show()
elif df.empty:
    print("DataFrame is empty. Cannot perform gap visualization.")
else:
    print("Gap data not available or 'method'/'model' columns missing. Skipping gap visualization.")

Unnamed: 0,method,dataset,forget_class_begin,forget_class_end,classes_forgotten,unlearning_time,accuracy_retain,accuracy_forget,accuracy_val,accuracy_test,...,test_LP_accuracy_class_95,test_LP_accuracy_class_96,test_LP_accuracy_class_97,test_LP_accuracy_class_98,test_LP_accuracy_class_99,mia_forget_correctness,mia_forget_confidence,mia_forget_entropy,mia_forget_m_entropy,mia_forget_prob
19,retrain,cifar100,0,4,4,2731.865137,99.978947,0.0,92.82,72.44,...,70.0,62.0,78.0,59.0,69.0,1.0,1.0,0.852,1.0,0.891111
20,retrain,cifar100,0,9,9,2548.175206,99.980247,0.0,87.98,68.36,...,71.0,64.0,78.0,61.0,76.0,1.0,1.0,0.867556,1.0,0.816444
21,retrain,cifar100,0,14,14,2512.537116,99.976471,0.0,83.12,65.35,...,71.0,62.0,83.0,48.0,74.0,1.0,1.0,0.815111,1.0,0.756889
22,retrain,cifar100,0,19,19,2297.386467,99.977778,0.0,78.26,61.67,...,69.0,63.0,74.0,42.0,68.0,1.0,1.0,0.870222,1.0,0.812889
23,retrain,cifar100,0,24,24,2153.357804,99.97037,0.0,73.34,58.05,...,72.0,61.0,72.0,50.0,61.0,1.0,1.0,0.824,1.0,0.768889
24,retrain,cifar100,0,29,29,2054.268947,99.965079,0.0,68.52,53.75,...,69.0,59.0,73.0,49.0,69.0,1.0,1.0,0.870222,1.0,0.806222
25,retrain,cifar100,0,34,34,1894.195484,99.965812,0.0,63.76,50.38,...,63.0,53.0,73.0,45.0,71.0,1.0,1.0,0.837333,1.0,0.579111
26,retrain,cifar100,0,39,39,1780.355181,99.903704,0.0,58.76,46.57,...,56.0,54.0,74.0,42.0,62.0,1.0,1.0,0.778222,1.0,0.764889
27,retrain,cifar100,0,44,44,1641.675249,99.971717,0.0,54.02,43.71,...,57.0,52.0,78.0,39.0,69.0,1.0,1.0,0.932,1.0,0.833333
28,retrain,cifar100,0,49,49,1520.356477,99.968889,0.0,49.2,40.13,...,61.0,48.0,69.0,30.0,61.0,1.0,1.0,0.757333,1.0,0.674667


In [None]:
# Visualize Precise Metrics for CIFAR-100 (per model)
if not df.empty and 'method' in df.columns and 'model' in df.columns:
    required_metric_cols = ['precise_retain_accuracy_test', 'end_class_accuracy_test', 'precise_remember_accuracy_test']
    # Check if all required columns are present in df.columns
    all_cols_present = all(col in df.columns for col in required_metric_cols)

    if not all_cols_present:
        print(f"Missing one or more required columns for precise metric visualization: {required_metric_cols}. Skipping this plot.")
    else:
        for model_name_iter in df['model'].unique():
            model_specific_df = df[df['model'] == model_name_iter]
            if model_specific_df.empty:
                print(f"No data for model {model_name_iter} for precise metrics on CIFAR-100.")
                continue

            methods_to_plot = sorted(model_specific_df['method'].unique())
            if not methods_to_plot or (len(methods_to_plot) == 1 and 'retrain' in methods_to_plot):
                print(f"Not enough distinct methods to plot for model {model_name_iter} on CIFAR-100.")
                continue

            fig, axs = plt.subplots(1, 3, figsize=(24, 7))
            colors = sns.color_palette("tab10", len(methods_to_plot))
            markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'X']

            for i, method_iter in enumerate(methods_to_plot):
                method_data = model_specific_df[model_specific_df['method'] == method_iter].sort_values('classes_forgotten')
                if method_data.empty:
                    continue
                axs[0].plot(method_data['classes_forgotten'], method_data['precise_retain_accuracy_test'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)
                axs[1].plot(method_data['classes_forgotten'], method_data['end_class_accuracy_test'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)
                axs[2].plot(method_data['classes_forgotten'], method_data['precise_remember_accuracy_test'], marker=markers[i % len(markers)], label=method_iter, color=colors[i % len(colors)], linewidth=2)
            
            plot_titles = ['Retain Accuracy (Test)', 'Last Forgotten Class Accuracy (Test)', 'Remember Accuracy (Test)']
            ylabels = ['Accuracy (%)', 'Accuracy (%)', 'Accuracy (%)']
            unique_x_values = sorted(model_specific_df['classes_forgotten'].unique())

            for i in range(3):
                axs[i].set_title(plot_titles[i], fontsize=14, fontweight='bold')
                axs[i].set_xlabel('Number of Classes Forgotten', fontsize=12)
                axs[i].set_ylabel(ylabels[i], fontsize=12)
                axs[i].grid(True, linestyle='--', alpha=0.7)
                if unique_x_values:
                    axs[i].set_xticks(unique_x_values)

            handles, labels = axs[0].get_legend_handles_labels()
            if handles:
                fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.05), fancybox=True, shadow=True, ncol=min(5, len(handles)), fontsize=12)
            plt.suptitle(f'Precise Performance Metrics for {model_name_iter} on CIFAR-100 (Test)', fontsize=16, fontweight='bold')
            plt.tight_layout(rect=[0, 0.08, 1, 0.95])
            plt.show()
elif df.empty:
    print("DataFrame is empty. Cannot perform precise metric visualization.")
else:
    print("DataFrame does not contain 'method' or 'model' columns. Cannot perform precise metric visualization.")

In [None]:
# Final check of the DataFrame if it exists and is populated
if 'df' in locals() and not df.empty:
    print("\nFinal DataFrame sample:")
    print(df.head())
    print("\nDataFrame columns:")
    print(df.columns.tolist())
else:
    print("\nDataFrame 'df' is not defined or is empty.")