In [6]:
# function file
from data_structures import PatientData
from sklearn.base import BaseEstimator
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from typing import Tuple
from sklearn.model_selection import train_test_split
import numpy as np
from dataclasses import dataclass
from typing import Dict, List
from sklearn.metrics import accuracy_score, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from decoders import generate_pseudopopulations, DecodingResult
from multi_decoding import MultiResultsManager, plot_multi_patient_heatmap

In [7]:
p566 = PatientData(pid='566')
p563 = PatientData(pid='563')
p562 = PatientData(pid='562')

./Data/40m_act_24_S06E01_30fps_character_frames.csv
./Data/40m_act_24_S06E01_30fps_character_frames.csv
./Data/40m_act_24_S06E01_30fps_character_frames.csv


In [8]:
THRESHOLD = 0.1


p566_fr_neurons = p566.filter_neurons_by_fr(neurons=p566.neurons, window=(p566.times_dict['movie_start_rel'], p566.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
p566_mtl_fr_neurons = p566.filter_mtl_neurons(neurons=p566_fr_neurons)

p563_fr_neurons = p563.filter_neurons_by_fr(neurons=p563.neurons, window=(p563.times_dict['movie_start_rel'], p563.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
p563_mtl_fr_neurons = p563.filter_mtl_neurons(neurons=p563_fr_neurons)

p562_fr_neurons = p562.filter_neurons_by_fr(neurons=p562.neurons, window=(p562.times_dict['movie_start_rel'], p562.times_dict['preSleep_recall_start_rel']), threshold=THRESHOLD)
p562_mtl_fr_neurons = p562.filter_mtl_neurons(neurons=p562_fr_neurons)

neurons_list = [p562_mtl_fr_neurons, p563_mtl_fr_neurons, p566_mtl_fr_neurons]

In [9]:
p562_mtl_fr_neurons

[<data_structures.Neuron at 0x31e4f7790>,
 <data_structures.Neuron at 0x31e4f7950>,
 <data_structures.Neuron at 0x31e4f7a90>]

In [11]:
concept_pairs_to_decode = []

selected_concepts = ['A.Amar',
  'A.Fayed',
  'B.Buchanan',
  'C.Manning',
  'C.OBrian',
  'J.Bauer',
  'K.Hayes',
  'M.OBrian',
  'N.Yassir',
  'R.Wallace',
  'T.Lennox',
]

for i, concept1 in enumerate(selected_concepts[:]):
    for concept2 in selected_concepts[i+1:]: #avoid duplicates and self-pairs
        concept_pairs_to_decode.append((concept1, concept2))

print(f"Number of concept pairs to decode: {len(concept_pairs_to_decode)}")
print(concept_pairs_to_decode[:3]) # Print first 5 pairs as example


Number of concept pairs to decode: 55
[('A.Amar', 'A.Fayed'), ('A.Amar', 'B.Buchanan'), ('A.Amar', 'C.Manning')]


In [None]:
multi_mtl_manager = MultiResultsManager(
    patient_data_list=[p562, p563, p566],
    concept_pairs=concept_pairs_to_decode,
    epoch='movie',
    standardize=False,
    pseudo=True,
    neurons_list=[p562_mtl_fr_neurons, p563_mtl_fr_neurons, p566_mtl_fr_neurons]
)

In [None]:
multi_mtl_manager.run_decoding_for_pairs(num_iter=5)


In [None]:
best_concepts = [
    "A.Fayed", 
    "R.Wallace", 
    "T.Lennox", 
    "N.Yassir", 
    "K.Hayes", 
    "M.OBrian", 
    "J.Bauer", 
    "C.Manning"
]

In [None]:
plot_multi_patient_heatmap(multi_mtl_manager, metric='test_accuracy', selected_concepts=best_concepts, show_numbers=False)
plt.suptitle('Multi-Patient Psuedopopulation Character Decoding Performance\nPatients: 562, 563, 566\nMTL neurons only, above 0.1Hz Firing Rate', fontsize=17)
#plt.savefig('mtl_multipatient_without_acc')
#


# All areas, neurons above 0.1hz decoding

In [12]:
multi_all_manager = MultiResultsManager(
    patient_data_list=[p562, p563, p566],
    concept_pairs=concept_pairs_to_decode,
    epoch='movie',
    standardize=False,
    pseudo=True,
    neurons_list=[p562_fr_neurons, p563_fr_neurons, p566_fr_neurons]
)

In [None]:
multi_all_manager.run_decoding_for_pairs(num_iter=5)


In [None]:
plot_multi_patient_heatmap(multi_all_manager, metric='test_accuracy', selected_concepts=best_concepts, show_numbers=False, center=0.5)
plt.suptitle('Multi-Patient Psuedopopulation Character Decoding Performance\nPatients: 562, 563, 566\nAll neurons above 0.1Hz Firing Rate', fontsize=17)
#plt.savefig('multipatient_without_accuracy_full_spectrum')
