In [1]:
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)
plt.rcParams['figure.figsize'] = (14, 8)

In [None]:
# Base directory for checkpoints
base_dir = "/u/kdkyum/ptmp_link/workdir/continual_unlearn/checkpoints"

# Function to find all available methods with continual_unlearn suffix
def find_methods():
    methods = []
    if os.path.exists(base_dir):
        for item in os.listdir(base_dir):
            if item.endswith("_continual_unlearn") and os.path.isdir(os.path.join(base_dir, item)):
                methods.append(item)
    return methods

# Get all methods
methods = find_methods()
print(f"Found methods: {methods}")

# Identify available models for each method
models_per_method = {}
for method in methods:
    method_path = os.path.join(base_dir, method)
    if os.path.exists(method_path) and os.path.isdir(method_path):
        models_per_method[method] = [m for m in os.listdir(method_path) if os.path.isdir(os.path.join(method_path, m))]
        print(f"Method {method} has models: {models_per_method[method]}")
    else:
        models_per_method[method] = []
        print(f"Warning: Method directory not found or not a directory: {method_path}")

Found methods: ['FT_continual_unlearn', 'retrain_continual_unlearn', 'NG_continual_unlearn', 'synaptag_RL_continual_unlearn', 'GA_continual_unlearn', 'SalUn_continual_unlearn', 'RL_continual_unlearn', 'synaptag_NG_continual_unlearn']
Method FT_continual_unlearn has models: ['resnet18', 'resnet50']
Method retrain_continual_unlearn has models: ['resnet18', 'resnet50']
Method NG_continual_unlearn has models: ['resnet18', 'resnet50']
Method synaptag_RL_continual_unlearn has models: ['resnet18', 'resnet50']
Method GA_continual_unlearn has models: ['resnet18', 'resnet50']
Method SalUn_continual_unlearn has models: ['resnet18', 'resnet50']
Method RL_continual_unlearn has models: ['resnet18', 'resnet50']
Method synaptag_NG_continual_unlearn has models: ['resnet18', 'resnet50']


In [None]:
# This cell might show an outdated variable or concept if 'method_dir' was specific to the old structure.
# For example, to see models for the first method:
if methods and methods[0] in models_per_method:
    print(f"Models for {methods[0]}: {models_per_method[methods[0]]}")
else:
    print("No methods or models found to display.")

'/u/kdkyum/ptmp_link/workdir/continual_unlearn/checkpoints/synaptag_NG_continual_unlearn'

In [None]:
def load_evaluation_results(method, model, dataset_name):
    """Load evaluation results for a specific method, model, and dataset_name"""
    results = []
    # Path to the directory containing forget_stage subdirectories
    path_to_stages = os.path.join(base_dir, method, model, dataset_name)
    
    if not os.path.exists(path_to_stages) or not os.path.isdir(path_to_stages):
        return results
    
    # Get all forget stages
    forget_stages = []
    for stage_dir_name in os.listdir(path_to_stages):
        stage_full_path = os.path.join(path_to_stages, stage_dir_name)
        if os.path.isdir(stage_full_path):
            try:
                if '-' in stage_dir_name:
                    begin, end = map(int, stage_dir_name.split('-'))
                elif '_' in stage_dir_name:
                    begin, end = map(int, stage_dir_name.split('_'))
                else:
                    raise ValueError(f"Directory name format not recognized: {stage_dir_name}")
                forget_stages.append((begin, end, stage_dir_name, stage_full_path))
            except ValueError as e:
                print(f"Skipping directory with invalid format: {stage_dir_name} in {path_to_stages} - {str(e)}")
    
    forget_stages.sort(key=lambda x: x[1]) # Sort by end class
    
    for begin, end, stage_dir_name, stage_full_path in forget_stages:
        eval_file = os.path.join(stage_full_path, 'evaluation_results.json')
        if os.path.exists(eval_file):
            try:
                with open(eval_file, 'r') as f:
                    data = json.load(f)
                    results.append({
                        'method': method,
                        'model': model,
                        'dataset': dataset_name,
                        'forget_class_begin': begin,
                        'forget_class_end': end,
                        'data': data,
                    })
            except Exception as e:
                print(f"Error loading {eval_file}: {e}")
    return results

def extract_metrics(results):
    """Extract key metrics from loaded results into a structured DataFrame"""
    metrics_data = []
    
    for result in results:
        method = result['method']
        method_display = method.replace('_continual_unlearn', '')
        model_name = result['model']
        dataset_name = result['dataset']
        forget_begin = result['forget_class_begin']
        forget_end = result['forget_class_end']
        data = result['data']
        
        metrics = {
            'method': method_display,
            'model': model_name,
            'dataset': dataset_name,
            'forget_class_begin': forget_begin,
            'forget_class_end': forget_end,
            'classes_forgotten': forget_end - forget_begin,
            'unlearning_time': data.get('unlearning_time', None)
        }
        
        if 'accuracy' in data:
            if isinstance(data['accuracy'], dict):
                for key, value in data['accuracy'].items():
                    metrics[f'accuracy_{key}'] = value
            else:
                metrics['accuracy'] = data['accuracy']

        if 'class_wise_accuracy' in data and isinstance(data['class_wise_accuracy'], list):
            for x in data["class_wise_accuracy"]:
                metrics[f'{x["dataset"]}_accuracy_class_{x["class"]}'] = x.get('accuracy', None)
        
        if 'SVC_MIA_forget_efficacy' in data:
            for key, value in data['SVC_MIA_forget_efficacy'].items():
                metrics[f'mia_forget_{key}'] = value
                
        metrics_data.append(metrics)
    
    return pd.DataFrame(metrics_data)

In [None]:
# Define the dataset names (e.g., cifar10, cifar100) to process
dataset_names_to_process = ["cifar10", "cifar100"]

for dataset_name in dataset_names_to_process:
    print(f"\nProcessing dataset type: {dataset_name}...")
    all_results_for_dataset_type = []
    for method in methods:
        if method in models_per_method:
            for model in models_per_method[method]:
                current_results = load_evaluation_results(method, model, dataset_name)
                if current_results:
                    all_results_for_dataset_type.extend(current_results)
        else:
            print(f"  Skipping method {method} as it's not in models_per_method list.")

    df = extract_metrics(all_results_for_dataset_type)

    print(f"--- Summary for {dataset_name} ---")
    if not df.empty:
        print(f"Loaded {len(df)} evaluation entries.")
        print(f"Methods found: {df['method'].unique()}")
        print(f"Models found: {df['model'].unique()}")
        print(f"Dataset names processed: {df['dataset'].unique()}")
        df.to_csv(f"eval_results_for_{dataset_name}.csv", index=False)
        print(f"Saved results to eval_results_for_{dataset_name}.csv")
    else:
        print(f"No evaluation entries loaded for {dataset_name}.")

Skipping directory with invalid format: masks - Directory name format not recognized: masks
Loaded 56 evaluation results
Methods: ['FT' 'retrain' 'NG' 'synaptag_RL' 'GA' 'RL' 'synaptag_NG']
Datasets: ['cifar10']
Skipping directory with invalid format: masks - Directory name format not recognized: masks
Loaded 152 evaluation results
Methods: ['FT' 'retrain' 'NG' 'synaptag_RL' 'GA' 'SalUn' 'RL' 'synaptag_NG']
Datasets: ['cifar100']


In [None]:
# You can inspect a sample of the DataFrame if needed for one of the datasets, e.g., the last one processed.
if 'df' in locals() and not df.empty:
    print("\nSample of the last processed DataFrame:")
    print(df.head())
    print("\nColumns in the DataFrame:")
    print(df.columns.tolist())
else:
    print("\nNo DataFrame generated or DataFrame is empty.")