# Continual Unlearning Results Visualization

This notebook visualizes the results of continual unlearning experiments across different models (FT, GA, RL) as they progressively forget more classes.

In [ ]:
# Import required libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import glob
from pathlib import Path

# Set plot style
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

## Data Loading and Processing

First, we'll load all the CSV files from the checkpoints directory and organize them by model type and forgetting stages.

In [ ]:
# Define the base path for checkpoints
base_path = "/home/kdkyum/workdir/cl_unlearn/checkpoints/"

# Dictionary to store results by model type
all_results = {
    'base': None,
    'FT': {},
    'GA': {},
    'RL': {}
}

In [ ]:
# Load base model results
base_model_path = os.path.join(base_path, "base_model/class_wise_accuracy.csv")
if os.path.exists(base_model_path):
    all_results['base'] = pd.read_csv(base_model_path)
    print(f"Loaded base model data: {all_results['base'].shape[0]} rows")
else:
    print("Base model data not found")

In [ ]:
# Function to extract number of forgotten classes from directory path
def extract_forgotten_classes(path):
    # Extract directory name
    dir_name = os.path.basename(os.path.dirname(path))
    
    # Count the number of 'c' characters in the directory name
    if dir_name.startswith("forget_class_"):
        forgotten_classes = dir_name.replace("forget_class_", "").split("_")
        return len(forgotten_classes)
    return 0

# Function to get forgotten class numbers
def get_forgotten_classes(path):
    # Extract directory name
    dir_name = os.path.basename(os.path.dirname(path))
    
    # Extract class numbers
    if dir_name.startswith("forget_class_"):
        class_str = dir_name.replace("forget_class_", "")
        return [int(c.replace('c', '')) for c in class_str.split("_")]
    return []

# Load all model results
for model_type in ['FT', 'GA', 'RL']:
    model_path = os.path.join(base_path, f"{model_type}_model/")
    if os.path.exists(model_path):
        # Find all CSV files in subdirectories
        csv_files = glob.glob(os.path.join(model_path, "**/class_wise_accuracy.csv"), recursive=True)
        
        for file_path in csv_files:
            # Extract number of forgotten classes
            num_forgotten = extract_forgotten_classes(file_path)
            forgotten_classes = get_forgotten_classes(file_path)
            
            # Load data
            df = pd.read_csv(file_path)
            df['num_forgotten_classes'] = num_forgotten
            df['forgotten_classes'] = str(forgotten_classes)
            df['model_type'] = model_type
            
            # Store in dictionary
            all_results[model_type][num_forgotten] = df
            
        print(f"Loaded {len(all_results[model_type])} datasets for {model_type} model")

In [ ]:
# Combine all datasets for analysis
combined_data = []

# Add base model data
if all_results['base'] is not None:
    base_df = all_results['base'].copy()
    base_df['num_forgotten_classes'] = 0
    base_df['forgotten_classes'] = '[]'
    base_df['model_type'] = 'base'
    combined_data.append(base_df)

# Add all other model data
for model_type in ['FT', 'GA', 'RL']:
    for num_forgotten, df in all_results[model_type].items():
        combined_data.append(df)

# Combine into single dataframe
if combined_data:
    all_data = pd.concat(combined_data, ignore_index=True)
    print(f"Combined dataframe shape: {all_data.shape}")
    
    # Add flag for forgotten vs retained classes
    all_data['is_forgotten'] = all_data.apply(
        lambda row: row['class'] in eval(row['forgotten_classes']) 
        if row['forgotten_classes'] != '[]' else False, 
        axis=1
    )
else:
    print("No data to combine")

## Visualization 1: Accuracy of Forgotten vs Retained Classes

Let's visualize how well the models forget the targeted classes while retaining performance on other classes.

In [ ]:
# Filter to only test set results
test_data = all_data[all_data['dataset'] == 'test'].copy()

# Create a summary dataframe
summary = test_data.groupby(['model_type', 'num_forgotten_classes', 'is_forgotten'])['accuracy'].mean().reset_index()

# Plot retained vs forgotten class accuracy
plt.figure(figsize=(14, 8))

for model_type in sorted(summary['model_type'].unique()):
    if model_type == 'base':
        continue
    
    # Forgotten classes
    forgotten = summary[(summary['model_type'] == model_type) & (summary['is_forgotten'])]
    plt.plot(forgotten['num_forgotten_classes'], forgotten['accuracy'], 
             marker='o', linestyle='-', label=f'{model_type} - Forgotten Classes')
    
    # Retained classes
    retained = summary[(summary['model_type'] == model_type) & (~summary['is_forgotten'])]
    plt.plot(retained['num_forgotten_classes'], retained['accuracy'], 
             marker='s', linestyle='--', label=f'{model_type} - Retained Classes')

plt.title('Test Accuracy: Forgotten vs. Retained Classes')
plt.xlabel('Number of Forgotten Classes')
plt.ylabel('Average Accuracy (%)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## Visualization 2: Class-wise Accuracy Heatmaps

Let's create heatmaps to visualize the accuracy of each class as more classes are forgotten.

In [ ]:
# Function to create heatmap for a specific model type
def plot_heatmap(model_type, dataset_type='test'):
    plt.figure(figsize=(14, 8))
    
    # Filter data
    filtered_data = all_data[(all_data['model_type'] == model_type) & 
                            (all_data['dataset'] == dataset_type)]
    
    if filtered_data.empty:
        print(f"No data available for {model_type} model")
        return
    
    # Create pivot table for heatmap
    heatmap_data = filtered_data.pivot_table(
        index='num_forgotten_classes', 
        columns='class', 
        values='accuracy'
    )
    
    # Sort by number of forgotten classes
    heatmap_data = heatmap_data.sort_index()
    
    # Create heatmap
    sns.heatmap(heatmap_data, annot=True, fmt='.1f', cmap='RdYlGn', vmin=0, vmax=100)
    
    plt.title(f'{model_type} Model - Class-wise Accuracy ({dataset_type} set)')
    plt.xlabel('Class')
    plt.ylabel('Number of Forgotten Classes')
    plt.tight_layout()
    plt.show()
    
# Plot heatmaps for each model type
for model_type in ['FT', 'GA', 'RL']:
    if all_results[model_type]:
        plot_heatmap(model_type)

## Visualization 3: Comparing Different Methods

Let's compare the performance of different unlearning methods (FT, GA, RL).

In [ ]:
# Calculate the average accuracy for forgotten and retained classes
model_comparison = test_data.groupby(['model_type', 'num_forgotten_classes', 'is_forgotten'])['accuracy'].mean().reset_index()

# Plot comparison for forgotten classes
plt.figure(figsize=(14, 6))

# Forgotten classes
forgotten_comparison = model_comparison[model_comparison['is_forgotten']]
for model_type in sorted(forgotten_comparison['model_type'].unique()):
    if model_type == 'base':
        continue
    model_data = forgotten_comparison[forgotten_comparison['model_type'] == model_type]
    plt.plot(model_data['num_forgotten_classes'], model_data['accuracy'], 
             marker='o', linestyle='-', linewidth=2, label=model_type)

plt.title('Comparing Methods: Average Accuracy on Forgotten Classes')
plt.xlabel('Number of Forgotten Classes')
plt.ylabel('Average Accuracy (%)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Plot comparison for retained classes
plt.figure(figsize=(14, 6))

# Retained classes
retained_comparison = model_comparison[~model_comparison['is_forgotten']]
for model_type in sorted(retained_comparison['model_type'].unique()):
    if model_type == 'base':
        continue
    model_data = retained_comparison[retained_comparison['model_type'] == model_type]
    plt.plot(model_data['num_forgotten_classes'], model_data['accuracy'], 
             marker='s', linestyle='--', linewidth=2, label=model_type)

# Add base model for reference (if available)
if 'base' in model_comparison['model_type'].unique():
    base_acc = model_comparison[model_comparison['model_type'] == 'base']['accuracy'].mean()
    plt.axhline(y=base_acc, color='black', linestyle='-', label='Base Model')

plt.title('Comparing Methods: Average Accuracy on Retained Classes')
plt.xlabel('Number of Forgotten Classes')
plt.ylabel('Average Accuracy (%)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## Visualization 4: Forgetting Efficiency

Let's create a combined metric to evaluate how efficiently each method forgets classes while maintaining accuracy on retained classes.

In [ ]:
# Calculate forgetting efficiency
# A good forgetting method should have low accuracy on forgotten classes and high accuracy on retained classes

forgetting_efficiency = []

for model_type in ['FT', 'GA', 'RL']:
    for num_forgotten in sorted(all_results[model_type].keys()):
        # Skip if no data
        if num_forgotten not in all_results[model_type]:
            continue
            
        # Filter for test data
        df = all_results[model_type][num_forgotten]
        df_test = df[df['dataset'] == 'test']
        
        # Get forgotten classes
        forgotten_classes = eval(df_test['forgotten_classes'].iloc[0]) if not df_test.empty else []
        
        # Calculate metrics
        forgotten_acc = df_test[df_test['class'].isin(forgotten_classes)]['accuracy'].mean() if forgotten_classes else np.nan
        retained_acc = df_test[~df_test['class'].isin(forgotten_classes)]['accuracy'].mean() if forgotten_classes else np.nan
        
        # Efficiency score (higher is better)
        # We want low accuracy on forgotten classes and high accuracy on retained classes
        efficiency = retained_acc - forgotten_acc if not np.isnan(forgotten_acc) and not np.isnan(retained_acc) else np.nan
        
        forgetting_efficiency.append({
            'model_type': model_type,
            'num_forgotten_classes': num_forgotten,
            'forgotten_accuracy': forgotten_acc,
            'retained_accuracy': retained_acc,
            'efficiency_score': efficiency
        })

# Convert to dataframe
efficiency_df = pd.DataFrame(forgetting_efficiency)

# Plot efficiency scores
plt.figure(figsize=(14, 6))

for model_type in sorted(efficiency_df['model_type'].unique()):
    model_data = efficiency_df[efficiency_df['model_type'] == model_type]
    plt.plot(model_data['num_forgotten_classes'], model_data['efficiency_score'], 
             marker='D', linestyle='-', linewidth=2, label=model_type)

plt.title('Forgetting Efficiency Score by Method (Higher is Better)')
plt.xlabel('Number of Forgotten Classes')
plt.ylabel('Efficiency Score (Retained Acc - Forgotten Acc)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## Visualization 5: Progressive Forgetting Impact

Let's see how the accuracy of specific classes changes as more and more classes are forgotten.

In [ ]:
# Select classes for visualization
selected_classes = [0, 5, 9]  # Example classes (first forgotten, middle, last one)

# For each model type, track the accuracy of selected classes
plt.figure(figsize=(15, 10))

for model_type in ['FT', 'GA']:
    for class_num in selected_classes:
        class_accuracies = []
        forget_steps = []
        
        for num_forgotten, df in sorted(all_results[model_type].items()):
            if df.empty:
                continue
                
            # Get test accuracy for the specific class
            class_data = df[(df['dataset'] == 'test') & (df['class'] == class_num)]
            if not class_data.empty:
                class_accuracies.append(class_data['accuracy'].iloc[0])
                forget_steps.append(num_forgotten)
        
        plt.plot(forget_steps, class_accuracies, 
                 marker='o' if class_num == 0 else ('s' if class_num == 5 else 'D'), 
                 linestyle='-', linewidth=2, 
                 label=f'{model_type} - Class {class_num}')

plt.title('Progressive Forgetting: Accuracy of Selected Classes')
plt.xlabel('Number of Forgotten Classes')
plt.ylabel('Accuracy (%)')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

## Summary and Conclusions

Based on the visualizations, we can draw the following conclusions about the continual unlearning performance:

1. **Effectiveness at Forgetting:** The visualizations show how effectively each method (FT, GA, RL) can make the model forget specific classes. Lower accuracy on forgotten classes indicates better forgetting.

2. **Preservation of Retained Knowledge:** We can see how well each method preserves the model's performance on retained classes while forgetting others.

3. **Scalability with More Forgetting:** The trends show how each method performs as more and more classes need to be forgotten.

4. **Method Comparison:** The efficiency score helps identify which method provides the best balance between forgetting targeted classes and retaining performance on others.

5. **Impact on Specific Classes:** The class-wise visualizations reveal how the accuracy of specific classes changes throughout the continual unlearning process.