In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances
import pandas as pd
from tqdm import tqdm
import os

# Import project-specific modules
from data_structures import PatientData
from decoders import ConceptDecoder, SingleResultsManager
from multi_decoding import MultiResultsManager, plot_multi_patient_heatmap

In [None]:
p562 = PatientData(pid='562')
p563 = PatientData(pid='563')
p566 = PatientData(pid='566')


THRESHOLD = 0.1  # fr threshold


In [None]:
p562_fr_neurons = p562.filter_neurons_by_fr(
    neurons=p562.neurons, 
    window=(p562.times_dict['movie_start_rel'], p562.times_dict['preSleep_recall_start_rel']), 
    threshold=THRESHOLD
)
p562_mtl_fr_neurons = p562.filter_mtl_neurons(neurons=p562_fr_neurons)

# Patient 563
p563_fr_neurons = p563.filter_neurons_by_fr(
    neurons=p563.neurons, 
    window=(p563.times_dict['movie_start_rel'], p563.times_dict['preSleep_recall_start_rel']), 
    threshold=THRESHOLD
)
p563_mtl_fr_neurons = p563.filter_mtl_neurons(neurons=p563_fr_neurons)

# Patient 566
p566_fr_neurons = p566.filter_neurons_by_fr(
    neurons=p566.neurons, 
    window=(p566.times_dict['movie_start_rel'], p566.times_dict['preSleep_recall_start_rel']), 
    threshold=THRESHOLD
)
p566_mtl_fr_neurons = p566.filter_mtl_neurons(neurons=p566_fr_neurons)

# Create lists for multi-patient analysis
neurons_list = [p562_mtl_fr_neurons, p563_mtl_fr_neurons, p566_mtl_fr_neurons]
patient_data_list = [p562, p563, p566]


In [None]:
print(f"Patient 562: {len(p562_mtl_fr_neurons)} MTL neurons (from {len(p562_fr_neurons)} active neurons)")
print(f"Patient 563: {len(p563_mtl_fr_neurons)} MTL neurons (from {len(p563_fr_neurons)} active neurons)")
print(f"Patient 566: {len(p566_mtl_fr_neurons)} MTL neurons (from {len(p566_fr_neurons)} active neurons)")


In [None]:
selected_concepts = [
    'A.Amar',
    'A.Fayed',
    'B.Buchanan',
    'C.Manning',
    'C.OBrian',
    'J.Bauer',
    'K.Hayes',
    'M.OBrian',
    'N.Yassir',
    'R.Wallace',
    'T.Lennox',
]


In [None]:
concept_pairs_to_decode = []
for i, concept1 in enumerate(selected_concepts):
    for concept2 in selected_concepts[i+1:]:  # Avoid duplicates and self-pairs
        concept_pairs_to_decode.append((concept1, concept2))

print(f"Number of concept pairs to decode: {len(concept_pairs_to_decode)}")


In [None]:
multi_mtl_manager = MultiResultsManager(
    patient_data_list=patient_data_list,
    concept_pairs=concept_pairs_to_decode,
    epoch='movie',
    standardize=False,
    pseudo=True,  # Use pseudopopulations to balance dataset
    neurons_list=neurons_list
)

# Run decoding with multiple iterations for robustness
num_iterations = 3  # Can increase for more stable results
multi_mtl_manager.run_decoding_for_pairs(num_iter=num_iterations)


In [None]:
patient_data_list=[p562]

In [None]:
multi_mtl_no_pseudo = MultiResultsManager(patient_data_list=patient_data_list,
    concept_pairs=concept_pairs_to_decode,
    epoch='movie',
    standardize=False,
    pseudo=False,  # Use pseudopopulations to balance dataset
    neurons_list=neurons_list
)


In [None]:

multi_mtl_no_pseudo.run_decoding_for_pairs(1)