Here, we aggregate the model stats for our rebuttal.

In [1]:
import os
import pandas as pd
import numpy as np
import argparse
from pathlib import Path
import glob

def find_stats_files(base_dataset, base_model, models_dir="models"):
    """
    Find all stats.csv files for a given base dataset and model combination.
    
    Args:
        base_dataset: e.g., "base_cinic10", "base_cifar20"
        base_model: e.g., "cifar-resnet-18", "cifar-resnet-50", "cifar-vit"
        models_dir: Root directory containing the models
    
    Returns:
        List of tuples (cls_drop_string, file_path)
    """
    # Pattern to match the directory structure
    pattern = f"{models_dir}/mia/{base_dataset}/*/{base_model}/*/*/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_*/predictions/stats.csv"
    
    files = glob.glob(pattern)

    # Extract cls_drop string from each file path
    cls_drop_files = []
    for file_path in files:
        # Extract cls_drop string from path
        parts = file_path.split('/')
        cls_drop_part = [part for part in parts if part.startswith('cls_drop_')]
        if cls_drop_part:
            cls_drop_str = cls_drop_part[0].split('_')[-1]
            cls_drop_files.append((cls_drop_str, file_path))
    
    # Sort by cls_drop string
    cls_drop_files.sort(key=lambda x: [int(n) for n in x[0].split(',')])
    return cls_drop_files

def load_and_aggregate_data(cls_drop_files):
    """
    Load all stats.csv files and aggregate into a single DataFrame.
    
    Args:
        cls_drop_files: List of tuples (cls_drop_string, file_path)
    
    Returns:
        DataFrame with aggregated data
    """
    all_data = []
    
    for cls_drop_str, file_path in cls_drop_files:
        try:
            df = pd.read_csv(file_path)
            df['cls_drop'] = cls_drop_str
            all_data.append(df)
        except Exception as e:
            print(f"Warning: Could not read {file_path}: {e}")
    
    if not all_data:
        raise ValueError("No valid stats.csv files found!")
    
    return pd.concat(all_data, ignore_index=True)

def create_metric_table(df, qmia_metric, baseline_metric, metric_name):
    """
    Create a table for a specific metric showing QMIA and Baseline values.
    
    Args:
        df: Aggregated DataFrame
        qmia_metric: Column name for QMIA metric
        baseline_metric: Column name for Baseline metric
        metric_name: Human-readable metric name
    
    Returns:
        DataFrame with the metric table
    """
    # Pivot table with cls_drop as columns and dataset_type as rows
    qmia_table = df.pivot(index='dataset_type', columns='cls_drop', values=qmia_metric)
    baseline_table = df.pivot(index='dataset_type', columns='cls_drop', values=baseline_metric)
    
    # Create multi-level columns
    qmia_cols = pd.MultiIndex.from_product([['QMIA'], qmia_table.columns], 
                                           names=['Method', 'cls_drop'])
    baseline_cols = pd.MultiIndex.from_product([['Baseline'], baseline_table.columns], 
                                               names=['Method', 'cls_drop'])
    
    qmia_table.columns = qmia_cols
    baseline_table.columns = baseline_cols
    
    # Combine tables
    result_table = pd.concat([qmia_table, baseline_table], axis=1)
    
    # Sort columns by cls_drop numbers within each method
    result_table = result_table.reindex(sorted(result_table.columns, key=lambda x: [int(n) for n in x[1].split(',')]), axis=1)
    
    return result_table

def format_table_for_display(table, metric_name):
    """Format table with proper rounding and add a title."""
    # Round to 4 decimal places for readability
    formatted_table = table.round(4)
    
    # Add a descriptive name
    formatted_table.index.name = 'Dataset Type'
    
    return formatted_table

In [4]:
args = type('Args', (), {
    'base_dataset': 'base_cifar20',
    'base_model': 'cifar-vit',
    'models_dir': '../models',
    'output_dir': '../rebuttal_outputs'
})()

# Create output directory if it doesn't exist
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)

# Find all stats files for the given dataset and model
print(f"Finding stats files for {args.base_dataset} with {args.base_model}...")
cls_drop_files = find_stats_files(args.base_dataset, args.base_model, args.models_dir)

if not cls_drop_files:
    print(f"No stats.csv files found for {args.base_dataset}/{args.base_model}")

print(f"Found {len(cls_drop_files)} stats files:")
for cls_drop_num, file_path in cls_drop_files:
    print(f"  cls_drop_{cls_drop_num}: {file_path}")

# Load and aggregate data
print("\nLoading and aggregating data...")
df = load_and_aggregate_data(cls_drop_files)

# Create output filename prefix
output_prefix = f"{args.base_dataset}_{args.base_model.replace('-', '_')}"

# Create the three metric tables
metrics = [
    ('qmia_auc', 'baseline_auc', 'AUC'),
    ('qmia_tpr_at_fpr_1%', 'baseline_tpr_at_fpr_1%', 'TPR_at_1pct_FPR'),
    # ('qmia_tpr_at_fpr_0.1%', 'baseline_tpr_at_fpr_0.1%', 'TPR_at_0.1pct_FPR')
]

tables = []

for qmia_metric, baseline_metric, metric_name in metrics:
    print(f"\nCreating {metric_name} table...")
    
    # Create the table
    table = create_metric_table(df, qmia_metric, baseline_metric, metric_name)
    
    # # Save to CSV
    # output_file = output_dir / f"{output_prefix}_{metric_name.lower()}.csv"
    # formatted_table.to_csv(output_file)
    # print(f"Saved {metric_name} table to: {output_file}")

    tables.append(table)

for (qmia_metric, baseline_metric, metric_name), table in zip(metrics, tables):
    if "TPR" in metric_name:
        table *= 100
        table[numeric_cols] = table[numeric_cols].round(2)
    else:
        table = table.round(3)
    numeric_cols = table.select_dtypes(include=['float64', 'int64']).columns
    
    # Stack Method annotations as rows
    stacked_table = table.loc[['in_distribution']].stack(level='Method').reset_index()
    stacked_table = stacked_table.rename(columns={'level_1': 'Method'})
    stacked_table = stacked_table.set_index(['dataset_type', 'Method'])
    
    # Define the desired column order
    col_order = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
                 '01', '23', '45', '67', '89', '01234', '56789', 
                 '0123456789', '10111213141516171819']
    
    # Reorder columns, keeping dataset_type and Method as index
    stacked_table = stacked_table.reindex(columns=col_order)
    display(stacked_table)

# # Save raw aggregated data as well
# raw_output_file = output_dir / f"{output_prefix}_raw_data.csv"
# df.to_csv(raw_output_file, index=False)
# print(f"\nSaved raw aggregated data to: {raw_output_file}")

# print(f"\nAll outputs saved to: {output_dir}")

Finding stats files for base_cifar20 with cifar-vit...
Found 19 stats files:
  cls_drop_0: ../models/mia/base_cifar20/0_16/cifar-vit/attack_cifar20/0_16/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_0/predictions/stats.csv
  cls_drop_1: ../models/mia/base_cifar20/0_16/cifar-vit/attack_cifar20/0_16/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_1/predictions/stats.csv
  cls_drop_01: ../models/mia/base_cifar20/0_16/cifar-vit/attack_cifar20/0_16/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_01/predictions/stats.csv
  cls_drop_2: ../models/mia/base_cifar20/0_16/cifar-vit/attack_cifar20/0_16/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_2/predictions/stats.csv
  cls_drop_3: ../models/mia/base_cifar20/0_16/cifar-vit/attack_cifar20/0_16/facebook/convnext-tiny-224/score_fn_top_two_margin/loss_fn_gaussian/cls_drop_3/predictions/stats.csv
  cls_drop_4: ../models/mia/base_cif

  stacked_table = table.loc[['in_distribution']].stack(level='Method').reset_index()


Unnamed: 0_level_0,cls_drop,0,1,2,3,4,5,6,7,8,9,01,23,45,67,89,01234,56789,0123456789,10111213141516171819
dataset_type,Method,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
in_distribution,Baseline,0.66,0.656,0.663,0.658,0.661,0.659,0.659,0.658,0.659,0.66,0.657,0.661,0.661,0.658,0.66,0.661,0.658,0.661,0.658
in_distribution,QMIA,0.665,0.656,0.662,0.661,0.665,0.66,0.661,0.663,0.659,0.663,0.659,0.662,0.661,0.66,0.661,0.665,0.66,0.665,0.662


  stacked_table = table.loc[['in_distribution']].stack(level='Method').reset_index()


Unnamed: 0_level_0,cls_drop,0,1,2,3,4,5,6,7,8,9,01,23,45,67,89,01234,56789,0123456789,10111213141516171819
dataset_type,Method,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
in_distribution,Baseline,1.1,1.12,1.08,1.14,1.14,1.13,1.11,1.12,1.08,1.17,1.06,1.1,1.11,1.08,1.11,1.11,1.09,1.14,1.12
in_distribution,QMIA,3.38,3.39,3.48,3.09,3.36,3.19,3.12,3.5,3.26,3.17,3.6,2.93,3.36,3.07,3.1,3.24,3.0,3.2,3.17
