### File for 4x4 concept decoding across all patients, 10 iterations each
Note - not going to use MTL neurons only b/c other patients don't have enough MTL neurons

In [29]:
import os
import time
import joblib
from datetime import datetime
from itertools import combinations
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from tqdm import tqdm
from data_structures import PatientData
from decoders import ConceptDecoder, SingleResultsManager
from sklearn.svm import LinearSVC
from typing import defaultdict
import glob

In [None]:
# Create results directory structure if it doesn't exist
RESULTS_DIR = "./decoding_res_4x4_april9"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [30]:
def generate_consistent_combinations(strings_list, group_size):
    all_combos = list(combinations(strings_list, group_size))
    unique_pairs = []
    seen = set()
    
    for combo in all_combos:
        # Find the complement (the items not in this combination)
        complement = tuple(c for c in strings_list if c not in combo)
        
        # Order consistently to avoid duplicates
        if combo < complement:
            pair = (combo, complement)
        else:
            pair = (complement, combo)
        
        pair_str = str(pair)
        if pair_str not in seen:
            unique_pairs.append(pair)
            seen.add(pair_str)
    
    return unique_pairs

In [31]:
best_concepts = [
    "A.Fayed", "R.Wallace", "T.Lennox", "N.Yassir", 
    "K.Hayes", "M.OBrian", "J.Bauer", "C.Manning"
]
stable_groups = generate_consistent_combinations(best_concepts, 4)

THRESHOLD = 0.1  # firing rate threshold
MAX_ITERATIONS = 10  # Total iterations per group/patient combination
PATIENT_IDS = ['562', '563', '566']

In [None]:
def create_patient_dict(pids):
    """Create and return dictionary with patient data and filtered neurons"""
    patient_dict = {}
    
    for pid in pids:
        print(f"Loading data for patient {pid}...")
        # Create PatientData object
        patient = PatientData(pid=pid)
        
        # Filter neurons by firing rate
        fr_neurons = patient.filter_neurons_by_fr(
            neurons=patient.neurons, 
            window=(patient.times_dict['movie_start_rel'], patient.times_dict['preSleep_recall_start_rel']), 
            threshold=THRESHOLD
        )
        
        # Store as [PatientData object, filtered neurons list]
        patient_dict[pid] = [patient, fr_neurons]
        print(f"Patient {pid}: {len(fr_neurons)} neurons after filtering")
    
    return patient_dict

In [None]:


# Create a unique run ID using timestamp
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_DIR = os.path.join(RESULTS_DIR, f"run_{RUN_ID}")
os.makedirs(RUN_DIR, exist_ok=True)

config = {
    'concepts': best_concepts,
    'threshold': THRESHOLD,
    'max_iterations': MAX_ITERATIONS,
    'patient_ids': PATIENT_IDS,
    'timestamp': RUN_ID,
    'num_concept_groups': len(stable_groups)
}
joblib.dump(config, os.path.join(RUN_DIR, "config.pkl"))

In [None]:
def get_result_filename(pid, group_idx, iteration):
    """Generate a standardized filename for saving results"""
    return os.path.join(RUN_DIR, f"patient_{pid}_group_{group_idx}_iter_{iteration}.pkl")

# Define a function to check if a result file already exists
def result_exists(pid, group_idx, iteration):
    """Check if a result file already exists"""
    filename = get_result_filename(pid, group_idx, iteration)
    return os.path.exists(filename)

# Define a function to save progress metadata
def save_progress(pid_dict, completed_dict):
    """Save progress metadata"""
    progress = {
        'completed': completed_dict,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    joblib.dump(progress, os.path.join(RUN_DIR, "progress.pkl"))

# Define a function to load progress metadata
def load_progress():
    """Load progress metadata if it exists"""
    progress_file = os.path.join(RUN_DIR, "progress.pkl")
    if os.path.exists(progress_file):
        return joblib.load(progress_file)
    else:
        return {'completed': {}, 'timestamp': None}

In [None]:
def run_decoding_with_checkpoints():
    # Load or initialize patient data
    patient_dict = create_patient_dict(PATIENT_IDS)
    
    # Initialize or load progress
    progress = load_progress()
    completed_dict = progress['completed']
    
    # Track overall progress
    total_tasks = len(PATIENT_IDS) * len(stable_groups) * MAX_ITERATIONS
    completed_tasks = sum(len(iterations) for iterations in completed_dict.values())
    
    print(f"Starting decoding run: {RUN_ID}")
    print(f"Total tasks: {total_tasks}")
    print(f"Completed tasks: {completed_tasks}")
    print(f"Progress: {(completed_tasks/total_tasks)*100:.2f}%")
    
    # Main loop over all patients, groups, and iterations
    for pid in PATIENT_IDS:
        # Get patient data and filtered neurons
        patient, neurons = patient_dict[pid]
        
        # Make sure this patient has a tracking entry
        if pid not in completed_dict:
            completed_dict[pid] = {}
        
        for group_idx, group_pair in enumerate(stable_groups):
            group_key = str(group_idx)
            
            # Initialize this group's tracking if needed
            if group_key not in completed_dict[pid]:
                completed_dict[pid][group_key] = []
            
            # Get completed iterations for this combination
            completed_iterations = completed_dict[pid][group_key]
            
            # Skip if all iterations completed
            if len(completed_iterations) >= MAX_ITERATIONS:
                print(f"All iterations completed for patient {pid}, group {group_idx}")
                continue
            
            # Calculate remaining iterations
            remaining_iterations = MAX_ITERATIONS - len(completed_iterations)
            
            print(f"Processing patient {pid}, group {group_idx}: {remaining_iterations} iterations remaining")
            
            # Create results manager for this group
            manager = SingleResultsManager(
                patient_data=patient,
                concept_items=[group_pair],  # List with a single group pair
                epoch='movie',
                classifier=LinearSVC(random_state=42),
                standardize=True,
                pseudo=True,  # Use pseudopopulations for balanced datasets
                neurons=neurons
            )
            
            # Run iterations one by one, saving each result
            for i in range(remaining_iterations):
                iter_num = len(completed_iterations) + i
                print(f"  Running iteration {iter_num + 1}/{MAX_ITERATIONS}")
                
                # Run a single iteration
                start_time = time.time()
                manager.run_decoding(num_iter=1)
                end_time = time.time()
                
                # Save the result
                if manager.results:
                    result = manager.results[group_pair][0]  # Get the first (only) result
                    filename = get_result_filename(pid, group_idx, iter_num)
                    
                    # Add metadata to result
                    result_with_meta = {
                        'result': result,
                        'patient_id': pid,
                        'group_idx': group_idx,
                        'group_pair': group_pair,
                        'iteration': iter_num,
                        'duration': end_time - start_time,
                        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    }
                    
                    # Save result with joblib
                    joblib.dump(result_with_meta, filename)
                    
                    # Update progress tracking
                    completed_dict[pid][group_key].append(iter_num)
                    save_progress(patient_dict, completed_dict)
                    
                    print(f"  Iteration completed in {end_time - start_time:.2f} seconds")
                else:
                    print(f"  No results for this iteration, possibly due to insufficient data")
            
    print("Decoding completed!")
    return completed_dict

In [32]:
def load_results(run_dir=None):
    """Load all results from a run directory"""
    if run_dir is None:
        # Use most recent run if none specified
        all_runs = sorted([d for d in os.listdir(RESULTS_DIR) if d.startswith("run_")], reverse=True)
        if not all_runs:
            print("No runs found")
            return None
        run_dir = os.path.join(RESULTS_DIR, all_runs[0])
    
    # Load configuration
    config = joblib.load(os.path.join(run_dir, "config.pkl"))
    print(f"Loaded configuration from {run_dir}")
    print(f"Run timestamp: {config['timestamp']}")
    print(f"Patients: {config['patient_ids']}")
    print(f"Concept groups: {config['num_concept_groups']}")
    
    # Get all result files
    result_files = [f for f in os.listdir(run_dir) if f.endswith(".pkl") and not f.startswith("config") and not f.startswith("progress")]
    print(f"Found {len(result_files)} result files")
    
    # Load all results into a structured dictionary
    results = {}
    for filename in result_files:
        filepath = os.path.join(run_dir, filename)
        result_data = joblib.load(filepath)
        
        pid = result_data['patient_id']
        group_idx = result_data['group_idx']
        iter_num = result_data['iteration']
        
        if pid not in results:
            results[pid] = {}
        if group_idx not in results[pid]:
            results[pid][group_idx] = []
        
        results[pid][group_idx].append((iter_num, result_data))
    
    return results, config

# Function to resume an incomplete run
def resume_decoding():
    """Resume the most recent decoding run if it's incomplete"""
    # Find the most recent run
    all_runs = sorted([d for d in os.listdir(RESULTS_DIR) if d.startswith("run_")], reverse=True)
    if not all_runs:
        print("No previous runs found. Starting new run.")
        return run_decoding_with_checkpoints()
    
    most_recent_run = os.path.join(RESULTS_DIR, all_runs[0])
    progress_file = os.path.join(most_recent_run, "progress.pkl")
    
    if not os.path.exists(progress_file):
        print("Most recent run has no progress file. Starting new run.")
        return run_decoding_with_checkpoints()
    
    # Load progress and configuration
    progress = joblib.load(progress_file)
    config = joblib.load(os.path.join(most_recent_run, "config.pkl"))
    
    # Check if all tasks are completed
    completed_dict = progress['completed']
    total_tasks = len(config['patient_ids']) * config['num_concept_groups'] * config['max_iterations']
    completed_tasks = sum(len(iterations) for pid_dict in completed_dict.values() for iterations in pid_dict.values())
    
    if completed_tasks >= total_tasks:
        print("Previous run is already complete. Starting new run.")
        return run_decoding_with_checkpoints()
    
    # Resume the previous run
    print(f"Resuming run {all_runs[0]}")
    print(f"Progress: {completed_tasks}/{total_tasks} tasks completed ({(completed_tasks/total_tasks)*100:.2f}%)")
    
    # Set the global RUN_ID and RUN_DIR to match the previous run
    global RUN_ID, RUN_DIR
    RUN_ID = config['timestamp']
    RUN_DIR = most_recent_run
    
    # Run the decoding with the existing progress
    return run_decoding_with_checkpoints()

# Function to plot results by group
def plot_group_performance(results, config, metric='test_roc_auc', save_path=None):
    """Plot performance for each concept group across patients"""
    # Extract group pairs
    stable_groups = generate_consistent_combinations(config['concepts'], 4)
    
    # Prepare data for plotting
    group_labels = []
    performance_by_patient = {pid: [] for pid in config['patient_ids']}
    errors_by_patient = {pid: [] for pid in config['patient_ids']}
    
    for group_idx, group_pair in enumerate(stable_groups):
        # Create label for this group
        group1_str = '+'.join(group_pair[0])
        group2_str = '+'.join(group_pair[1])
        label = f"G{group_idx}"  # Short label for x-axis
        group_labels.append(label)
        
        # Get performance for this group across patients
        for pid in config['patient_ids']:
            if pid in results and group_idx in results[pid]:
                # Collect all results for this patient/group
                values = []
                for _, result_data in results[pid][group_idx]:
                    # Extract the metric value
                    result = result_data['result']
                    if hasattr(result, metric):
                        values.append(getattr(result, metric))
                
                if values:
                    performance_by_patient[pid].append(np.mean(values))
                    errors_by_patient[pid].append(np.std(values))
                else:
                    performance_by_patient[pid].append(np.nan)
                    errors_by_patient[pid].append(np.nan)
            else:
                performance_by_patient[pid].append(np.nan)
                errors_by_patient[pid].append(np.nan)
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bar_width = 0.2
    index = np.arange(len(group_labels))
    
    # Plot bars for each patient
    for i, pid in enumerate(config['patient_ids']):
        position = index + (i * bar_width)
        rects = ax.bar(position, performance_by_patient[pid], 
                       bar_width, yerr=errors_by_patient[pid], 
                       label=f"Patient {pid}")
    
    ax.set_xlabel('Concept Groups')
    ax.set_ylabel(metric.replace('_', ' ').title())
    ax.set_title('Group Decoding Performance Across Patients')
    ax.set_xticks(index + bar_width)
    ax.set_xticklabels(group_labels)
    ax.legend()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    
    return fig, ax


Load data

In [34]:
def load_decoding_results(results_dir):
    """
    Load all decoding results from a directory and organize them by patient.
    
    Args:
        results_dir: Path to the directory containing result pickle files
        
    Returns:
        Dictionary mapping patient IDs to SingleResultsManager objects
    """
    print(f"Loading results from {results_dir}...")
    
    # Find all result files (excluding config and progress files)
    result_files = glob.glob(os.path.join(results_dir, "patient_*.pkl"))
    print(f"Found {len(result_files)} result files")
    
    if not result_files:
        print("No result files found!")
        return None
    
    # Load the configuration file to get metadata
    config_file = os.path.join(results_dir, "config.pkl")
    if os.path.exists(config_file):
        config = joblib.load(config_file)
        print(f"Loaded configuration from run {config.get('timestamp', 'unknown')}")
    else:
        print("No configuration file found, proceeding with limited metadata")
        config = {}
    
    # Group files by patient
    files_by_patient = defaultdict(list)
    for filepath in result_files:
        # Extract patient ID from filename
        filename = os.path.basename(filepath)
        parts = filename.split('_')
        if len(parts) >= 4 and parts[0] == "patient":
            pid = parts[1]
            files_by_patient[pid].append(filepath)
    
    print(f"Found data for {len(files_by_patient)} patients: {', '.join(files_by_patient.keys())}")
    
    # Create PatientData objects for each patient
    patient_data_objects = {}
    for pid in files_by_patient.keys():
        try:
            print(f"Creating PatientData object for patient {pid}...")
            patient_data_objects[pid] = PatientData(pid=pid)
        except Exception as e:
            print(f"Error creating PatientData object for patient {pid}: {e}")
            continue
    
    # Create SingleResultsManager objects for each patient
    managers = {}
    
    for pid, files in files_by_patient.items():
        if pid not in patient_data_objects:
            print(f"Skipping patient {pid} due to missing PatientData object")
            continue
        
        patient_data = patient_data_objects[pid]
        
        # Create an empty manager
        manager = SingleResultsManager(
            patient_data=patient_data,
            concept_items=[],  # We'll populate this from the results
            epoch='movie',
            classifier=LinearSVC(random_state=42),
            standardize=True,
            pseudo=True  # Match the settings used for original decoding
        )
        
        
        # Organize results by group
        results_by_group = defaultdict(list)
        concept_pairs_set = set()  # Keep track of unique concept pairs
        
        print(f"Loading {len(files)} result files for patient {pid}...")
        for filepath in tqdm(files):
            try:
                # Load the result data
                result_data = joblib.load(filepath)
                
                # Extract the group_pair and DecodingResult
                group_pair = result_data['group_pair']
                result = result_data['result']
                
                # Add to results_by_group
                results_by_group[group_pair].append(result)
                concept_pairs_set.add(group_pair)
            except Exception as e:
                print(f"Error loading {filepath}: {e}")
                continue
        
        # Convert concept_pairs_set to list
        concept_pairs = list(concept_pairs_set)
        
        # Update manager's attributes
        manager.concept_items = concept_pairs
        manager.results = dict(results_by_group)  # Convert defaultdict to dict
        
        # Add to the managers dictionary
        managers[pid] = manager
        
        print(f"Created SingleResultsManager for patient {pid} with {len(concept_pairs)} concept groups")
        
    return managers


In [35]:
RESULTS_DIR = "./decoding_res_4x4_april9"
run_dir = "run_20250409_153314"



patient_dict = load_decoding_results(os.path.join(RESULTS_DIR, run_dir))

Loading results from ./decoding_res_4x4_april9/run_20250409_153314...
Found 1048 result files
Loaded configuration from run 20250409_153314
Found data for 3 patients: 562, 563, 566
Creating PatientData object for patient 562...
./Data/40m_act_24_S06E01_30fps_character_frames.csv
Creating PatientData object for patient 563...
./Data/40m_act_24_S06E01_30fps_character_frames.csv
Creating PatientData object for patient 566...
./Data/40m_act_24_S06E01_30fps_character_frames.csv
Loading 350 result files for patient 562...


100%|██████████| 350/350 [00:00<00:00, 667.68it/s]


Created SingleResultsManager for patient 562 with 35 concept groups
Loading 350 result files for patient 563...


100%|██████████| 350/350 [00:00<00:00, 611.96it/s]


Created SingleResultsManager for patient 563 with 35 concept groups
Loading 348 result files for patient 566...


100%|██████████| 348/348 [00:00<00:00, 672.18it/s]

Created SingleResultsManager for patient 566 with 35 concept groups





In [36]:
p562_manager = patient_dict['562']
p563_manager = patient_dict['563']
p566_manager = patient_dict['566']



In [None]:
def plot_dichotomy_performance_with_key(results_manager, dichotomy_list, metric='test_roc_auc', figsize=(15, 12), title_suffix=""):
    """
    Plots decoding performance for dichotomies (1-N) and adds a text key below
    mapping numbers to the actual group comparisons.

    Args:
        results_manager: A SingleResultsManager object that has run decoding
                         on the items in dichotomy_list.
        dichotomy_list: The list of group-vs-group tuples exactly as used
                        when running the results_manager. The order determines
                        the x-axis order (1-N).
        metric (str): The performance metric from DecodingResult to plot.
        figsize (tuple): Figure size for the *entire* plot (bars + text key).
        title_suffix (str): Optional text to append to the plot title.
    """
    if not results_manager.results:
        print("Error: results_manager has no results. Did you run run_decoding?")
        return None

    expected_num_dichotomies = len(dichotomy_list)
    print(f"Expecting {expected_num_dichotomies} dichotomies based on input list.")

    performance_means = []
    performance_stds = []
    dichotomy_labels_numeric = [str(i + 1) for i in range(expected_num_dichotomies)]
    key_strings = [] # To store formatted strings for the text key

    found_count = 0
    missing_keys_indices = []

    # --- Data Processing ---
    for i, dichotomy_key in enumerate(dichotomy_list):
        # Format the key string regardless of results being present
        group1, group2 = dichotomy_key
        group1_str = '+'.join(group1)
        group2_str = '+'.join(group2)
        key_strings.append(f"{i+1}: ({group1_str}) vs ({group2_str})")

        if dichotomy_key in results_manager.results:
            results_for_key = results_manager.results[dichotomy_key]
            if results_for_key:
                try:
                    values = [getattr(r, metric) for r in results_for_key]
                    performance_means.append(np.mean(values))
                    performance_stds.append(np.std(values))
                    found_count += 1
                except AttributeError:
                    print(f"Error: Metric '{metric}' not found for dichotomy {i+1}. Plotting NaN.")
                    performance_means.append(np.nan)
                    performance_stds.append(np.nan)
                except Exception as e:
                     print(f"Error processing results for dichotomy {i+1}: {e}")
                     performance_means.append(np.nan)
                     performance_stds.append(np.nan)
            else:
                performance_means.append(np.nan)
                performance_stds.append(np.nan)
                missing_keys_indices.append(i + 1)
        else:
            performance_means.append(np.nan)
            performance_stds.append(np.nan)
            missing_keys_indices.append(i + 1)

    print(f"Processed results for {found_count}/{expected_num_dichotomies} dichotomies.")
    if missing_keys_indices:
         print(f"Missing or empty results for dichotomies (numbered 1 to {expected_num_dichotomies}): {sorted(list(set(missing_keys_indices)))}")


    # --- Plotting ---
    fig = plt.figure(figsize=figsize)

    # Define grid: 2 rows, 1 column. Top plot (bars) taller than bottom (text).
    # Adjust height_ratios if needed (e.g., [4, 1] for more space for bars)
    gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])

    # --- Top Subplot: Bar Chart ---
    ax_bar = fig.add_subplot(gs[0])

    x_positions = np.arange(expected_num_dichotomies)
    plot_stds = np.array(performance_stds)
    plot_stds[np.isnan(performance_means)] = 0 # Avoid error bars for NaN means

    ax_bar.bar(x_positions,
               np.nan_to_num(performance_means, nan=0.0),
               yerr=plot_stds,
               align='center',
               alpha=0.75,
               ecolor='black',
               capsize=4)

    ax_bar.set_ylabel(f'{metric.replace("_", " ").title()}')
    # Remove x-label from bar chart, it's implied by the key below
    ax_bar.set_xlabel('')
    ax_bar.set_xticks(x_positions)
    ax_bar.set_xticklabels(dichotomy_labels_numeric, rotation=90, fontsize=8)

    if "roc_auc" in metric.lower() or "accuracy" in metric.lower():
         ax_bar.set_ylim(0.0, 1.05)
         ax_bar.axhline(0.5, color='grey', linestyle='--', linewidth=0.8, label='Chance (0.5)')
         ax_bar.legend(loc='lower right')

    ax_bar.grid(axis='y', linestyle=':', linewidth=0.5) # Add horizontal grid lines

    # --- Bottom Subplot: Text Key ---
    ax_text = fig.add_subplot(gs[1])
    ax_text.axis('off') # Hide axes lines and ticks

    # Calculate positions for text lines
    num_lines = len(key_strings)
    # Split into two columns if too many lines
    if num_lines > 20: # Adjust this threshold as needed
        split_point = (num_lines + 1) // 2
        col1_strings = key_strings[:split_point]
        col2_strings = key_strings[split_point:]
        max_lines_per_col = split_point

        # Column 1
        y_start = 0.95
        y_step = 1.0 / (max_lines_per_col + 1) if max_lines_per_col > 0 else 1.0
        for i, line in enumerate(col1_strings):
            ax_text.text(0.01, y_start - i * y_step, line,
                         ha='left', va='top', fontsize=7, family='monospace')
        # Column 2
        for i, line in enumerate(col2_strings):
             ax_text.text(0.51, y_start - i * y_step, line,
                          ha='left', va='top', fontsize=7, family='monospace')

    else:
        # Single column
        y_start = 0.95
        y_step = 1.0 / (num_lines + 1) if num_lines > 0 else 1.0
        for i, line in enumerate(key_strings):
            ax_text.text(0.01, y_start - i * y_step, line,
                         ha='left', va='top', fontsize=9, family='monospace')


    # --- Overall Figure Title ---
    patient_id = results_manager.patient_data.pid
    epoch = results_manager.epoch
    base_title = f'Group Decoding Performance for {expected_num_dichotomies} Dichotomies'
    full_title = f'{base_title}\nPatient {patient_id}, Epoch: {epoch}'
    if title_suffix:
        full_title += f" - {title_suffix}"
    fig.suptitle(full_title, y=0.99) # Adjust y if title overlaps top plot

    # Adjust layout - rect might need tuning depending on title length
    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    plt.show()

    # Optionally return the processed data
    #return {'dichotomy_num': dichotomy_labels_numeric, 'mean_perf': performance_means, 'std_perf': performance_stds, 'key': key_strings}

In [None]:
plot_dichotomy_performance_with_key(p562_manager, stable_groups)

In [None]:
plot_dichotomy_performance_with_key(p563_manager, stable_groups)

In [None]:
plot_dichotomy_performance_with_key(p566_manager, stable_groups)