In [1]:
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)
plt.rcParams['figure.figsize'] = (14, 8)

In [2]:
# Base directory for checkpoints
base_dir = "/u/kdkyum/ptmp_link/workdir/continual_unlearn/checkpoints"

# Function to find all available methods with continual_unlearn suffix
def find_methods():
    methods = []
    if os.path.exists(base_dir):
        for item in os.listdir(base_dir):
            if item.endswith("_continual_unlearn") and os.path.isdir(os.path.join(base_dir, item)):
                methods.append(item)
    return methods

# Get all methods
methods = find_methods()
print(f"Found methods: {methods}")

# Identify available datasets for each method
datasets = {}
for method in methods:
    method_dir = os.path.join(base_dir, method)
    datasets[method] = [d for d in os.listdir(method_dir) if os.path.isdir(os.path.join(method_dir, d))]
    print(f"Method {method} has datasets: {datasets[method]}")

Found methods: ['FT_continual_unlearn', 'retrain_continual_unlearn', 'NG_continual_unlearn', 'synaptag_continual_unlearn', 'boundary_shrink_continual_unlearn', 'GA_continual_unlearn', 'RL_continual_unlearn', 'boundary_expanding_continual_unlearn']
Method FT_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method retrain_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method NG_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method synaptag_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method boundary_shrink_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method GA_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method RL_continual_unlearn has datasets: ['cifar10', 'cifar100']
Method boundary_expanding_continual_unlearn has datasets: ['cifar10', 'cifar100']


Now let's create functions to load and process the evaluation results:

In [None]:
def load_evaluation_results(method, dataset):
    """Load evaluation results for a specific method and dataset"""
    results = []
    method_dir = os.path.join(base_dir, method, dataset)
    
    if not os.path.exists(method_dir):
        print(f"Directory not found: {method_dir}")
        return results
    
    # Get all forget stages
    forget_stages = []
    for stage_dir in os.listdir(method_dir):
        stage_path = os.path.join(method_dir, stage_dir)
        if os.path.isdir(stage_path):
            try:
                # Handle both underscore and hyphen formats (e.g., '0_1' or '0-1')
                if '-' in stage_dir:
                    begin, end = map(int, stage_dir.split('-'))
                elif '_' in stage_dir:
                    begin, end = map(int, stage_dir.split('_'))
                else:
                    # Skip directories that don't follow either pattern
                    raise ValueError(f"Directory name format not recognized: {stage_dir}")
                    
                forget_stages.append((begin, end, stage_dir, stage_path))
            except Exception as e:
                print(f"Skipping directory with invalid format: {stage_dir} - {str(e)}")
    
    # Sort by end class for proper ordering
    forget_stages.sort(key=lambda x: x[1])
    
    # Load results for each stage
    for begin, end, stage_dir, stage_path in forget_stages:
        eval_file = os.path.join(stage_path, 'evaluation_results.json')
        if os.path.exists(eval_file):
            try:
                with open(eval_file, 'r') as f:
                    data = json.load(f)
                    results.append({
                        'method': method,
                        'dataset': dataset,
                        'forget_class_begin': begin,
                        'forget_class_end': end,
                        'data': data,
                    })
            except Exception as e:
                print(f"Error loading {eval_file}: {e}")
    
    return results

def extract_metrics(results):
    """Extract key metrics from loaded results into a structured DataFrame"""
    metrics_data = []
    
    for result in results:
        method = result['method']
        method_display = method.replace('_continual_unlearn', '')
        dataset = result['dataset']
        forget_begin = result['forget_class_begin']
        forget_end = result['forget_class_end']
        data = result['data']
        
        # Extract common metrics
        metrics = {
            'method': method_display,
            'dataset': dataset,
            'forget_class_begin': forget_begin,
            'forget_class_end': forget_end,
            'classes_forgotten': forget_end - forget_begin,
            'unlearning_time': data.get('unlearning_time', None)
        }
        
        # Extract accuracy metrics
        if 'accuracy' in data:
            if isinstance(data['accuracy'], dict):
                for key, value in data['accuracy'].items():
                    metrics[f'accuracy_{key}'] = value
            else:
                metrics['accuracy'] = data['accuracy']

        for x in data["class_wise_accuracy"]:
            metrics[f'accuracy_class_{x["class"]}'] = x.get('accuracy', None)
        
        # Extract MIA metrics
        if 'SVC_MIA_forget_efficacy' in data:
            for key, value in data['SVC_MIA_forget_efficacy'].items():
                metrics[f'mia_forget_{key}'] = value
                
        metrics_data.append(metrics)
    
    return pd.DataFrame(metrics_data)

In [20]:
# Load all results
all_results = []
for method in methods:
    # for dataset in datasets[method]:
    method_results = load_evaluation_results(method, "cifar10")
    all_results.extend(method_results)

# Convert to DataFrame for easier analysis
df = extract_metrics(all_results)

# Show basic stats
print(f"Loaded {len(df)} evaluation results")
print(f"Methods: {df['method'].unique()}")
print(f"Datasets: {df['dataset'].unique()}")

# Display the first few rows
df.head()

Skipping directory with invalid format: masks - Directory name format not recognized: masks
[{'class': 0, 'total_samples': 4500, 'correct_counts': 84, 'accuracy': 1.8666666746139526, 'dataset': 'train'}, {'class': 1, 'total_samples': 4500, 'correct_counts': 4490, 'accuracy': 99.77777862548828, 'dataset': 'train'}, {'class': 2, 'total_samples': 4500, 'correct_counts': 4456, 'accuracy': 99.02222442626953, 'dataset': 'train'}, {'class': 3, 'total_samples': 4500, 'correct_counts': 4253, 'accuracy': 94.5111083984375, 'dataset': 'train'}, {'class': 4, 'total_samples': 4500, 'correct_counts': 4434, 'accuracy': 98.53333282470703, 'dataset': 'train'}, {'class': 5, 'total_samples': 4500, 'correct_counts': 4200, 'accuracy': 93.33333587646484, 'dataset': 'train'}, {'class': 6, 'total_samples': 4500, 'correct_counts': 4445, 'accuracy': 98.77777862548828, 'dataset': 'train'}, {'class': 7, 'total_samples': 4500, 'correct_counts': 4453, 'accuracy': 98.95555877685547, 'dataset': 'train'}, {'class': 8, 

AttributeError: 'list' object has no attribute 'items'

## Performance Visualization

Now, let's create visualizations to show the performance of different methods across sequential unlearning stages.

In [12]:
# Identify the ground truth method
ground_truth_method = 'retrain'

def plot_metric_by_dataset(dataset, metric, title=None, ylim=None):
    """Plot a specific metric for all methods in a dataset"""
    # Filter data for the specific dataset
    dataset_df = df[df['dataset'] == dataset]
    
    if len(dataset_df) == 0:
        print(f"No data available for dataset {dataset}")
        return
        
    if metric not in dataset_df.columns:
        print(f"Metric '{metric}' not found in data. Available metrics: {[col for col in dataset_df.columns if col not in ['method', 'dataset', 'forget_class_begin', 'forget_class_end', 'classes_forgotten']]}")
        return
    
    plt.figure(figsize=(14, 8))
    
    # Plot each method
    for method in dataset_df['method'].unique():
        method_data = dataset_df[dataset_df['method'] == method]
        
        # Sort by forget_class_end to ensure proper sequence
        method_data = method_data.sort_values('forget_class_end')
        
        if method == ground_truth_method:
            # Highlight ground truth method
            plt.plot(method_data['forget_class_end'], method_data[metric], 
                    marker='o', linewidth=3, markersize=10, label=method, linestyle='-', color='black')
        else:
            plt.plot(method_data['forget_class_end'], method_data[metric], 
                    marker='o', linewidth=2, markersize=8, label=method)
    
    if title is None:
        title = f"{metric} vs. Classes Forgotten ({dataset})"
    
    plt.title(title)
    plt.xlabel('Cumulative Classes Forgotten')
    plt.ylabel(metric)
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    if ylim is not None:
        plt.ylim(ylim)
    
    plt.tight_layout()
    plt.show()

def plot_gap_to_ground_truth(dataset, metric, title=None):
    """Plot the gap between each method and the ground truth method"""
    # Filter data for the specific dataset
    dataset_df = df[df['dataset'] == dataset]
    
    if len(dataset_df) == 0 or ground_truth_method not in dataset_df['method'].unique():
        print(f"Dataset {dataset} doesn't have ground truth method {ground_truth_method}")
        return
        
    if metric not in dataset_df.columns:
        print(f"Metric '{metric}' not found in data")
        return
    
    plt.figure(figsize=(14, 8))
    
    # Get ground truth data
    ground_truth_data = dataset_df[dataset_df['method'] == ground_truth_method]
    ground_truth_dict = dict(zip(ground_truth_data['forget_class_end'], ground_truth_data[metric]))
    
    # Plot gap for each method
    for method in dataset_df['method'].unique():
        if method == ground_truth_method:
            continue
            
        method_data = dataset_df[dataset_df['method'] == method]
        method_data = method_data.sort_values('forget_class_end')
        
        # Calculate gap at each stage
        gaps = []
        x_values = []
        
        for _, row in method_data.iterrows():
            stage = row['forget_class_end']
            if stage in ground_truth_dict:
                gap = abs(row[metric] - ground_truth_dict[stage])
                gaps.append(gap)
                x_values.append(stage)
        
        if gaps:
            plt.plot(x_values, gaps, marker='o', linewidth=2, markersize=8, label=method)
    
    if title is None:
        title = f"Gap to Ground Truth ({ground_truth_method}) for {metric} ({dataset})"
    
    plt.title(title)
    plt.xlabel('Cumulative Classes Forgotten')
    plt.ylabel(f"Absolute Difference from {ground_truth_method}")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

## Class-wise Accuracy Analysis

Now let's analyze the per-class accuracy for different methods at various stages of continual unlearning. This will show how unlearning affects the model's performance on individual classes.

In [6]:
def extract_classwise_accuracy(method, dataset):
    """Extract class-wise accuracy from evaluation results for a method and dataset"""
    classwise_data = {}
    method_dir = os.path.join(base_dir, method, dataset)
    
    if not os.path.exists(method_dir):
        print(f"Directory not found: {method_dir}")
        return classwise_data
    
    # Get all forget stages
    forget_stages = []
    for stage_dir in os.listdir(method_dir):
        stage_path = os.path.join(method_dir, stage_dir)
        if os.path.isdir(stage_path):
            try:
                # Handle both underscore and hyphen formats
                if '-' in stage_dir:
                    begin, end = map(int, stage_dir.split('-'))
                elif '_' in stage_dir:
                    begin, end = map(int, stage_dir.split('_'))
                else:
                    # Skip directories that don't follow either pattern
                    continue
                    
                forget_stages.append((begin, end, stage_dir, stage_path))
            except Exception as e:
                print(f"Skipping directory with invalid format: {stage_dir} - {str(e)}")
    
    # Sort by end class for proper ordering
    forget_stages.sort(key=lambda x: x[1])
    
    # Load class-wise accuracy data for each stage
    for begin, end, stage_dir, stage_path in forget_stages:
        eval_file = os.path.join(stage_path, 'evaluation_results.json')
        if os.path.exists(eval_file):
            try:
                with open(eval_file, 'r') as f:
                    data = json.load(f)
                    
                    # Check if class-wise accuracy data exists in the new format
                    if 'class_wise_accuracy' in data:
                        class_wise_items = data['class_wise_accuracy']
                        
                        # Extract test accuracy (filter by dataset="test")
                        test_items = [item for item in class_wise_items if item.get('dataset') == 'test']
                        
                        # Create dictionary mapping class index to accuracy
                        class_acc = {}
                        max_class = 0
                        for item in test_items:
                            class_idx = item.get('class')
                            accuracy = item.get('accuracy')
                            class_acc[class_idx] = accuracy
                            max_class = max(max_class, class_idx)
                        
                        # Convert to list format for compatibility
                        acc_list = [class_acc.get(i, 0) for i in range(max_class + 1)]
                        
                        classwise_data[end] = {
                            'forget_class_begin': begin,
                            'forget_class_end': end,
                            'classes_forgotten': end - begin,
                            'classwise_accuracy': acc_list
                        }
                    elif 'classwise_accuracy' in data:
                        # Handle the original format if present
                        classwise_acc = data['classwise_accuracy']
                        classwise_data[end] = {
                            'forget_class_begin': begin,
                            'forget_class_end': end,
                            'classes_forgotten': end - begin,
                            'classwise_accuracy': classwise_acc
                        }
            except Exception as e:
                print(f"Error loading {eval_file}: {e}")
    
    return classwise_data

def plot_classwise_accuracy(method, dataset, method_display=None):
    """Plot class-wise accuracy for a specific method and dataset"""
    if method_display is None:
        method_display = method.replace('_continual_unlearn', '')
        
    classwise_data = extract_classwise_accuracy(method, dataset)
    
    if not classwise_data:
        # Try a fallback approach: calculate class-wise accuracy from test dataloader
        print(f"No class-wise accuracy data found for {method} on {dataset}. Checking if we can extract from raw evaluation data...")
        
        # Try to load data from test_predictions.npz if available
        classwise_data = extract_classwise_accuracy_from_predictions(method, dataset)
        
        # Check if we can calculate class-wise accuracy now
        if not classwise_data:
            print(f"Cannot generate class-wise accuracy plot for {method} on {dataset}")
            return
    
    # Create a plot
    plt.figure(figsize=(15, 10))
    
    # Determine max number of classes based on dataset
    max_classes = 10 if dataset == 'cifar10' else 100
    
    # Color map for different stages
    cmap = plt.cm.get_cmap('viridis', len(classwise_data) + 1)
    
    # Plot class-wise accuracy for each stage
    for i, (stage, stage_data) in enumerate(sorted(classwise_data.items())):
        class_acc = stage_data['classwise_accuracy']
        classes_forgotten = stage_data['classes_forgotten']
        
        # Convert to right format if it's a dict
        if isinstance(class_acc, dict):
            class_indices = sorted([int(k) for k in class_acc.keys()])
            accuracy_values = [class_acc[str(k)] for k in class_indices]
        else:  # Assume it's a list
            class_indices = range(len(class_acc))
            accuracy_values = class_acc
        
        # Plot with a line style that distinguishes forgotten vs. retained
        plt.plot(class_indices, accuracy_values, 'o-', 
                label=f'Classes forgotten: {classes_forgotten}',
                color=cmap(i), linewidth=2, markersize=6)
        
        # Mark forgotten classes with a different marker
        forgotten_classes = range(stage_data['forget_class_begin'], stage_data['forget_class_end'])
        forgotten_indices = [i for i in class_indices if i in forgotten_classes]
        forgotten_values = [accuracy_values[class_indices.index(i)] for i in forgotten_indices if i in class_indices]
        
        if forgotten_indices:
            plt.plot(forgotten_indices, forgotten_values, 'x', 
                    color=cmap(i), markersize=10, markeredgewidth=2)
    
    # Add horizontal lines showing forgetting thresholds (e.g., at 20% and 50%)
    plt.axhline(y=50, color='gray', linestyle='--', alpha=0.5)
    plt.axhline(y=20, color='gray', linestyle=':', alpha=0.5)
    
    plt.title(f'Class-wise Accuracy for {method_display} on {dataset}')
    plt.xlabel('Class Index')
    plt.ylabel('Accuracy (%)')
    plt.xticks(range(0, max_classes, 5 if max_classes > 20 else 1))
    plt.ylim(0, 105)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    plt.tight_layout()
    plt.show()

# Alternative implementation if classwise_accuracy isn't available directly
def extract_classwise_accuracy_from_predictions(method, dataset):
    """Extract class-wise accuracy from raw prediction data"""
    classwise_data = {}
    method_dir = os.path.join(base_dir, method, dataset)
    
    if not os.path.exists(method_dir):
        return classwise_data
    
    for stage_dir in os.listdir(method_dir):
        stage_path = os.path.join(method_dir, stage_dir)
        if os.path.isdir(stage_path):
            try:
                # Parse stage info
                if '-' in stage_dir:
                    begin, end = map(int, stage_dir.split('-'))
                elif '_' in stage_dir:
                    begin, end = map(int, stage_dir.split('_'))
                else:
                    continue
                
                # Check for predictions file
                preds_file = os.path.join(stage_path, 'test_predictions.npz')
                if os.path.exists(preds_file):
                    # Load predictions and calculate class-wise accuracy
                    data = np.load(preds_file)
                    preds = data['preds']
                    targets = data['targets']
                    
                    # Calculate per-class accuracy
                    max_class = max(targets) + 1
                    class_acc = []
                    for c in range(max_class):
                        class_indices = targets == c
                        if np.sum(class_indices) > 0:  # Ensure we have samples for this class
                            class_correct = np.sum(preds[class_indices] == c)
                            class_total = np.sum(class_indices)
                            accuracy = 100 * class_correct / class_total
                        else:
                            accuracy = 0
                        class_acc.append(accuracy)
                    
                    classwise_data[end] = {
                        'forget_class_begin': begin,
                        'forget_class_end': end,
                        'classes_forgotten': end - begin,
                        'classwise_accuracy': class_acc
                    }
            except Exception as e:
                print(f"Error processing {stage_path}: {e}")
    
    return classwise_data

In [7]:
# Plot class-wise accuracy for all methods on CIFAR-10
if 'cifar10' in df['dataset'].unique():
    print("\n==== Class-wise Accuracy for CIFAR-10 ====\n")
    
    for method in methods:
        method_display = method.replace('_continual_unlearn', '')
        print(f"Generating plot for {method_display}...")
        plot_classwise_accuracy(method, 'cifar10', method_display)

# Plot class-wise accuracy for all methods on CIFAR-100
if 'cifar100' in df['dataset'].unique():
    print("\n==== Class-wise Accuracy for CIFAR-100 ====\n")
    
    for method in methods:
        method_display = method.replace('_continual_unlearn', '')
        print(f"Generating plot for {method_display}...")
        plot_classwise_accuracy(method, 'cifar100', method_display)


==== Class-wise Accuracy for CIFAR-10 ====

Generating plot for FT...
No class-wise accuracy data found for FT_continual_unlearn on cifar10. Checking if we can extract from raw evaluation data...
Cannot generate class-wise accuracy plot for FT_continual_unlearn on cifar10
Generating plot for retrain...
No class-wise accuracy data found for retrain_continual_unlearn on cifar10. Checking if we can extract from raw evaluation data...
Cannot generate class-wise accuracy plot for retrain_continual_unlearn on cifar10
Generating plot for NG...
No class-wise accuracy data found for NG_continual_unlearn on cifar10. Checking if we can extract from raw evaluation data...
Cannot generate class-wise accuracy plot for NG_continual_unlearn on cifar10
Generating plot for synaptag...
No class-wise accuracy data found for synaptag_continual_unlearn on cifar10. Checking if we can extract from raw evaluation data...
Cannot generate class-wise accuracy plot for synaptag_continual_unlearn on cifar10
Genera

In [None]:
def plot_classwise_comparison(dataset, target_stage=None):
    """Compare class-wise accuracy across methods at a specific forget stage"""
    plt.figure(figsize=(15, 10))
    
    # Determine max number of classes based on dataset
    max_classes = 10 if dataset == 'cifar10' else 100
    
    # Collect class-wise data for all methods
    method_data = {}
    for method in methods:
        method_display = method.replace('_continual_unlearn', '')
        classwise_data = extract_classwise_accuracy(method, dataset)
        
        if not classwise_data:
            classwise_data = extract_classwise_accuracy_from_predictions(method, dataset)
            
        if classwise_data:
            # Select the right stage to compare
            if target_stage is None:
                # Default to the maximum common stage
                available_stages = list(classwise_data.keys())
                if available_stages:
                    selected_stage = max(available_stages)
                    method_data[method_display] = classwise_data[selected_stage]
            elif target_stage in classwise_data:
                method_data[method_display] = classwise_data[target_stage]
            else:
                # If exact target stage not found, find the closest one
                available_stages = sorted(list(classwise_data.keys()))
                if available_stages:
                    closest_stage = min(available_stages, key=lambda x: abs(x - target_stage))
                    method_data[method_display] = classwise_data[closest_stage]
                    print(f"Note: For method {method_display}, using stage {closest_stage} instead of requested {target_stage}")
    
    if not method_data:
        print(f"No class-wise data available for comparison on {dataset}")
        return
    
    # Determine the stage being compared
    example_data = next(iter(method_data.values()))
    classes_forgotten = example_data['classes_forgotten']
    stage_end = example_data['forget_class_end']
    
    # Plot class-wise accuracy for each method
    for method_display, data in method_data.items():
        class_acc = data['classwise_accuracy']
        
        # Convert to right format if it's a dict
        if isinstance(class_acc, dict):
            class_indices = sorted([int(k) for k in class_acc.keys()])
            accuracy_values = [class_acc[str(k)] for k in class_indices]
        else:  # Assume it's a list
            class_indices = range(len(class_acc))
            accuracy_values = class_acc
        
        # Plot with a unique style for each method
        plt.plot(class_indices, accuracy_values, 'o-', 
                label=method_display,
                linewidth=2, markersize=6)
        
    # Add vertical band to highlight forgotten classes
    plt.axvspan(0, stage_end, color='lightgray', alpha=0.3, label=f'Forgotten ({stage_end} classes)')
    
    # Add horizontal lines showing thresholds
    plt.axhline(y=50, color='gray', linestyle='--', alpha=0.5)
    plt.axhline(y=20, color='gray', linestyle=':', alpha=0.5)
    
    plt.title(f'Class-wise Accuracy Comparison on {dataset} (Forgotten classes: {classes_forgotten})')
    plt.xlabel('Class Index')
    plt.ylabel('Accuracy (%)')
    plt.xticks(range(0, max_classes, 5 if max_classes > 20 else 1))
    plt.ylim(0, 105)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    plt.tight_layout()
    plt.show()

In [None]:
# Function to run class-wise accuracy extraction on GPU cluster with many files
def run_batch_classwise_extraction():
    """Extract and export class-wise accuracy data for all methods to a single file"""
    print("Running batch extraction of class-wise accuracy data...")
    
    all_classwise_data = {}
    
    for method in methods:
        method_display = method.replace('_continual_unlearn', '')
        all_classwise_data[method_display] = {}
        
        for dataset in datasets[method]:
            print(f"Extracting data for {method_display} on {dataset}...")
            classwise_data = extract_classwise_accuracy(method, dataset)
            
            if not classwise_data:
                print(f"  Falling back to prediction-based extraction...")
                classwise_data = extract_classwise_accuracy_from_predictions(method, dataset)
            
            if classwise_data:
                all_classwise_data[method_display][dataset] = classwise_data
                print(f"  Successfully extracted data for {len(classwise_data)} stages")
            else:
                print(f"  No class-wise data found")
    
    # Save the extracted data to a file
    output_dir = os.path.join(os.path.dirname(base_dir), "plots")
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "classwise_accuracy_data.npz")
    
    # Convert complex nested dict to a simpler format for saving
    save_dict = {}
    for method, datasets_dict in all_classwise_data.items():
        for dataset, stages_dict in datasets_dict.items():
            for stage, data in stages_dict.items():
                key = f"{method}_{dataset}_{stage}"
                save_dict[f"{key}_begin"] = data['forget_class_begin']
                save_dict[f"{key}_end"] = data['forget_class_end']
                save_dict[f"{key}_forgotten"] = data['classes_forgotten']
                save_dict[f"{key}_accuracy"] = np.array(data['classwise_accuracy'])
    
    np.savez_compressed(output_file, **save_dict)
    print(f"Saved class-wise accuracy data to {output_file}")
    
    return all_classwise_data

# Run the batch extraction
# Uncomment to run on the computing cluster
# all_classwise_data = run_batch_classwise_extraction()