In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import config
from experiments.cross_val import run_cv_experiment
from utils.visualization import plot_history, plot_tuning_results
import matplotlib.pyplot as plt
from loaders.eeg_loader import load_eeg_dataset
from sklearn.model_selection import StratifiedKFold



experiments = ["Non_Augmented", "ChannelsDropout", "FTSurrogate", "TimeReverse", "SmoothTimeMask"]
results = {}

In [None]:
import os
import sys
import json
import torch
import datetime
import numpy as np
from sklearn.model_selection import StratifiedKFold

# Ensure paths are correct for imports
sys.path.append('..')
from experiments.cross_val import run_cv_experiment 
from loaders.eeg_loader import load_eeg_dataset
import config

# --- 1. OVERNIGHT GRID SEARCH CONFIG ---
MODELS = ["CustomEEGNet", "EEGNet", "DeepConvNet", "ShallowConvNet"]
LRS = [1e-3, 5e-4]
BATCH_SIZES = [16, 32]

# Full Experiment Suite: Maps directly to your get_augmentation names
tuning_grid = {
    "Original": [None],
    "ChannelsDropout": [0.3, 0.5, 0.7],
    "FTSurrogate": [1.57, 3.14, 6.28], 
    "SmoothTimeMask": [150, 300, 500],
    "TimeReverse": [True]
}

# --- 2. DATA LOAD ---
X, y, metadata, n_classes = load_eeg_dataset(mode="single", subject_id=1)
if X is None:
    print(f"ERROR: Data not found. Check relative paths. CWD: {os.getcwd()}")
    sys.exit()

# Zero Leakage: Split indices defined BEFORE any augmentation
groups_for_cv = metadata['trial_ids']
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results_file = "overnight_search_results.json"
if os.path.exists(results_file):
    with open(results_file, 'r') as f:
        all_results = json.load(f)
    print(f"Resuming from {len(all_results)} existing results.")
else:
    all_results = []

# --- 3. THE AUTOMATION LOOP ---
for model_name in MODELS:
    for lr in LRS:
        for bs in BATCH_SIZES:
            # Consistent hyperparams for fair comparison
            hparams = {
                "batch_size": bs,
                "lr": lr,
                "epochs": 60,       # Increased slightly for better convergence
                "weight_decay": 1e-3, 
                "data_multiplier": 4 # Expansion factor for stochastic dataset
            }

            for method, values in tuning_grid.items():
                for val in values:
                    run_id = f"{model_name}_{method}_{val}_LR{lr}_BS{bs}"
                    
                    # Skip if already calculated (Crash Recovery)
                    if any(res['run_id'] == run_id for res in all_results):
                        continue

                    print(f"\n>>> TESTING: {run_id} | {datetime.datetime.now().strftime('%H:%M:%S')}")

                    # Mapping the grid to the specific dict structure get_augmentation expects
                    current_aug = {}
                    if method != "Original":
                        # Convert CamelCase to snake_case for the internal dictionary keys
                        aug_key = method.lower().replace("channelsdropout", "channels_dropout") \
                                               .replace("ftsurrogate", "freq_surrogate") \
                                               .replace("smoothtimemask", "smooth_time_mask") \
                                               .replace("timereverse", "time_reverse")
                        
                        # Map the specific param name required for each function
                        param_name = "p_drop" if "Dropout" in method else \
                                     "phase_noise_max" if "FT" in method else \
                                     "mask_len_samples" if "Mask" in method else "active"
                        
                        current_aug = {aug_key: {param_name: val}}

                    try:
                        history, mean_acc = run_cv_experiment(
                            X, y, groups_for_cv, n_classes, cv, 
                            exp_name=method, 
                            aug_params=current_aug, 
                            hyperparams=hparams,
                            model_name=model_name,
                            verbose=False
                        )

                        all_results.append({
                            "run_id": run_id, 
                            "accuracy": float(mean_acc), 
                            "model": model_name, 
                            "method": method,
                            "val": val, 
                            "lr": lr, 
                            "bs": bs
                        })
                        
                        # Save after every successful run to prevent data loss
                        with open(results_file, 'w') as f:
                            json.dump(all_results, f, indent=4)
                            
                    except Exception as e:
                        print(f"!!! FAILED {run_id}: {str(e)}")
                        continue

print("\n>>> ALL OVERNIGHT EXPERIMENTS FINISHED. RESULTS SAVED.")


>>> TESTING: ShallowConvNet_Original_None_LR0.001_BS16 | 04:15:07




KeyboardInterrupt: 

In [2]:

tuning_grid = {
    "ChannelsDropout": {
        "p_drop": [0.1, 0.3, 0.5, 0.7, 0.9] 
    },
    "FTSurrogate": {
        "phase_noise_max": [0.5, 1.5, 3.14, 4.71, 6.28]  
    },
    "SmoothTimeMask": {
        "mask_len_samples": [102, 204, 307, 410, 512] 
    }
}


In [3]:
X, y, groups, n_classes = load_eeg_dataset(mode="single", subject_id=1)
groups_for_cv = groups['trial_ids'] # This should be length 320

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

In [4]:
tuning_output = []


print(">>> Running Baseline (Non_Augmented)...")

_, base_acc = run_cv_experiment(X, y, groups_for_cv, n_classes, cv, "Original", {}, config.HYPERPARAMS)
tuning_output.append({'method': 'Baseline', 'params': 0, 'accuracy': base_acc})


for method, grid in tuning_grid.items():
    param_name = list(grid.keys())[0]
    values = grid[param_name]
    
    for val in values:
        print(f"\n>>> Testing {method} | {param_name}: {val}")
        
        
        current_params = {method.lower(): {param_name: val}}
        
        
        history, mean_acc = run_cv_experiment(
            X, y, groups_for_cv, n_classes, cv, 
            exp_name=method, 
            aug_params=current_params, 
            hyperparams=config.HYPERPARAMS,
            verbose=False 
        )
        
        tuning_output.append({
            'method': method,
            'params': val,
            'accuracy': mean_acc,
            'history': history
        })



>>> Running Baseline (Non_Augmented)...




 -> Fold 1 Finished. Best Val Acc: 14.06% (Epoch 16)
 -> Fold 2 Finished. Best Val Acc: 20.31% (Epoch 41)
 -> Fold 3 Finished. Best Val Acc: 15.62% (Epoch 18)
 -> Fold 4 Finished. Best Val Acc: 17.19% (Epoch 29)
 -> Fold 5 Finished. Best Val Acc: 18.75% (Epoch 37)

>>> Testing ChannelsDropout | p_drop: 0.1




 -> Fold 1 Finished. Best Val Acc: 15.62% (Epoch 50)
 -> Fold 2 Finished. Best Val Acc: 17.19% (Epoch 3)
 -> Fold 3 Finished. Best Val Acc: 23.44% (Epoch 45)
 -> Fold 4 Finished. Best Val Acc: 25.00% (Epoch 49)
 -> Fold 5 Finished. Best Val Acc: 20.31% (Epoch 20)

>>> Testing ChannelsDropout | p_drop: 0.3




 -> Fold 1 Finished. Best Val Acc: 15.62% (Epoch 7)
 -> Fold 2 Finished. Best Val Acc: 21.88% (Epoch 40)
 -> Fold 3 Finished. Best Val Acc: 20.31% (Epoch 30)
 -> Fold 4 Finished. Best Val Acc: 17.19% (Epoch 12)
 -> Fold 5 Finished. Best Val Acc: 17.19% (Epoch 25)

>>> Testing ChannelsDropout | p_drop: 0.5




 -> Fold 1 Finished. Best Val Acc: 20.31% (Epoch 8)
 -> Fold 2 Finished. Best Val Acc: 20.31% (Epoch 24)
 -> Fold 3 Finished. Best Val Acc: 23.44% (Epoch 40)
 -> Fold 4 Finished. Best Val Acc: 18.75% (Epoch 7)
 -> Fold 5 Finished. Best Val Acc: 25.00% (Epoch 28)

>>> Testing ChannelsDropout | p_drop: 0.7




 -> Fold 1 Finished. Best Val Acc: 14.06% (Epoch 7)
 -> Fold 2 Finished. Best Val Acc: 21.88% (Epoch 39)
 -> Fold 3 Finished. Best Val Acc: 23.44% (Epoch 18)
 -> Fold 4 Finished. Best Val Acc: 17.19% (Epoch 17)
 -> Fold 5 Finished. Best Val Acc: 14.06% (Epoch 4)

>>> Testing ChannelsDropout | p_drop: 0.9




 -> Fold 1 Finished. Best Val Acc: 17.19% (Epoch 4)
 -> Fold 2 Finished. Best Val Acc: 21.88% (Epoch 35)
 -> Fold 3 Finished. Best Val Acc: 21.88% (Epoch 38)
 -> Fold 4 Finished. Best Val Acc: 17.19% (Epoch 24)
 -> Fold 5 Finished. Best Val Acc: 23.44% (Epoch 30)

>>> Testing FTSurrogate | phase_noise_max: 0.5




 -> Fold 1 Finished. Best Val Acc: 17.19% (Epoch 1)
 -> Fold 2 Finished. Best Val Acc: 17.19% (Epoch 4)
 -> Fold 3 Finished. Best Val Acc: 29.69% (Epoch 27)
 -> Fold 4 Finished. Best Val Acc: 17.19% (Epoch 48)
 -> Fold 5 Finished. Best Val Acc: 25.00% (Epoch 44)

>>> Testing FTSurrogate | phase_noise_max: 1.5




 -> Fold 1 Finished. Best Val Acc: 20.31% (Epoch 30)
 -> Fold 2 Finished. Best Val Acc: 20.31% (Epoch 50)
 -> Fold 3 Finished. Best Val Acc: 26.56% (Epoch 17)
 -> Fold 4 Finished. Best Val Acc: 23.44% (Epoch 44)
 -> Fold 5 Finished. Best Val Acc: 21.88% (Epoch 5)

>>> Testing FTSurrogate | phase_noise_max: 3.14




 -> Fold 1 Finished. Best Val Acc: 20.31% (Epoch 27)
 -> Fold 2 Finished. Best Val Acc: 23.44% (Epoch 7)
 -> Fold 3 Finished. Best Val Acc: 25.00% (Epoch 34)
 -> Fold 4 Finished. Best Val Acc: 25.00% (Epoch 47)
 -> Fold 5 Finished. Best Val Acc: 20.31% (Epoch 5)

>>> Testing FTSurrogate | phase_noise_max: 4.71




 -> Fold 1 Finished. Best Val Acc: 25.00% (Epoch 14)
 -> Fold 2 Finished. Best Val Acc: 20.31% (Epoch 31)
 -> Fold 3 Finished. Best Val Acc: 26.56% (Epoch 45)
 -> Fold 4 Finished. Best Val Acc: 26.56% (Epoch 43)
 -> Fold 5 Finished. Best Val Acc: 23.44% (Epoch 3)

>>> Testing FTSurrogate | phase_noise_max: 6.28




 -> Fold 1 Finished. Best Val Acc: 18.75% (Epoch 9)
 -> Fold 2 Finished. Best Val Acc: 21.88% (Epoch 44)
 -> Fold 3 Finished. Best Val Acc: 25.00% (Epoch 24)
 -> Fold 4 Finished. Best Val Acc: 28.12% (Epoch 14)
 -> Fold 5 Finished. Best Val Acc: 15.62% (Epoch 27)

>>> Testing SmoothTimeMask | mask_len_samples: 102




 -> Fold 1 Finished. Best Val Acc: 21.88% (Epoch 44)
 -> Fold 2 Finished. Best Val Acc: 14.06% (Epoch 25)
 -> Fold 3 Finished. Best Val Acc: 23.44% (Epoch 48)
 -> Fold 4 Finished. Best Val Acc: 25.00% (Epoch 32)
 -> Fold 5 Finished. Best Val Acc: 21.88% (Epoch 28)

>>> Testing SmoothTimeMask | mask_len_samples: 204




 -> Fold 1 Finished. Best Val Acc: 20.31% (Epoch 27)
 -> Fold 2 Finished. Best Val Acc: 20.31% (Epoch 18)
 -> Fold 3 Finished. Best Val Acc: 20.31% (Epoch 16)
 -> Fold 4 Finished. Best Val Acc: 15.62% (Epoch 12)
 -> Fold 5 Finished. Best Val Acc: 25.00% (Epoch 30)

>>> Testing SmoothTimeMask | mask_len_samples: 307




 -> Fold 1 Finished. Best Val Acc: 17.19% (Epoch 6)
 -> Fold 2 Finished. Best Val Acc: 18.75% (Epoch 47)
 -> Fold 3 Finished. Best Val Acc: 17.19% (Epoch 36)
 -> Fold 4 Finished. Best Val Acc: 10.94% (Epoch 15)
 -> Fold 5 Finished. Best Val Acc: 21.88% (Epoch 25)

>>> Testing SmoothTimeMask | mask_len_samples: 410




KeyboardInterrupt: 

In [None]:

plot_tuning_results(tuning_output)