notebook for decoding

In [None]:
from data_structures import PatientData
from sklearn.base import BaseEstimator
from sklearn.svm import LinearSVC
from sklearn.preprocessing import StandardScaler
from typing import Tuple
from sklearn.model_selection import train_test_split
import numpy as np
from dataclasses import dataclass
from typing import Dict, List
from sklearn.metrics import accuracy_score, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns


In [2]:
p566 = PatientData(pid='566')
p563 = PatientData(pid='563')

./Data/40m_act_24_S06E01_30fps_character_frames.csv
./Data/40m_act_24_S06E01_30fps_character_frames.csv


In [3]:
@dataclass
class DecodingResult:
    """Container for decoding results and metrics"""
    test_accuracy: float
    train_accuracy: float
    test_roc_auc: float
    train_roc_auc: float
    train_samples: Dict[str, int]  # Number of samples for each concept in training
    test_samples: Dict[str, int]   # Number of samples for each concept in testing
    predictions: np.ndarray
    true_labels: np.ndarray
    classifier: BaseEstimator
    data: Dict

In [17]:
class ConceptDecoder:
    
    """
    Handles decoding for a single concept pair

    - design choice - will call dataset method in each decoder call?

    - problem is that generally we want an object associated with one dataset - this would require an input for pseudo or not
    in the class instantiation. however, we want fine control over pseudopops parameters, so this is less good.
    one potential solution is a params* dict, but thats complicated. 

    for consistent stuff

    maybe separate classes - instantiate dataset, get training into dict, input training dict into concept decoder?

    - add method for PCA visualization in 2D/3D
    """
    def __init__(self, patient_data: PatientData, c1: str, c2: str, epoch: str, classifier: BaseEstimator = LinearSVC(), dataset: ConceptPairDataset = None, standardize: bool=False):
        self.patient_data = patient_data
        self.c1 = c1
        self.c2 = c2
        self.epoch = epoch
        self.classifier = classifier

        self.scaler = StandardScaler() if standardize else None
        self.metrics = {}

    
        if not dataset:
            self.dataset = ConceptPairDataset( #type: ignore
                patient_data=self.patient_data,
                concept_pair=(self.c1, self.c2),
                epoch=self.epoch, 
                min_samples = 20
            )
        else:
            self.dataset = dataset

    def decode(self, test_size: float = 0.3) -> DecodingResult: 
        """
        Performs decoding on the concept pair using normal dataset
        
        Args:
            test_size: Fraction of data to use for testing
            
        Returns:
            DecodingResult containing metrics and predictions
        """
        try:
            data_dict, info = self.dataset.create_dataset_normal(test_size=test_size)
        except ValueError as e:
            print(f"Skipping concept pair {self.c1}, {self.c2}: {e}") # Inform user of skipped pair and reason
            return None # Return None to indicate decoding failure for this pair


        X_train = data_dict['X_train']
        X_test = data_dict['X_test']
        y_train = data_dict['y_train']
        y_test = data_dict['y_test']

        if self.scaler:
            X_train = self.scaler.fit_transform(X_train)
            X_test = self.scaler.transform(X_test)

        self.classifier.fit(X_train, y_train)
        
        # Get predictions
        y_train_pred = self.classifier.predict(X_train)
        y_pred = self.classifier.predict(X_test)

        # Calculate metrics for train and test
        train_accuracy = accuracy_score(y_train, y_train_pred)
        test_accuracy = accuracy_score(y_test, y_pred)
        train_roc_auc = roc_auc_score(y_train, y_train_pred) # or use decision_function for prob based ROC AUC if needed
        test_roc_auc = roc_auc_score(y_test, y_pred)

        
        train_samples = {
            self.c1: np.sum(y_train == 0),
            self.c2: np.sum(y_train == 1)
        }
        test_samples = {
            self.c1: np.sum(y_test == 0),
            self.c2: np.sum(y_test == 1)
        }

        return DecodingResult(
            train_accuracy=train_accuracy,
            train_roc_auc=train_roc_auc,
            test_accuracy=test_accuracy,
            test_roc_auc=test_roc_auc,
            train_samples=train_samples,
            test_samples=test_samples,
            predictions=y_pred,
            true_labels=y_test,
            classifier=self.classifier,
            data=data_dict
        )

    def decode_pseudo(self):
        raise NotImplementedError



        
    

In [5]:
class ConceptPairDataset():
    """
    Class to turn both concept bins - np.ndarrays shape (n_onsets, n_neurons) (each row is a response) into dataset with 

    2 methods - one with psuedopopulations

    
    """
    def __init__(self, patient_data: PatientData, concept_pair: Tuple[str, str], 
                 epoch: str, min_samples: int = 10):
        self.patient_data = patient_data
        self.c1, self.c2 = concept_pair
        self.epoch = epoch
        self.min_samples = min_samples

    def create_dataset_normal(self, test_size = 0.3):
        """
        Method to create dataset without pseudopops, liable to make unbalanced dataset

        Returns X_train, X_test, y_train, y_test, info: dict
        """
        c1_data, c2_data = self.patient_data.get_concept_data(c1=self.c1, c2=self.c2, epoch=self.epoch)

        print(f"c1 shape: {c1_data.shape[0]}, c2 shape: {c2_data.shape[0]}")

        if len(c1_data) < self.min_samples or len(c2_data) < self.min_samples:
            raise ValueError(f"Insufficient samples for {self.c1} vs {self.c2}")

        X = np.vstack([c1_data, c2_data])
        y = np.concatenate([np.zeros(len(c1_data)), np.ones(len(c2_data))])

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)

        info = {} #for any extra stuff we wanna pass through
        res_dict = {
            'X_train': X_train, 'X_test': X_test, 'y_test': y_test, 'y_train': y_train
        }

        return res_dict, info

    def create_dataset_pseudo(self):
        raise NotImplementedError


In [16]:
class DecodingResultsManager:
    """
    Manages decoding results for multiple concept pairs for a single patient and epoch.
    """
    def __init__(self, patient_data: PatientData, concept_pairs: List[Tuple[str, str]], epoch: str, classifier: BaseEstimator = LinearSVC(), standardize: bool = False):
        self.patient_data = patient_data
        self.concept_pairs = concept_pairs
        self.epoch = epoch
        self.classifier = classifier # Default classifier for all decoders, can be overridden
        self.standardize = standardize # Default standardization for all decoders
        self.results: Dict[Tuple[str, str], DecodingResult] = {} # Store results here, key is concept pair

    def run_decoding_for_pairs(self) -> None:
        """
        Runs decoding for all concept pairs provided in the constructor.
        Stores the DecodingResult in the self.results dictionary.
        """
        for c1, c2 in self.concept_pairs:
            decoder = ConceptDecoder(
                patient_data=self.patient_data,
                c1=c1,
                c2=c2,
                epoch=self.epoch,
                classifier=self.classifier,
                standardize=self.standardize
            )
            result = decoder.decode()
            if result is not None: # Only store if decode was successful (not None)
                self.results[(c1, c2)] = result

    def plot_train_test_performance_heatmap(self, metric='test_roc_auc', figsize=(20, 10)):
        """
        Generates and displays a combined heatmap of training and testing performance for all concept pairs.
        Metrics can be 'test_accuracy', 'train_accuracy', 'test_roc_auc', 'train_roc_auc'.
        """
        concepts = sorted(list(set([c for pair in self.concept_pairs for c in pair])))
        n_concepts = len(concepts)
        train_matrix = np.full((n_concepts, n_concepts), np.nan)
        test_matrix = np.full((n_concepts, n_concepts), np.nan)

        concept_to_idx = {concept: i for i, concept in enumerate(concepts)}

        for concept_pair, result in self.results.items():
            if result is not None:
                c1, c2 = concept_pair
                i, j = concept_to_idx[c1], concept_to_idx[c2]

                # Fix: Use the full attribute names
                if 'roc_auc' in metric:
                    train_value = result.train_roc_auc
                    test_value = result.test_roc_auc
                else:  # accuracy
                    train_value = result.train_accuracy
                    test_value = result.test_accuracy

                train_matrix[i, j] = train_value
                train_matrix[j, i] = train_value
                test_matrix[i, j] = test_value
                test_matrix[j, i] = test_value

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

        heatmap_params = {
            'xticklabels': concepts,
            'yticklabels': concepts,
            'annot': True,
            'fmt': '.3f',
            'cmap': 'viridis',
            'vmin': 0.5, #set vmin and vmax to be consistent and centered
            'vmax': 1.0,
            'center': 0.75
        }

        sns.heatmap(train_matrix, ax=ax1, **heatmap_params)
        ax1.set_title(f'Training {metric.replace("test_", "").replace("_", " ").title()}') # Dynamic title
        ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')

        sns.heatmap(test_matrix, ax=ax2, **heatmap_params)
        ax2.set_title(f'Test {metric.replace("test_", "").replace("_", " ").title()}') # Dynamic title
        ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')

        for ax, matrix in [(ax1, train_matrix), (ax2, test_matrix)]:
            for i in range(len(concepts)):
                for j in range(len(concepts)):
                    if np.isnan(matrix[i, j]):
                        ax.text(j + 0.5, i + 0.5, 'N/A',
                               ha='center', va='center',
                               color='gray')
        plt.suptitle(f'Train vs Test Performance for Concept Decoding - Patient {self.patient_data.pid}, Epoch: {self.epoch}', y=1.05)
        plt.tight_layout()
        plt.show()

# All concept decoding!

# testing playground


:)

In [10]:
p563 = PatientData(pid='563')
decoder = ConceptDecoder(patient_data=p563, c1='A.Amar', c2='B.Buchanan', epoch='movie')
res = decoder.decode()


./Data/40m_act_24_S06E01_30fps_character_frames.csv
c1 shape: 20, c2 shape: 68
Skipping concept pair A.Amar, B.Buchanan: Insufficient samples for A.Amar vs B.Buchanan


In [None]:
res.accuracy

0.7407407407407407

In [18]:
concept_pairs_to_decode = []
concept_list = list(p566.movie_df.columns)
#remove non-concept columns if they exist
concept_list = [col for col in concept_list if col not in ['Frame', 'time_sec', 'rel_corrected_time_sec', 'uncorrected_time_sec']]


# doing to 5 for time

for i, concept1 in enumerate(concept_list[:5]):
    for concept2 in concept_list[i+1:5]: #avoid duplicates and self-pairs
        concept_pairs_to_decode.append((concept1, concept2))

print(f"Number of concept pairs to decode: {len(concept_pairs_to_decode)}")
print(concept_pairs_to_decode[:5]) # Print first 5 pairs as example

Number of concept pairs to decode: 10
[('A.Amar', 'A.Fayed'), ('A.Amar', 'B.Buchanan'), ('A.Amar', 'C.Manning'), ('A.Amar', 'C.OBrian'), ('A.Fayed', 'B.Buchanan')]


In [19]:
manager = DecodingResultsManager(
    patient_data=p566,
    concept_pairs=concept_pairs_to_decode,
    epoch='movie',
    standardize=True
)

In [20]:
manager.run_decoding_for_pairs()


c1 shape: 19, c2 shape: 51
Skipping concept pair A.Amar, A.Fayed: Insufficient samples for A.Amar vs A.Fayed
c1 shape: 20, c2 shape: 68
c1 shape: 20, c2 shape: 27
c1 shape: 19, c2 shape: 76
Skipping concept pair A.Amar, C.OBrian: Insufficient samples for A.Amar vs C.OBrian
c1 shape: 52, c2 shape: 63
c1 shape: 53, c2 shape: 26
c1 shape: 54, c2 shape: 71
c1 shape: 55, c2 shape: 10
Skipping concept pair B.Buchanan, C.Manning: Insufficient samples for B.Buchanan vs C.Manning
c1 shape: 62, c2 shape: 70
c1 shape: 27, c2 shape: 74


In [21]:
manager.plot_train_test_performance_heatmap(metric='test_roc_auc') # Example ROC AUC heatmap
manager.plot_train_test_performance_heatmap(metric='test_accuracy') # Example Accuracy heatmap

AttributeError: module 'matplotlib' has no attribute 'subplots'