## Grid Search for lambda penalty 

In [None]:
import os
import numpy as np
import joblib
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from scipy.fft import rfft, rfftfreq

# =====================================================================
# 0) Global Parameters & Configuration
# =====================================================================
BASE_DIR = r"C:\Users\schmi\Documents\Studium\TUM\Masterthesis\Experimental Data"
PARTICIPANTS = [1, 3, 4, 6]
PREPROCESSED_DATA_DIR_NAME = "Preprocessed_Data_Matrix"

# --- MODIFIED: Analyze only 4 synergies ---
N_SYNERGIES_TO_ANALYZE = [5, 6, 7]

# --- NEW: Analyze only 6 representative trials (one for each grasp type) ---
TRIALS_TO_ANALYZE = [1, 3, 5, 7]

# --- MEM‑FIX: Use float32 everywhere heavy ---------------------------
DTYPE = np.float32

P7_SKIPPED_TRIALS = [20, 22]

# MMF Parameters
MU_S = DTYPE(0.1)
DEFAULT_LAMBD_S = DTYPE(10.0) 
NUM_EMG_CHANNELS = 117
BASE_ITERATIONS = 10000
ITERATION_STEP = 5000
MAX_TOTAL_ITERATIONS = 40000

# Grid Search Parameters
LAMBDA_GRID_VALUES = [10, 100, 200, 500, 800, 1000, 2000, 4000]
N_SYNERGIES_FOR_GRID_SEARCH = 4

# Smoothness Check Parameters
SMOOTHNESS_FREQ_THRESHOLD_HZ = 4.0
HIGH_FREQ_POWER_RATIO = 0.15

# Data & Clipping Parameters
SAMPLING_RATE = 2000.0
PRE_LIFT_CLIP_SECONDS = 1.0
POST_LIFT_CLIP_SECONDS = 2.0
PRE_LIFT_SAMPLES = int(PRE_LIFT_CLIP_SECONDS * SAMPLING_RATE)
POST_LIFT_SAMPLES = int(POST_LIFT_CLIP_SECONDS * SAMPLING_RATE)

# =====================================================================
# 1) Helper Functions (Trial Info, R^2)
# =====================================================================
def trial_info(trial_number):
    """Returns a dictionary with details for a given trial number."""
    protocol = {
        1: ("Precision Grasp (Four Fingers and Thumb)", "Precision Handle", 0.25, "Left Lever", "No"),
        2: ("Precision Grasp (Four Fingers and Thumb)", "Precision Handle", 0.25, "Left Lever", "Yes"),
        3: ("Lateral Pinch Grasp", "Lateral Pinch Handle", 0.25, "Right Lever", "No"),
        4: ("Lateral Pinch Grasp", "Lateral Pinch Handle", 0.25, "Right Lever", "Yes"),
        5: ("Ball Grasp", "Ball Handle", 0.50, "Left Lever", "No"),
        6: ("Ball Grasp", "Ball Handle", 0.50, "Left Lever", "Yes"),
        7: ("Precision Grasp (Thumb and Index)", "Precision Handle", 0.25, "Front Lever", "No"),
        8: ("Precision Grasp (Thumb and Index)", "Precision Handle", 0.25, "Front Lever", "Yes"),
        9: ("Disc Grip", "Disc Handle", 0.50, "Back Lever", "No"),
        10: ("Disc Grip", "Disc Handle", 0.50, "Back Lever", "Yes"),
        11: ("Power Bar Grasp", "Power Bar Handle", 0.50, "Front Lever", "No"),
        12: ("Power Bar Grasp", "Power Bar Handle", 0.50, "Front Lever", "Yes"),
        13: ("Precision Grasp (Four Fingers and Thumb)", "Precision Handle", 0.25, "Front Lever", "No"),
        14: ("Precision Grasp (Four Fingers and Thumb)", "Precision Handle", 0.25, "Front Lever", "Yes"),
        15: ("Lateral Pinch Grasp", "Lateral Pinch Handle", 0.25, "Back Lever", "No"),
        16: ("Lateral Pinch Grasp", "Lateral Pinch Handle", 0.25, "Back Lever", "Yes"),
        17: ("Ball Grasp", "Ball Handle", 0.50, "Front Lever", "No"),
        18: ("Ball Grasp", "Ball Handle", 0.50, "Front Lever", "Yes"),
        19: ("Precision Grasp (Thumb and Index)", "Precision Handle", 0.25, "Back Lever", "No"),
        20: ("Precision Grasp (Thumb and Index)", "Precision Handle", 0.25, "Back Lever", "Yes"),
        21: ("Disc Grip", "Disc Handle", 0.50, "Left Lever", "No"),
        22: ("Disc Grip", "Disc Handle", 0.50, "Left Lever", "Yes"),
        23: ("Power Bar Grasp", "Power Bar Handle", 0.50, "Right Lever", "No"),
        24: ("Power Bar Grasp", "Power Bar Handle", 0.50, "Right Lever", "Yes"),
    }
    if trial_number not in protocol: return None
    tup = protocol[trial_number]
    return {'grasp_type': tup[0], 'handle_type': tup[1], 'weight_kg': tup[2], 'lever_side': tup[3], 'knowledge': tup[4]}

def r2_score_matrix(original_2d: np.ndarray, reconstructed_2d: np.ndarray) -> float:
    """Compute R² for two equally‑shaped matrices using float32 throughout."""
    orig = original_2d.astype(DTYPE, copy=False)
    recon = reconstructed_2d.astype(DTYPE, copy=False)
    diff = np.subtract(orig, recon, dtype=DTYPE)
    ss_res = np.sum(diff * diff, dtype=DTYPE)
    mean_orig = np.mean(orig, dtype=DTYPE, keepdims=True)
    diff_tot = np.subtract(orig, mean_orig, dtype=DTYPE)
    ss_tot = np.sum(diff_tot * diff_tot, dtype=DTYPE)
    if ss_tot == 0: return 1.0
    return float(1.0 - (ss_res / ss_tot))

# =====================================================================
# 2) MMF Core & Adaptive Learning Logic
# =====================================================================
def check_smoothness_fft(c_vector: np.ndarray, fs: float, freq_thresh: float, power_ratio_thresh: float) -> bool:
    """Return True if signal is ‘smooth’ (low high‑freq power)."""
    n = len(c_vector)
    if n < 2: return True
    yf = rfft(c_vector.astype(DTYPE, copy=False) - c_vector.mean())
    xf = rfftfreq(n, 1 / fs).astype(DTYPE)
    power = np.square(np.abs(yf).astype(DTYPE), dtype=DTYPE)
    total_power = np.sum(power, dtype=DTYPE)
    if total_power < DTYPE(1e-9): return True
    high_freq_indices = np.where(xf > freq_thresh)
    high_freq_power = np.sum(power[high_freq_indices], dtype=DTYPE)
    ratio = high_freq_power / total_power
    return ratio < power_ratio_thresh

def run_mmf_iterations(X: np.ndarray, W: np.ndarray, C: np.ndarray, k: int,
                       mu_s: DTYPE, lambd_s: DTYPE, num_iters: int):
    """Run MMF updates for a fixed number of iterations."""
    n_vars, n_synergies = W.shape
    normX = np.linalg.norm(X.astype(DTYPE, copy=False))
    if normX == 0: return W, C

    mu_W = mu_s / normX
    mu_C = mu_s / normX
    lambd = lambd_s / (n_vars * n_synergies)

    for _ in range(num_iters):
        W_update = mu_W * ((X @ C.T) - (W @ (C @ C.T)) - lambd * W)
        C_update = mu_C * ((W.T @ X) - ((W.T @ W) @ C))
        W += W_update.astype(DTYPE)
        C += C_update.astype(DTYPE)
        W[:k, :] = np.maximum(W[:k, :], 0)
        C = np.maximum(C, 0)
    return W, C

def adaptive_mmf_wrapper(data_matrix: np.ndarray, n_synergies: int, lambd_s: DTYPE):
    """Adaptive wrapper that now accepts lambd_s for regularization."""
    if data_matrix is None or data_matrix.shape[0] < 5:
        print("   -> WARNING: Input data matrix is too short or None. Skipping MMF.")
        return None, None

    X = data_matrix.T.astype(DTYPE, copy=False)
    n_vars, n_samples = X.shape

    rng = np.random.default_rng(42)
    W = (rng.random((n_vars, n_synergies), dtype=DTYPE) * DTYPE(0.01))
    C = (rng.random((n_synergies, n_samples), dtype=DTYPE) * DTYPE(0.01))

    total_iters = 0
    while total_iters < MAX_TOTAL_ITERATIONS:
        iters_to_run = BASE_ITERATIONS if total_iters == 0 else ITERATION_STEP
        W, C = run_mmf_iterations(X, W, C, NUM_EMG_CHANNELS, MU_S, lambd_s, iters_to_run)
        total_iters += iters_to_run

        smooth = [check_smoothness_fft(C[i, :], SAMPLING_RATE, SMOOTHNESS_FREQ_THRESHOLD_HZ, HIGH_FREQ_POWER_RATIO) for i in range(n_synergies)]
        if all(smooth):
            print(f"   -> Smoothness achieved for lambda={lambd_s:.2f} after {total_iters} iterations.")
            break
    else:
        print(f"   -> WARNING: Max iterations ({MAX_TOTAL_ITERATIONS}) reached for lambda={lambd_s:.2f}.")

    return W, C

# =====================================================================
# 3) Plotting & Execution Helpers
# =====================================================================
def plot_and_save_synergies(W, C, synergy_out_dir, prefix, index_dict, r2_score):
    """Plots and saves synergies."""
    n_features, n_synergies = W.shape
    if C.shape[0] != n_synergies: raise ValueError("Synergy mismatch in W and C.")

    otb_start, otb_end = index_dict['otb_indices']
    myo_start, myo_end = index_dict['myo_indices']
    kin_hand_start, kin_hand_end = index_dict['kin_hand_indices']

    colors = ["gray"] * n_features
    for i in range(otb_start, otb_end): colors[i] = '#4F81BD'
    for i in range(myo_start, myo_end): colors[i] = 'pink'
    for i in range(kin_hand_start, kin_hand_end): colors[i] = 'orange'

    os.makedirs(synergy_out_dir, exist_ok=True)
    fig, axes = plt.subplots(n_synergies, 2, figsize=(20, 4 * n_synergies), squeeze=False)
    fig.suptitle(f'Synergy Analysis for {prefix} | Overall R²: {r2_score:.4f}', fontsize=16, fontweight='bold')

    for i in range(n_synergies):
        ax_bar, ax_line = axes[i, 0], axes[i, 1]
        ax_bar.bar(np.arange(n_features), W[:, i], color=colors, alpha=0.7, width=1.0)
        ax_bar.set_title(f"Synergy {i+1} (W column {i+1})")
        legend_elems = [Patch(facecolor='#4F81BD', label='OTB'), Patch(facecolor='pink', label='Myo'), Patch(facecolor='orange', label='Kinematic Hand')]
        ax_bar.legend(handles=legend_elems, loc='upper left')
        ax_bar.set_xticks([])

        time_axis = np.arange(C.shape[1]) / SAMPLING_RATE
        ax_line.plot(time_axis, C[i, :], color='green', linewidth=1.5)
        ax_line.set_title(f"Activation {i+1} (C row {i+1})")
        ax_line.set_xlabel('Time (s)')
        ax_line.grid(True, linestyle='--')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(synergy_out_dir, f"{prefix}_synergies.png"), dpi=150)
    plt.close(fig)

def run_analysis_block(analysis_path: str, data_matrix: np.ndarray, n_syn: int, trial_idx: int,
                       participant_dir: str, index_dict: dict, is_skipped: bool, lambd_s: DTYPE):
    """Run and save one full analysis block using the provided best lambda."""
    print(f" Running {analysis_path} for Trial {trial_idx:02d} with {n_syn} synergies...")

    save_dir = os.path.join(participant_dir, "Synergies_Publication", analysis_path, f"{n_syn}_Syn", f"Trial_{trial_idx:02d}")
    if is_skipped: save_dir += "_SKIPPED_TRIAL"
    os.makedirs(save_dir, exist_ok=True)

    if is_skipped or data_matrix is None:
        if is_skipped: print("   -> Skipping analysis for this trial as marked.")
        return

    W, C = adaptive_mmf_wrapper(data_matrix, n_syn, lambd_s)
    if W is None or C is None: return

    reconstructed = (W @ C).astype(DTYPE)
    r2 = r2_score_matrix(data_matrix.T.astype(DTYPE, copy=False), reconstructed)
    print(f"   -> Final R²: {r2:.4f}")

    np.save(os.path.join(save_dir, "W_synergies.npy"), W)
    np.save(os.path.join(save_dir, "C_activations.npy"), C)
    plot_prefix = f"synergies_{os.path.normpath(analysis_path).replace(os.sep, '_')}"
    plot_and_save_synergies(W, C, save_dir, plot_prefix, index_dict, r2)

# =====================================================================
# 4) Pipeline Entrypoint
# =====================================================================
def main():
    for pid in PARTICIPANTS:
        participant_str = f"P({pid})"
        participant_dir = os.path.join(BASE_DIR, participant_str)
        pre_dir = os.path.join(participant_dir, PREPROCESSED_DATA_DIR_NAME)

        paths = {"p1": os.path.join(pre_dir, f"P{pid}_combined_matrix_phase1.npy"),
                 "p2": os.path.join(pre_dir, f"P{pid}_combined_matrix_phase2.npy"),
                 "idx": os.path.join(pre_dir, f"P{pid}_feature_indices.joblib")}
        if not all(os.path.exists(p) for p in paths.values()):
            print(f"ERROR: Preprocessed files not found for P{pid}. Skipping.")
            continue

        print("\n" + "="*80 + f"\nProcessing Participant: {pid}\n" + "="*80)
        p1_full = np.load(paths["p1"]).astype(DTYPE)
        p2_full = np.load(paths["p2"]).astype(DTYPE)
        index_dict = joblib.load(paths["idx"])

        # --- Step 1: Grid Search for Optimal Lambda ---
        print(f"--- Step 1: Running Grid Search to find optimal lambda for P{pid} ---")
        grid_search_dir = os.path.join(participant_dir, "Grid_Search_Results")
        os.makedirs(grid_search_dir, exist_ok=True)
        vaf_scores, activation_costs = [], []

        for lambd_val in LAMBDA_GRID_VALUES:
            print(f"  Testing lambda = {lambd_val:.2f}...")
            W_grid, C_grid = adaptive_mmf_wrapper(p1_full, N_SYNERGIES_FOR_GRID_SEARCH, DTYPE(lambd_val))
            if W_grid is not None and C_grid is not None:
                vaf = r2_score_matrix(p1_full.T, W_grid @ C_grid)
                cost = np.linalg.norm(C_grid, 'fro')
                vaf_scores.append(vaf)
                activation_costs.append(cost)
            else:
                vaf_scores.append(0)
                activation_costs.append(np.inf)

        # --- Automatically Select Best Lambda ---
        vaf_norm = (np.array(vaf_scores) - np.min(vaf_scores)) / (np.max(vaf_scores) - np.min(vaf_scores))
        cost_norm = (np.array(activation_costs) - np.min(activation_costs)) / (np.max(activation_costs) - np.min(activation_costs))
        distances = np.sqrt((1 - vaf_norm)**2 + cost_norm**2)
        best_idx = np.argmin(distances)
        best_lambda = DTYPE(LAMBDA_GRID_VALUES[best_idx])
        print(f"--- Grid Search Complete. Selected optimal lambda = {best_lambda:.2f} ---")

        # --- Plot and Save Grid Search Results ---
        fig, ax1 = plt.subplots(figsize=(12, 7))
        ax1.plot(LAMBDA_GRID_VALUES, vaf_scores, 'o-', color='tab:blue', label='VAF (R²)')
        ax1.set_xlabel('Lambda Value', fontsize=12)
        ax1.set_ylabel('VAF (R²)', color='tab:blue', fontsize=12)
        ax1.tick_params(axis='y', labelcolor='tab:blue')
        ax1.set_xscale('log')
        ax1.grid(True, which="both", ls="--")
        ax2 = ax1.twinx()
        ax2.plot(LAMBDA_GRID_VALUES, activation_costs, 's--', color='tab:red', label='Activation Cost')
        ax2.set_ylabel('Activation Cost (Norm of C)', color='tab:red', fontsize=12)
        ax2.tick_params(axis='y', labelcolor='tab:red')
        ax1.axvline(x=best_lambda, color='green', linestyle=':', linewidth=2, label=f'Selected λ = {best_lambda:.2f}')
        fig.legend(loc="upper right", bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
        plt.title(f'Lambda Grid Search for Participant {pid}', fontsize=16, fontweight='bold')
        plt.savefig(os.path.join(grid_search_dir, f'P{pid}_lambda_selection.png'), dpi=150)
        plt.close(fig)

        # --- Step 2: Run Focused Analysis ---
        print(f"--- Step 2: Running focused analysis for P{pid} using lambda = {best_lambda:.2f} ---")
        p1_lens = index_dict.get("phase1_trial_lengths")
        p2_lens = index_dict.get("phase2_trial_lengths")
        skip_list = P7_SKIPPED_TRIALS if pid == 7 else []
        valid_trials = [t for t in range(1, 25) if t not in skip_list]

        if p1_lens is None or p2_lens is None or len(p1_lens) != len(valid_trials) or len(p2_lens) != len(valid_trials):
            print("ERROR: Trial length mismatch. Skipping participant analysis.")
            continue

        p1_map, p2_map = {}, {}
        p1_row, p2_row = 0, 0
        for trial_idx, p1_len, p2_len in zip(valid_trials, p1_lens, p2_lens):
            p1_map[trial_idx] = slice(p1_row, p1_row + p1_len)
            p1_row += p1_len
            p2_map[trial_idx] = slice(p2_row, p2_row + p2_len)
            p2_row += p2_len

        # --- MODIFIED: Loop over the specific synergies and trials ---
        for n_syn in N_SYNERGIES_TO_ANALYZE:
            for trial_idx in TRIALS_TO_ANALYZE:
                is_skip = (pid == 7 and trial_idx in skip_list)
                
                full_p1 = p1_full[p1_map[trial_idx]] if not is_skip and trial_idx in p1_map else None
                full_p2 = p2_full[p2_map[trial_idx]] if not is_skip and trial_idx in p2_map else None
                lift_p1 = full_p1[-PRE_LIFT_SAMPLES:] if full_p1 is not None else None
                lift_p2 = full_p2[:POST_LIFT_SAMPLES] if full_p2 is not None else None

                run_analysis_block("Full_Phases/Phase_1", full_p1, n_syn, trial_idx, participant_dir, index_dict, is_skip, best_lambda)
                run_analysis_block("Full_Phases/Phase_2", full_p2, n_syn, trial_idx, participant_dir, index_dict, is_skip, best_lambda)
                run_analysis_block("Lift_Onset/Phase_1", lift_p1, n_syn, trial_idx, participant_dir, index_dict, is_skip, best_lambda)
                run_analysis_block("Lift_Onset/Phase_2", lift_p2, n_syn, trial_idx, participant_dir, index_dict, is_skip, best_lambda)

    print("\nAnalysis complete for all participants.")

if __name__ == "__main__":
    main()


Processing Participant: 1
--- Step 1: Running Grid Search to find optimal lambda for P1 ---
  Testing lambda = 5000.00...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.


  vaf_norm = (np.array(vaf_scores) - np.min(vaf_scores)) / (np.max(vaf_scores) - np.min(vaf_scores))
  cost_norm = (np.array(activation_costs) - np.min(activation_costs)) / (np.max(activation_costs) - np.min(activation_costs))


--- Grid Search Complete. Selected optimal lambda = 5000.00 ---
--- Step 2: Running focused analysis for P1 using lambda = 5000.00 ---
 Running Full_Phases/Phase_1 for Trial 01 with 4 synergies...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.
   -> Final R²: 0.5084
 Running Full_Phases/Phase_2 for Trial 01 with 4 synergies...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.
   -> Final R²: 0.6463
 Running Lift_Onset/Phase_1 for Trial 01 with 4 synergies...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.
   -> Final R²: 0.5660
 Running Lift_Onset/Phase_2 for Trial 01 with 4 synergies...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.
   -> Final R²: 0.7274
 Running Full_Phases/Phase_1 for Trial 03 with 4 synergies...
   -> Smoothness achieved for lambda=5000.00 after 10000 iterations.
   -> Final R²: 0.6595
 Running Full_Phases/Phase_2 for Trial 03 with 4 synergies...
   -> Smoothness achieved for lambd

KeyboardInterrupt: 