# Analyze TIES Hyperparameter Gridsearch

## Imports

In [1]:
import os
import json
import pandas as pd

# Iterate Through Each Hyperparameter Combination folder

In [8]:
unlearn_method = "erasediff"
training_method = "independent"
merge_method = "ties_proj"
concept_type = "style"
concepts = ["Winter"]
OUTPUT_ROOT = os.environ.get("OUTPUT_ROOT", "/fs/scratch/PAS2099/lee.10369/CUIG")
base_dir = f"{OUTPUT_ROOT}/{unlearn_method}/eval_results/{training_method}/merge/{merge_method}/{concept_type}"
all_results = []

# Find hypeparameter values
print(f"Searching in {base_dir}")
combos = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
print(f"Found {len(combos)} combinations in {base_dir}")

# Iterate through each combination
for combo in combos:
    combo_dir = os.path.join(base_dir, combo)
    for concept in concepts:
        # Each concept folder contains a "metrics" folder with the JSON files
        metrics_dir = os.path.join(combo_dir, f"thru{concept}", 'metrics')
        summary_file = os.path.join(metrics_dir, 'summary.json')
        
        if os.path.exists(summary_file):
            with open(summary_file, 'r') as f:
                summary_data = json.load(f)

            unlearn_accuracy_avg = summary_data.get('UA', {})
            retention_accuracy = summary_data.get('IRA', {})
            cross_retention_accuracy = summary_data.get('CRA', {})
            final_score = (unlearn_accuracy_avg + (retention_accuracy + cross_retention_accuracy)/2) / 2
            
            all_results.append({
                'combo': combo,
                'concept': concept,
                'concept_type': concept_type,
                'UA': unlearn_accuracy_avg,
                'IRA': retention_accuracy,
                'CRA': cross_retention_accuracy,
                'final_score': final_score
            })
        else:
            print(f"Missing metrics for combo {combo} and concept {concept} in {combo}")

# Combine results into a DataFrame for further analysis
df_results = pd.DataFrame(all_results)

# ---------------------------
# 2. Extract Hyperparameter Values from 'combo'
# ---------------------------
# The combo string is in the format "2Rank1Alpha0.1Dropout"
if "ties" in merge_method:
    pattern = r'lambda(?P<lambda>[\d\.]+)_topk(?P<topk>[\d\.]+)'
extracted = df_results['combo'].str.extract(pattern)
df_results['topk'] = extracted['topk'].astype(float)
df_results['lambda'] = extracted['lambda'].astype(float)

df_results.sort_values(['final_score', 'UA'], ascending=False, inplace=True)

REPO_ROOT = os.environ.get("REPO_ROOT", "/users/PAS2099/justinhylee135/Research/UnlearningDM/CUIG")
df_save_dir = f"{REPO_ROOT}/Analysis/Notebooks/notebook_output/{unlearn_method}/{training_method}/merge/{merge_method}/{concept_type}"
os.makedirs(df_save_dir, exist_ok=True)
df_results.to_excel(os.path.join(df_save_dir, f"thru{concepts[0]}.xlsx"), index=False)
print(f"Results saved to {df_save_dir}/thru{concepts[0]}.xlsx")

print(f"Total number of combinations: {len(df_results)}")
df_results.head(len(df_results))

Searching in /fs/scratch/PAS2099/lee.10369/CUIG/erasediff/eval_results/independent/merge/ties_proj/style
Found 17 combinations in /fs/scratch/PAS2099/lee.10369/CUIG/erasediff/eval_results/independent/merge/ties_proj/style
Missing metrics for combo lambda1.00_topk0.30 and concept Winter in lambda1.00_topk0.30
Results saved to /users/PAS2099/justinhylee135/Research/UnlearningDM/CUIG/Analysis/Notebooks/notebook_output/erasediff/independent/merge/ties_proj/style/thruWinter.xlsx
Total number of combinations: 16


Unnamed: 0,combo,concept,concept_type,UA,IRA,CRA,final_score,topk,lambda
14,lambda2.25_topk0.20,Winter,style,96.59,37.5,98.91,82.3975,0.2,2.25
9,lambda2.25_topk0.60,Winter,style,96.59,35.42,100.0,82.15,0.6,2.25
15,lambda2.25_topk0.40,Winter,style,92.05,40.62,99.46,81.045,0.4,2.25
3,lambda2.00_topk0.40,Winter,style,86.36,48.96,98.37,80.0125,0.4,2.0
5,lambda2.00_topk0.20,Winter,style,88.64,41.67,98.91,79.465,0.2,2.0
1,lambda2.25_topk0.80,Winter,style,89.77,35.42,100.0,78.74,0.8,2.25
4,lambda1.75_topk0.20,Winter,style,77.27,57.29,98.91,77.685,0.2,1.75
8,lambda2.00_topk0.60,Winter,style,82.95,43.75,99.46,77.2775,0.6,2.0
0,lambda1.75_topk0.40,Winter,style,72.73,62.5,97.83,76.4475,0.4,1.75
2,lambda1.75_topk0.60,Winter,style,72.73,59.38,98.37,75.8025,0.6,1.75
