In [1]:
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)
plt.rcParams['figure.figsize'] = (14, 8)

In [2]:
# Base directory for checkpoints
base_dir = "/u/kdkyum/ptmp_link/workdir/continual_unlearn/checkpoints"

# Function to find all available methods with continual_unlearn suffix
def find_methods():
    methods = []
    if os.path.exists(base_dir):
        for item in os.listdir(base_dir):
            if item.endswith("_continual_unlearn") and os.path.isdir(os.path.join(base_dir, item)):
                methods.append(item)
    return methods

# Get all methods
methods = find_methods()
print(f"Found methods: {methods}")

# Identify available datasets for each method
datasets = {}
for method in methods:
    method_dir = os.path.join(base_dir, method)
    datasets[method] = [d for d in os.listdir(method_dir) if os.path.isdir(os.path.join(method_dir, d))]
    print(f"Method {method} has datasets: {datasets[method]}")

Found methods: ['FT_continual_unlearn', 'retrain_continual_unlearn', 'NG_continual_unlearn', 'synaptag_continual_unlearn', 'GA_continual_unlearn', 'RL_continual_unlearn']
Method FT_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method retrain_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method NG_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method synaptag_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method GA_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method RL_continual_unlearn has datasets: ['cifar10', 'cifar100']


In [3]:
def load_evaluation_results(method, dataset):
    """Load evaluation results for a specific method and dataset"""
    results = []
    method_dir = os.path.join(base_dir, method, dataset)
    
    if not os.path.exists(method_dir):
        print(f"Directory not found: {method_dir}")
        return results
    
    # Get all forget stages
    forget_stages = []
    for stage_dir in os.listdir(method_dir):
        stage_path = os.path.join(method_dir, stage_dir)
        if os.path.isdir(stage_path):
            try:
                # Handle both underscore and hyphen formats (e.g., '0_1' or '0-1')
                if '-' in stage_dir:
                    begin, end = map(int, stage_dir.split('-'))
                elif '_' in stage_dir:
                    begin, end = map(int, stage_dir.split('_'))
                else:
                    # Skip directories that don't follow either pattern
                    raise ValueError(f"Directory name format not recognized: {stage_dir}")
                    
                forget_stages.append((begin, end, stage_dir, stage_path))
            except Exception as e:
                print(f"Skipping directory with invalid format: {stage_dir} - {str(e)}")
    
    # Sort by end class for proper ordering
    forget_stages.sort(key=lambda x: x[1])
    
    # Load results for each stage
    for begin, end, stage_dir, stage_path in forget_stages:
        eval_file = os.path.join(stage_path, 'evaluation_results.json')
        if os.path.exists(eval_file):
            try:
                with open(eval_file, 'r') as f:
                    data = json.load(f)
                    results.append({
                        'method': method,
                        'dataset': dataset,
                        'forget_class_begin': begin,
                        'forget_class_end': end,
                        'data': data,
                    })
            except Exception as e:
                print(f"Error loading {eval_file}: {e}")
    
    return results

def extract_metrics(results):
    """Extract key metrics from loaded results into a structured DataFrame"""
    metrics_data = []
    
    for result in results:
        method = result['method']
        method_display = method.replace('_continual_unlearn', '')
        dataset = result['dataset']
        forget_begin = result['forget_class_begin']
        forget_end = result['forget_class_end']
        data = result['data']
        
        # Extract common metrics
        metrics = {
            'method': method_display,
            'dataset': dataset,
            'forget_class_begin': forget_begin,
            'forget_class_end': forget_end,
            'classes_forgotten': forget_end - forget_begin,
            'unlearning_time': data.get('unlearning_time', None)
        }
        
        # Extract accuracy metrics
        if 'accuracy' in data:
            if isinstance(data['accuracy'], dict):
                for key, value in data['accuracy'].items():
                    metrics[f'accuracy_{key}'] = value
            else:
                metrics['accuracy'] = data['accuracy']

        for x in data["class_wise_accuracy"]:
            metrics[f'accuracy_class_{x["class"]}'] = x.get('accuracy', None)
        
        # Extract MIA metrics
        if 'SVC_MIA_forget_efficacy' in data:
            for key, value in data['SVC_MIA_forget_efficacy'].items():
                metrics[f'mia_forget_{key}'] = value
                
        metrics_data.append(metrics)
    
    return pd.DataFrame(metrics_data)

In [6]:
# Load all results
all_results = []
for method in methods:
    # for dataset in datasets[method]:
    method_results = load_evaluation_results(method, "cifar10")
    all_results.extend(method_results)

# Convert to DataFrame for easier analysis
df = extract_metrics(all_results)

# Show basic stats
print(f"Loaded {len(df)} evaluation results")
print(f"Methods: {df['method'].unique()}")
print(f"Datasets: {df['dataset'].unique()}")

# Display the first few rows
df.head()

Skipping directory with invalid format: masks - Directory name format not recognized: masks
Loaded 40 evaluation results
Methods: ['FT' 'retrain' 'NG' 'synaptag' 'GA']
Datasets: ['cifar10']


Unnamed: 0,method,dataset,forget_class_begin,forget_class_end,classes_forgotten,unlearning_time,accuracy_retain,accuracy_forget,accuracy_val,accuracy_test,...,accuracy_class_5,accuracy_class_6,accuracy_class_7,accuracy_class_8,accuracy_class_9,mia_forget_correctness,mia_forget_confidence,mia_forget_entropy,mia_forget_m_entropy,mia_forget_prob
0,FT,cifar10,0,0,0,143.505831,98.012346,1.844444,87.02,83.4,...,83.5,94.5,94.699997,97.900002,94.5,0.981556,1.0,0.811333,1.0,0.941111
1,FT,cifar10,0,1,1,127.511504,98.694444,14.933333,79.92,75.08,...,85.699997,94.599998,93.599998,98.400002,96.5,0.850667,1.0,0.876222,1.0,0.998222
2,FT,cifar10,0,2,2,115.667961,98.304762,2.044444,68.44,65.2,...,94.300003,94.800003,94.0,98.0,97.300003,0.979556,0.999778,0.915778,1.0,0.702667
3,FT,cifar10,0,3,3,99.204968,99.785185,10.155556,60.56,59.28,...,94.400002,98.599998,97.099998,98.5,98.400002,0.898444,0.996444,0.807333,1.0,0.562889
4,FT,cifar10,0,4,4,84.374024,99.657778,27.6,52.14,51.31,...,98.699997,97.300003,94.599998,97.0,98.800003,0.724,0.985333,0.905111,1.0,0.945556


In [7]:
df.to_csv("eval_results_for_cifar10.csv", index=False)