### CCGP notebook between individual concepts

### TODO:

- do save stuff to save decodingresults classes so dont have to run them every time 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from data_structures import PatientData
from decoders import ConceptDecoder, SingleResultsManager
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, roc_auc_score
import os
import datetime
import joblib
import seaborn as sns
from tqdm import tqdm

# Load patient data
p566 = PatientData(pid='566')
p563 = PatientData(pid='563')
p562 = PatientData(pid='562')
# Set a threshold for neuron filtering
THRESHOLD = 0.1  # firing rate threshold

In [None]:
p566_fr_neurons = p566.filter_neurons_by_fr(neurons=p566.neurons, window=(p566.times_dict['movie_start_rel'], p566.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
#p566_mtl_fr_neurons = p566.filter_mtl_neurons(neurons=p566_fr_neurons)

p563_fr_neurons = p563.filter_neurons_by_fr(neurons=p563.neurons, window=(p563.times_dict['movie_start_rel'], p563.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
#p563_mtl_fr_neurons = p563.filter_mtl_neurons(neurons=p563_fr_neurons)

p562_fr_neurons = p562.filter_neurons_by_fr(neurons=p562.neurons, window=(p562.times_dict['movie_start_rel'], p562.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
#p562_mtl_fr_neurons = p562.filter_mtl_neurons(neurons=p562_fr_neurons)

In [None]:
concept_pairs_to_decode = []

selected_concepts = ['A.Amar',
  'A.Fayed',
  'B.Buchanan',
  'C.Manning',
  'C.OBrian',
  'J.Bauer',
  'K.Hayes',
  'M.OBrian',
  'N.Yassir',
  'R.Wallace',
  'T.Lennox',
]

for i, concept1 in enumerate(selected_concepts[:]):
    for concept2 in selected_concepts[i+1:]: #avoid duplicates and self-pairs
        concept_pairs_to_decode.append((concept1, concept2))

print(f"Number of concept pairs to decode: {len(concept_pairs_to_decode)}")
print(concept_pairs_to_decode[:3]) # Print first 5 pairs as example


### Train decoders (dont have to run anymore, skip to loading step)

In [None]:
p562_manager = SingleResultsManager(
    patient_data=p562,
    concept_items=concept_pairs_to_decode,
    epoch='movie',
    standardize=True,
    pseudo=True,
    neurons=p562_fr_neurons    # no kwargs -> default

)
p563_manager = SingleResultsManager(
    patient_data=p563,
    concept_items=concept_pairs_to_decode,
    epoch='movie',
    standardize=True,
    pseudo=True,
    neurons=p563_fr_neurons    # no kwargs -> default

)
p566_manager = SingleResultsManager(
    patient_data=p566,
    concept_items=concept_pairs_to_decode,
    epoch='movie',
    standardize=True,
    pseudo=True,
    neurons=p566_fr_neurons    # no kwargs -> default

)
NUM_ITER = 5
p562_manager.run_decoding_for_pairs(num_iter = 5)
p563_manager.run_decoding_for_pairs(num_iter = NUM_ITER)
p566_manager.run_decoding_for_pairs(num_iter = NUM_ITER)

### Save decodingresults classes to pkl files for load

In [None]:
RESULTS_DIR = "./concept_pairs_single_patient_april16"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Generate a unique run ID based on timestamp
RUN_ID = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
RUN_DIR = os.path.join(RESULTS_DIR, f"run_{RUN_ID}")
os.makedirs(RUN_DIR, exist_ok=True)

# Save metadata/configuration
config = {
    'concepts': selected_concepts,
    'threshold': THRESHOLD,
    'num_iterations': NUM_ITER,
    'patient_ids': ['562', '563', '566'],
    'timestamp': RUN_ID
}
joblib.dump(config, os.path.join(RUN_DIR, "config.pkl"))


In [None]:
for pid, manager, neurons in [('562', p562_manager, p562_fr_neurons), 
                             ('563', p563_manager, p563_fr_neurons), 
                             ('566', p566_manager, p566_fr_neurons)]:
    print(f"Saving results for patient {pid}...")
    
    # Create patient-specific directory
    patient_dir = os.path.join(RUN_DIR, f"patient_{pid}")
    os.makedirs(patient_dir, exist_ok=True)
    
    # Save manager parameters separately for flexibility
    manager_params = {
        'epoch': manager.epoch,
        'standardize': manager.standardize,
        'pseudo': manager.pseudo,
        'pseudo_params': manager.pseudo_params
    }
    joblib.dump(manager_params, os.path.join(patient_dir, "manager_params.pkl"))
    
    # Save concept_items
    joblib.dump(manager.concept_items, os.path.join(patient_dir, "concept_items.pkl"))
    
    # Save results dictionary
    joblib.dump(manager.results, os.path.join(patient_dir, "results.pkl"))
    
    # Save neuron information
    # We'll save just the neuron IDs for efficiency
    neuron_ids = [neuron.neuron_id for neuron in neurons]
    joblib.dump(neuron_ids, os.path.join(patient_dir, "neuron_ids.pkl"))

print(f"Results saved to {RUN_DIR}")


### Load the decoding results files - TODO

In [None]:
def load_results_manager(run_dir, pid, patient_data):
    """
    Load saved results and reconstruct a SingleResultsManager.
    
    Args:
        run_dir: Directory containing the saved run
        pid: Patient ID to load
        patient_data: PatientData object for the patient
        
    Returns:
        Reconstructed SingleResultsManager with loaded results
    """
    patient_dir = os.path.join(run_dir, f"patient_{pid}")
    print(f"Loading results for patient {pid} from {patient_dir}...")
    
    # Load manager parameters
    manager_params = joblib.load(os.path.join(patient_dir, "manager_params.pkl"))
    
    # Load concept_items (pairs)
    concept_items = joblib.load(os.path.join(patient_dir, "concept_items.pkl"))
    
    # Load and recreate the neuron list
    neuron_ids = joblib.load(os.path.join(patient_dir, "neuron_ids.pkl"))
    
    # Filter the full neuron list to get the same neurons that were used originally
    neurons = [n for n in patient_data.neurons if n.neuron_id in neuron_ids]
    
    if len(neurons) != len(neuron_ids):
        print(f"Warning: Could only find {len(neurons)} of {len(neuron_ids)} original neurons")
    
    # Create a new manager with the reconstructed neuron list
    manager = SingleResultsManager(
        patient_data=patient_data,
        concept_items=concept_items,
        epoch=manager_params['epoch'],
        standardize=manager_params['standardize'],
        pseudo=manager_params['pseudo'],
        neurons=neurons,
        **(manager_params.get('pseudo_params', {}) or {})  # Handle None case
    )
    
    # Load results
    result_path = os.path.join(patient_dir, "results.pkl")
    if os.path.exists(result_path):
        manager.results = joblib.load(result_path)
        print(f"  Successfully loaded results with {len(manager.results)} concept pairs")
    else:
        print(f"  Results file not found at {result_path}")
    
    return manager


In [None]:
# testing load
load_p562_manager = load_results_manager("./concept_pairs_single_patient_april16/run_20250416_134735", pid=562, patient_data=p562)
load_p563_manager = load_results_manager("./concept_pairs_single_patient_april16/run_20250416_134735", pid=563, patient_data=p563)
load_p566_manager = load_results_manager("./concept_pairs_single_patient_april16/run_20250416_134735", pid=566, patient_data=p566)


In [None]:
def compare_managers(original_manager, loaded_manager):
    """Compare original and loaded managers to verify they're equivalent"""
    
    checks = {
        "Same number of concept pairs": len(original_manager.results) == len(loaded_manager.results),
        "Same patient ID": original_manager.patient_data.pid == loaded_manager.patient_data.pid,
        "Same epochs": original_manager.epoch == loaded_manager.epoch,
        "Same standardize setting": original_manager.standardize == loaded_manager.standardize,
        "Same pseudo setting": original_manager.pseudo == loaded_manager.pseudo,
        "Same number of neurons": len(original_manager.neurons) == len(loaded_manager.neurons)
    }
    
    # Check if all concept pairs match
    original_pairs = set(original_manager.results.keys())
    loaded_pairs = set(loaded_manager.results.keys())
    checks["Same concept pairs"] = original_pairs == loaded_pairs
    
    # Print comparison results
    print("Manager comparison results:")
    for check, result in checks.items():
        print(f"  {check}: {'✓' if result else '✗'}")
    
    if all(checks.values()):
        print("Managers appear to be equivalent!")
    else:
        print("Managers have some differences.")


In [None]:
# check - ran this, was successful
#compare_managers(p562_manager, load_p562_manager)
#compare_managers(p563_manager, load_p563_manager)
#compare_managers(p566_manager, load_p566_manager)


### CCGP

In [None]:
def compute_ccgp_matrix(manager, metric='accuracy'):
    """
    Compute CCGP matrix for all concept pairs.
    
    Args:
        manager: SingleResultsManager with trained decoders
        metric: Which metric to use for CCGP performance ('accuracy' or 'roc_auc')
        
    Returns:
        ccgp_mean: Mean CCGP performance matrix (n_pairs x n_pairs)
        ccgp_std: Standard deviation of CCGP performance matrix
        concept_pairs: List of concept pairs in the same order as matrix rows/columns
    """
    # Get all concept pairs with results
    concept_pairs = list(manager.results.keys())
    n_pairs = len(concept_pairs)
    
    # Initialize result matrices
    ccgp_mean = np.zeros((n_pairs, n_pairs))
    ccgp_std = np.zeros((n_pairs, n_pairs))
    
    # For each pair of concept pairs
    for i, train_pair in enumerate(tqdm(concept_pairs, desc="Computing CCGP Matrix")):
        train_results = manager.results[train_pair]
        
        for j, test_pair in enumerate(concept_pairs):
            # Skip diagonal (same pairs) - this would just be normal test performance
            if i == j:
                # Use test performance from the results directly
                performances = []
                for result in train_results:
                    if metric == 'accuracy':
                        performances.append(result.test_accuracy)
                    else:  # roc_auc
                        performances.append(result.test_roc_auc)
                
                ccgp_mean[i, j] = np.mean(performances)
                ccgp_std[i, j] = np.std(performances)
                continue
            
            # Get test data for the test pair
            test_results = manager.results[test_pair]
            
            # For each decoder trained on train_pair
            performances = []
            for train_result in train_results:
                # Get the classifier from training result
                classifier = train_result.classifier
                
                # For each test dataset
                for test_result in test_results:
                    # Get test data
                    X_test = test_result.data['X_test']
                    y_test = test_result.data['y_test']
                    
                    # Apply classifier to test data
                    y_pred = classifier.predict(X_test)
                    
                    # Compute performance metric
                    if metric == 'accuracy':
                        perf = accuracy_score(y_test, y_pred)
                    else:  # roc_auc
                        perf = roc_auc_score(y_test, y_pred)
                    
                    performances.append(perf)
            
            # Store mean and std of performances
            ccgp_mean[i, j] = np.mean(performances)
            ccgp_std[i, j] = np.std(performances)
    
    return ccgp_mean, ccgp_std, concept_pairs

In [None]:
def plot_ccgp_matrix(ccgp_mean, ccgp_std, concept_pairs, patient_id, metric='accuracy', 
                     figsize=(20, 16), show_values=False):
    """
    Plot CCGP matrix as a heatmap.
    
    Args:
        ccgp_mean: Mean CCGP performance matrix
        ccgp_std: Standard deviation of CCGP performance matrix
        concept_pairs: List of concept pairs in same order as matrix rows/columns
        patient_id: Patient ID for title
        metric: Which metric was used ('accuracy' or 'roc_auc')
        figsize: Figure size
        show_values: Whether to annotate cells with mean±std values
    """
    # Simplified labels for concept pairs
    pair_labels = []
    for pair in concept_pairs:
        if isinstance(pair[0], str) and isinstance(pair[1], str):
            # Regular concept pair
            label = f"{pair[0].split('.')[0]} vs {pair[1].split('.')[0]}"
        else:
            # Group pair - simplify further
            label = f"Group{pair_labels.count('Group')+1}"
        pair_labels.append(label)
    
    plt.figure(figsize=figsize)
    ax = plt.gca()
    
    # Plot heatmap
    sns.heatmap(ccgp_mean, cmap='viridis', vmin=0.5, vmax=1.0, 
                xticklabels=pair_labels, yticklabels=pair_labels, 
                cbar_kws={'label': f'CCGP {metric.capitalize()}'})
    
    # Annotate with values if requested
    if show_values:
        for i in range(len(concept_pairs)):
            for j in range(len(concept_pairs)):
                text = f"{ccgp_mean[i, j]:.2f}\n±{ccgp_std[i, j]:.2f}"
                plt.text(j + 0.5, i + 0.5, text, ha='center', va='center', 
                         color='white' if ccgp_mean[i, j] > 0.75 else 'black',
                         fontsize=8)
    
    plt.title(f'CCGP Matrix - Patient {patient_id} (Train row → Test column)', fontsize=14)
    plt.xlabel('Test Concept Pair', fontsize=12)
    plt.ylabel('Train Concept Pair', fontsize=12)
    plt.xticks(rotation=90)
    plt.tight_layout()
    
    return ax

In [None]:
p562_ccgp_mean, p562_ccgp_std, p562_concept_pairs = compute_ccgp_matrix(load_p562_manager, metric='accuracy')
p563_ccgp_mean, p563_ccgp_std, p563_concept_pairs = compute_ccgp_matrix(load_p563_manager, metric='accuracy')
p566_ccgp_mean, p566_ccgp_std, p566_concept_pairs = compute_ccgp_matrix(load_p566_manager, metric='accuracy')
