# Neural Code Final Portfolio: Analysis of Neural Population Geometries

**Key terms:**
* **Context:** One of two distinct task structures. The correct stimulus-response pair is *inverted* between contexts.
* **Block:** Groups of sequential trials within a session. Each block corresponds to a switch in latent context, and each session has 10-16 trial blocks of 15-32 trials each.
* **Inference trials:** trials where a given stimulus is encountered for the first time after a context switch.

🔺 **Indicates unanswered gap between reimplementation and original paper.**

## Preliminaries

In [1]:
import re
import json
from joblib import Parallel, delayed
import time
import numpy as np
import pandas as pd
import itertools
import concurrent.futures

import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D

from scipy.stats import binom

from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances

import helper_functions as F

## Load and inspect datasets

### Behavioral and task data

Analysis were performed on 36 sessions with high accuracy on **non-inference** trials (e.g., suggests that patients understood the task). The following datasets should therefore contain 36 entries each:

* `beh_data` contains behavioral data organized by recording session, including `events`, which tracks the timing of specific events (erecording start, stimulus onset, response 1, response 2, recording end) for each session. 
* `task_data` contains information about trial-level task variables inlcuding `context` (1 or 2), `stim_sequence` (identity of the stimulus on a current trial, corresponding to specific [response, reward, context] combinations; *these are not aligned between sessions*), and `response_sequence` (the required response for the current trial).

The most important event codes `36` (response 1) and `31` (response 2), which are used in the `get_session_accuracy` method to see if the patient's true response matches the required `response_sequence` for that task.

In [None]:
with open("beh.json", "r") as f:
    all_beh_data = json.load(f)["beh"]

# Behavioral data
beh_data = pd.DataFrame(all_beh_data["data"])
beh_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36 entries, 0 to 35
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   sessionID  36 non-null     object
 1   events     36 non-null     object
dtypes: object(2)
memory usage: 708.0+ bytes


In [3]:
# Check top rows
beh_data.head()

Unnamed: 0,sessionID,events
0,P61CS_1,"[[1551091404901085.0, 55], [1551091406087648.0..."
1,P62CS_1,"[[1555760333243868.0, 55], [1555760334376711.0..."
2,P62CS_2,"[[1555761804593152.0, 55], [1555761805744683.0..."
3,P62CS_3,"[[1556200582588169.0, 55], [1556200583950356.0..."
4,P62CS_4,"[[1556641173506903.0, 55], [1556641174833121.0..."


In [4]:
# Check number of events in each session
for i in range(len(beh_data)):
    print(f"Session {beh_data["sessionID"][i]}: {len(beh_data["events"][i])} events")

Session P61CS_1: 642 events
Session P62CS_1: 482 events
Session P62CS_2: 482 events
Session P62CS_3: 482 events
Session P62CS_4: 482 events
Session P63CS_1: 482 events
Session P63CS_2: 482 events
Session P63CS_3: 482 events
Session P65CS_1: 482 events
Session P65CS_2: 482 events
Session P65CS_3: 562 events
Session P67CS_1: 482 events
Session P67CS_2: 482 events
Session P67CS_3: 482 events
Session P67CS_4: 522 events
Session P70CS_1: 482 events
Session P71CS_1: 362 events
Session P71CS_2: 482 events
Session P73CS_1: 482 events
Session P73CS_2: 482 events
Session P73CS_3: 402 events
Session P74CS_1: 482 events
Session P74CS_2: 482 events
Session P76CS_1: 482 events
Session P78CS_1: 482 events
Session P78CS_2: 482 events
Session P79CS_1: 482 events
Session P79CS_2: 482 events
Session P79CS_3: 482 events
Session TWH162_1: 482 events
Session TWH163_1: 482 events
Session TWH163_2: 482 events
Session TWH165_1: 482 events
Session TWH165_2: 482 events
Session TWH172_1: 482 events
Session TWH172

In [5]:
# Task data
task_data = pd.DataFrame(all_beh_data["task_info"])
task_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36 entries, 0 to 35
Data columns (total 10 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   context            36 non-null     object
 1   stim_sequence      36 non-null     object
 2   reward_sequence    36 non-null     object
 3   response_sequence  36 non-null     object
 4   is_novel_variant   2 non-null      object
 5   novel_stim_dir     2 non-null      object
 6   novel_block_image  2 non-null      object
 7   remapping          2 non-null      object
 8   stim_to_replace    2 non-null      object
 9   replace_name       2 non-null      object
dtypes: object(10)
memory usage: 2.9+ KB


In [6]:
# Check top rows
task_data.head()

Unnamed: 0,context,stim_sequence,reward_sequence,response_sequence,is_novel_variant,novel_stim_dir,novel_block_image,remapping,stim_to_replace,replace_name
0,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[2, 4, 1, 2, 4, 4, 4, 3, 1, 4, 4, 3, 4, 3, 2, ...","[25, 5, 5, 25, 5, 5, 5, 25, 5, 5, 5, 25, 5, 25...","[0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, ...",,,,,,
1,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 2, 4, 4, 3, 2, 3, 1, 4, 4, 4, 2, 3, 1, 2, ...","[5, 25, 5, 5, 25, 25, 25, 5, 5, 5, 5, 25, 25, ...","[1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, ...",,,,,,
2,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[1, 3, 3, 4, 2, 2, 2, 4, 4, 1, 1, 2, 1, 3, 2, ...","[5, 25, 25, 5, 25, 25, 25, 5, 5, 5, 5, 25, 5, ...","[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, ...",,,,,,
3,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...","[2, 2, 3, 3, 1, 2, 2, 3, 2, 4, 4, 2, 2, 3, 4, ...","[25, 25, 25, 25, 5, 25, 25, 25, 25, 5, 5, 25, ...","[0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, ...",,,,,,
4,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, ...","[2, 1, 2, 3, 2, 2, 4, 3, 3, 4, 2, 4, 4, 2, 3, ...","[25, 5, 25, 25, 25, 25, 5, 25, 25, 5, 25, 5, 5...","[0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...",,,,,,


### Single-neuron recordings



In [7]:
# Neural data
neu_data = pd.read_json("neu.json")
neu_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2694 entries, 0 to 2693
Data columns (total 3 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   array      2694 non-null   object
 1   cellinfo   2694 non-null   int64 
 2   sessionID  2694 non-null   object
dtypes: int64(1), object(2)
memory usage: 63.3+ KB


In [8]:
# Check top rows
neu_data.head()

Unnamed: 0,array,cellinfo,sessionID
0,"[{'block_nr': 1, 'iscorrect': True, 'stim_id':...",2,P61CS_1
1,"[{'block_nr': 1, 'iscorrect': True, 'stim_id':...",2,P61CS_1
2,"[{'block_nr': 1, 'iscorrect': True, 'stim_id':...",2,P61CS_1
3,"[{'block_nr': 1, 'iscorrect': True, 'stim_id':...",4,P61CS_1
4,"[{'block_nr': 1, 'iscorrect': True, 'stim_id':...",4,P61CS_1


🔺 **Neuron counts in preSMA do not match paper (p. 842), although total counts do. Potential typo?**

In [9]:
# Check brain area counts against paper
brain_areas = F.define_cell_area_groups(neu_data)
true_area_counts = {
    "HPC" : 494,
    "vmPFC" : 463,
    "AMY" : 889,
    "dACC" : 310,
    "preSMA" : 269,
    "VTC" : 269
}
for key, value in brain_areas.items():
    print(f"Area name: {key}. Data count: {len(value)}, True count: {true_area_counts[key]}.")

Area name: HPC. Data count: 494, True count: 494.
Area name: vmPFC. Data count: 889, True count: 463.
Area name: AMY. Data count: 269, True count: 889.
Area name: dACC. Data count: 310, True count: 310.
Area name: preSMA. Data count: 463, True count: 269.
Area name: VTC. Data count: 269, True count: 269.


In [10]:
# Example of cell-level data
example_cell_data = F.get_cell_array(neu_data, cell_idx=42)
example_cell_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 320 entries, 0 to 319
Data columns (total 9 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   block_nr   320 non-null    int64  
 1   iscorrect  320 non-null    bool   
 2   stim_id    320 non-null    int64  
 3   context    320 non-null    int64  
 4   reward     320 non-null    int64  
 5   response   320 non-null    int64  
 6   trial_nr   320 non-null    int64  
 7   fr_stim    320 non-null    int64  
 8   fr_base    320 non-null    float64
dtypes: bool(1), float64(1), int64(7)
memory usage: 20.4 KB


In [None]:
example_cell_data.head()

Unnamed: 0,block_nr,iscorrect,stim_id,context,reward,response,trial_nr,fr_stim,fr_base
0,1,True,2,2,25,0,1,0,0.0
1,1,True,4,2,5,0,2,2,0.0
2,1,True,1,2,5,1,3,2,0.909091
3,1,True,2,2,25,0,4,0,1.818182
4,1,True,4,2,5,0,5,1,0.0


🔺 **Paper states that each recording session consists of 280-320 trials grouped into 10-16 blocks different sizes; data shows trial counts as low as 180.**

In [12]:
# Check unique trial counts
neu_data["n_trials"] = neu_data["array"].apply(lambda x : len(x))
neu_data["n_trials"].value_counts()

n_trials
240    2309
260     119
200     105
180      65
320      57
280      39
Name: count, dtype: int64

## Balanced dichotomies of task conditions

Task conditions are defined by the following `[stimulus identity, response, outcome, context]` sets:
1. `[C, left, 5, 1]`
2. `[D, right, 5, 1]`
3. `[A, left, 25, 1]`
4. `[B, right, 25, 1]`
5. `[D, left, 5, 2]`
6. `[A, right, 5, 2]`
7. `[B, left, 25, 2]`
8. `[C, right, 25, 2]`

Note that Python uses 0 indexing.

The set of **balanced dichotomies** is defined by all possible ways that the 8 task conditions can be split into pairs of 4 conditions each (`(8 choose 4)/2 = 35` groups).

In [13]:
# Left is 0, right is 1
dichotomy_labels = [
    [0, 5, 1],
    [1, 5, 1],
    [0, 25, 1],
    [1, 25, 1],
    [0, 5, 2],
    [1, 5, 2],
    [0, 25, 2],
    [1, 5, 2]
]

In [14]:
interpretations, pos_set, neg_set = F.define_dichotomies()

In [15]:
# Check that there are 35 dichotomies
print(len(pos_set), len(neg_set))

35 35


In [16]:
# Check dichotomies with straightforward interpretations
for key, value in interpretations.items():
    print(f"Interpretation: {value}")
    print(f"Class 1: {pos_set[key]}, Class 2: {neg_set[key]}")

Interpretation: context
Class 1: [0 1 2 3], Class 2: [4 5 6 7]
Interpretation: outcome
Class 1: [0 1 4 5], Class 2: [2 3 6 7]
Interpretation: AB vs CD
Class 1: [0 1 4 7], Class 2: [2 3 5 6]
Interpretation: response
Class 1: [0 2 4 6], Class 2: [1 3 5 7]
Interpretation: AC vs BD
Class 1: [0 2 5 7], Class 2: [1 3 4 6]
Interpretation: parity
Class 1: [0 3 5 6], Class 2: [1 2 4 7]
Interpretation: AD vs BC
Class 1: [0 4 5 6], Class 2: [1 2 3 7]


Note: `AB vs CD` is also called the **stimulus pair**, the grouping of stimuli where the response is the same in either context. **Parity** is the grouping with the "maximal nonlinear interaction between task variables."

## Data pre-processing



### Constructing a pseudo-population

Neuron recordings from all patients were combined to form a "pseudopopulation." Neurons included in analyses have at least `K = 15` trials from each unique task condition. The pseudopopulation is constructed by randomly sampling `K` trials for each condition for each neuron independently.

In [17]:
# Make groups with all possible combinations for variables of interet
# In practice, the relevant variables are response, reward, and context
variable_groups = F.make_variable_groups(
    example_cell_data,
    var_names=["context", "reward", "response"]
)
variable_groups.head()

Unnamed: 0,context_1_reward_0_response_0,context_1_reward_0_response_1,context_1_reward_5_response_0,context_1_reward_5_response_1,context_1_reward_25_response_0,context_1_reward_25_response_1,context_2_reward_0_response_0,context_2_reward_0_response_1,context_2_reward_5_response_0,context_2_reward_5_response_1,context_2_reward_25_response_0,context_2_reward_25_response_1
0,False,False,False,False,False,False,False,False,False,False,True,False
1,False,False,False,False,False,False,False,False,True,False,False,False
2,False,False,False,False,False,False,False,False,False,True,False,False
3,False,False,False,False,False,False,False,False,False,False,True,False
4,False,False,False,False,False,False,False,False,True,False,False,False


In [18]:
# Check that column order matches task condition definitions
correct_cols = [col for col in variable_groups.columns if "reward_0" not in col]
for i, col in enumerate(correct_cols):
    variable_nums = re.findall(r'\d+', col)
    variable_nums = [int(n) for n in variable_nums]
    variable_nums.reverse()
    print(f"Column index: {i}, Name: {col}")
    print(f"Match: {variable_nums == dichotomy_labels[i]}")

Column index: 0, Name: context_1_reward_5_response_0
Match: True
Column index: 1, Name: context_1_reward_5_response_1
Match: True
Column index: 2, Name: context_1_reward_25_response_0
Match: True
Column index: 3, Name: context_1_reward_25_response_1
Match: True
Column index: 4, Name: context_2_reward_5_response_0
Match: True
Column index: 5, Name: context_2_reward_5_response_1
Match: True
Column index: 6, Name: context_2_reward_25_response_0
Match: True
Column index: 7, Name: context_2_reward_25_response_1
Match: False


The `construct_regressors` method selects neurons (each represented by a row of the data) to include in analysis. The data of each cell is a list of firing rates for the given trial type.

In [19]:
# Note: Original code describes this method as balancing IA and IP sessions,
# but actual implementation seems to be done elsewhere.
example_regressors = F.construct_regressors(
    neu_data,
    sample_thr=15, # aka K
    select=[i for i in range(100)]
)
example_regressors.head()

Unnamed: 0,context_1_reward_5_response_0,context_1_reward_5_response_1,context_1_reward_25_response_0,context_1_reward_25_response_1,context_2_reward_5_response_0,context_2_reward_5_response_1,context_2_reward_25_response_0,context_2_reward_25_response_1
0,"[2, 0, 5, 1, 2, 0, 1, 3, 2, 1, 3, 3, 4, 0, 1, ...","[0, 3, 2, 4, 0, 0, 3, 2, 0, 3, 0, 1, 0, 2, 4, ...","[0, 4, 3, 0, 5, 0, 4, 0, 1, 3, 2, 0, 1, 1, 0, ...","[0, 1, 2, 1, 5, 2, 1, 1, 3, 2, 2, 0, 1, 0, 2, ...","[0, 0, 1, 1, 0, 2, 1, 3, 1, 0, 4, 0, 0, 0, 0, ...","[2, 0, 1, 3, 1, 2, 0, 4, 0, 0, 1, 2, 1, 3, 3, ...","[0, 1, 1, 2, 1, 3, 2, 2, 0, 2, 3, 0, 0, 0, 0, ...","[1, 7, 2, 3, 1, 1, 0, 1, 3, 0, 1, 2, 2, 3, 0, ..."
1,"[4, 3, 2, 2, 7, 0, 6, 0, 5, 4, 1, 4, 3, 4, 3, ...","[3, 3, 3, 1, 5, 4, 4, 3, 3, 0, 6, 2, 5, 4, 3, ...","[0, 3, 5, 3, 6, 2, 3, 1, 3, 0, 2, 0, 10, 3, 4,...","[3, 1, 3, 2, 2, 4, 3, 6, 3, 2, 1, 0, 4, 2, 1, ...","[0, 1, 3, 1, 4, 2, 2, 4, 0, 1, 6, 0, 3, 0, 2, ...","[1, 3, 4, 1, 5, 3, 1, 2, 2, 2, 3, 2, 0, 1, 1, ...","[2, 1, 3, 3, 4, 3, 1, 1, 2, 2, 1, 3, 0, 0, 1, ...","[6, 1, 5, 1, 2, 1, 0, 1, 4, 0, 3, 2, 2, 6, 3, ..."
2,"[0, 0, 0, 0, 0, 0, 3, 1, 1, 0, 0, 1, 1, 1, 0, ...","[0, 0, 2, 1, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 1, ...","[1, 0, 0, 0, 2, 1, 2, 1, 0, 0, 0, 0, 0, 0, 1, ...","[0, 1, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 2, 0, 0, ...","[0, 1, 0, 3, 0, 0, 1, 1, 0, 1, 1, 0, 2, 0, 0, ...","[2, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 1, ...","[2, 1, 0, 1, 0, 0, 2, 0, 0, 1, 1, 3, 1, 0, 0, ..."
3,"[0, 0, 0, 0, 11, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[0, 0, 0, 6, 4, 19, 0, 0, 1, 1, 0, 0, 1, 0, 0,...","[0, 4, 15, 5, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,...","[1, 0, 0, 6, 0, 4, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[4, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, ...","[1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, ..."
4,"[4, 12, 7, 4, 17, 5, 5, 4, 5, 8, 12, 5, 7, 2, ...","[6, 3, 7, 5, 19, 3, 9, 14, 10, 10, 4, 10, 4, 8...","[6, 12, 3, 4, 4, 6, 3, 6, 7, 10, 4, 4, 9, 11, ...","[5, 8, 6, 4, 2, 8, 4, 2, 13, 3, 4, 0, 4, 6, 3,...","[0, 7, 5, 6, 7, 10, 5, 6, 13, 7, 2, 5, 5, 5, 6...","[0, 5, 3, 3, 9, 6, 4, 6, 6, 6, 6, 4, 11, 3, 1,...","[0, 15, 9, 5, 6, 8, 13, 7, 9, 6, 6, 6, 3, 9, 7...","[9, 4, 14, 3, 4, 9, 1, 1, 8, 2, 17, 10, 4, 5, ..."


### Inference present vs. absent sessions

Inference present vs. absent sessions are computed using statistical significance of session-level inference behavior by the following steps:
1. `get_block_num`: Identify "blocks" in each session where contexts are switched.
2. `get_instance_num`: Within each block, assign instances to each type of stimulus. 
3. `get_session_accuracy`: Take the first instance of the stimulus in the new trial as an "inference trial," then compute participant accuracy. 
4. `test_inference_trials`: Classify each of 36 sessions as either inference present or absent, depending on whether accuracy is significantly above chance (`p > 0.05`).

In [20]:
example_context = task_data["context"].iloc[0]
example_block_nums = F.get_block_num(example_context)
print(f"Contexts: {example_context}")
print(f"Block numbers: {example_block_nums}")

Contexts: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Block numbers: [0. 0. 0. 0. 0

In [21]:
example_stim = task_data["stim_sequence"].iloc[0]
example_instance_nums = F.get_instance_num(example_stim, example_block_nums)
print(f"Stimulus sequence: {example_stim}")
print(f"Within-block instances: {example_instance_nums}")

Stimulus sequence: [2, 4, 1, 2, 4, 4, 4, 3, 1, 4, 4, 3, 4, 3, 2, 3, 1, 3, 1, 2, 1, 1, 4, 3, 2, 4, 1, 2, 2, 4, 4, 1, 2, 2, 3, 3, 4, 2, 3, 3, 1, 1, 2, 4, 2, 3, 1, 4, 2, 3, 3, 4, 4, 3, 1, 1, 2, 4, 2, 4, 1, 4, 2, 1, 2, 3, 2, 2, 4, 3, 3, 4, 2, 4, 4, 2, 3, 1, 1, 3, 4, 4, 1, 3, 2, 1, 2, 1, 4, 2, 3, 1, 3, 2, 3, 3, 3, 2, 1, 1, 4, 1, 4, 3, 4, 1, 2, 1, 4, 1, 4, 4, 4, 1, 2, 2, 4, 2, 4, 1, 2, 1, 1, 4, 3, 3, 1, 4, 3, 2, 3, 2, 1, 1, 1, 1, 1, 2, 1, 4, 4, 2, 2, 2, 4, 2, 1, 4, 2, 1, 2, 1, 1, 4, 4, 3, 1, 1, 2, 4, 1, 1, 1, 3, 3, 3, 2, 3, 2, 3, 1, 3, 1, 2, 3, 4, 1, 4, 4, 2, 2, 2, 2, 3, 3, 4, 4, 3, 2, 4, 3, 2, 4, 4, 3, 3, 3, 1, 2, 2, 1, 4, 1, 1, 1, 1, 2, 2, 4, 2, 1, 4, 4, 2, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 2, 2, 3, 1, 2, 4, 1, 4, 3, 2, 3, 1, 2, 4, 3, 3, 1, 2, 3, 3, 2, 2, 4, 1, 4, 4, 4, 1, 2, 2, 3, 1, 3, 1, 3, 2, 4, 3, 4, 4, 2, 3, 1, 1, 3, 3, 2, 4, 3, 3, 4, 4, 3, 1, 1, 4, 1, 2, 1, 4, 3, 3, 2, 1, 3, 1, 1, 3, 1, 4, 4, 3, 1, 3, 3, 4, 3, 4, 2, 2, 4, 1, 1, 1, 2, 4, 4, 1, 2, 3, 2, 3, 3, 2, 2, 1]
Within-block instanc

In [22]:
session_accuracy = F.get_session_accuracy(beh_data, task_data)
session_accuracy[2].shape

(5, 15)

In [23]:
len(session_accuracy)

36

In [24]:
inference_trials = F.test_inference_trials(beh_data, task_data)
for i in range(36):
    print(f"Session {i+1}")
    print(f"Baseline: {inference_trials["baseline"][i]}, Inference: {inference_trials["inference"][i]}, Inf. Performance: {inference_trials["inf_perf"][i]}")

Session 1
Baseline: 0.03125, Inference: 0.1875, Inf. Performance: 0.6
Session 2
Baseline: 0.03125, Inference: 0.1875, Inf. Performance: 0.6
Session 3
Baseline: 0.03125, Inference: 0.0, Inf. Performance: 1.0
Session 4
Baseline: 0.03125, Inference: 0.0, Inf. Performance: 1.0
Session 5
Baseline: 0.1875, Inference: 0.03125, Inf. Performance: 0.8
Session 6
Baseline: 0.03125, Inference: 0.5, Inf. Performance: 0.4
Session 7
Baseline: 0.03125, Inference: 0.0, Inf. Performance: 1.0
Session 8
Baseline: 0.1875, Inference: 0.0, Inf. Performance: 1.0
Session 9
Baseline: 0.03125, Inference: 0.0, Inf. Performance: 1.0
Session 10
Baseline: 0.1875, Inference: 0.1875, Inf. Performance: 0.6
Session 11
Baseline: 0.1875, Inference: 0.5, Inf. Performance: 0.4
Session 12
Baseline: 0.03125, Inference: 0.5, Inf. Performance: 0.4
Session 13
Baseline: 0.0, Inference: 0.0, Inf. Performance: 1.0
Session 14
Baseline: 0.0, Inference: 0.0, Inf. Performance: 1.0
Session 15
Baseline: 0.1875, Inference: 0.0, Inf. Perfor

In [25]:
sessions = F.split_sessions(neu_data, beh_data, task_data)
print(len(sessions["absent"])) # 14
print(len(sessions["present"])) # 22

8
9


## Geometric analyses

### Shattering dimensionality

In [47]:
start_time = time.time()
example_sd, example_sd_boot = F.sd(
    example_regressors,
    n_iter=5,
    n_samples=15,
    show_progress=True
)
end_time = time.time()
print(f"Time taken for {len(example_regressors)} regressors: {(end_time - start_time):.2f} seconds")

Time taken for 100 regressors: -648.90 seconds


In [48]:
example_sd

array([0.5    , 0.51125, 0.475  , 0.48125, 0.5225 , 0.53125, 0.4875 ,
       0.49375, 0.50125, 0.4975 , 0.495  , 0.46625, 0.51625, 0.5075 ,
       0.5075 , 0.4825 , 0.48625, 0.49   , 0.5225 , 0.4875 , 0.48   ,
       0.5    , 0.49   , 0.5275 , 0.49625, 0.47875, 0.50375, 0.47375,
       0.4925 , 0.49625, 0.535  , 0.4975 , 0.50375, 0.4775 , 0.51125,
       0.48625, 0.42875, 0.46625, 0.48875, 0.49125, 0.48875, 0.45375,
       0.47875, 0.4725 , 0.4925 , 0.5125 , 0.46   , 0.5425 , 0.5075 ,
       0.51625, 0.4875 , 0.5075 , 0.5    , 0.49   , 0.50875, 0.50875,
       0.52125, 0.50125, 0.48   , 0.5025 , 0.5125 , 0.51625, 0.4775 ,
       0.51875, 0.52125, 0.47875, 0.465  , 0.52875, 0.51   , 0.53625,
       0.5275 , 0.49375, 0.49625, 0.47125, 0.4975 , 0.52375, 0.51125,
       0.48125, 0.5025 , 0.5325 , 0.5025 , 0.46375, 0.48875, 0.49875,
       0.505  , 0.47   , 0.46625, 0.495  , 0.4575 , 0.505  , 0.5225 ,
       0.505  , 0.44625, 0.4875 , 0.5075 , 0.48875, 0.48375, 0.5325 ,
       0.53625, 0.46

In [49]:
example_sd_boot

array([0.47   , 0.505  , 0.50375, 0.4975 , 0.51   , 0.4725 , 0.50625,
       0.505  , 0.47625, 0.49625, 0.525  , 0.51125, 0.44875, 0.53125,
       0.475  , 0.49625, 0.49875, 0.53   , 0.4725 , 0.50375, 0.5075 ,
       0.52125, 0.4725 , 0.49875, 0.5025 , 0.5075 , 0.51625, 0.4875 ,
       0.47   , 0.5425 , 0.50875, 0.49125, 0.48375, 0.47   , 0.48625,
       0.48875, 0.4925 , 0.48875, 0.49875, 0.46   , 0.5    , 0.505  ,
       0.46125, 0.51625, 0.49375, 0.50625, 0.4575 , 0.5    , 0.52   ,
       0.50625, 0.50375, 0.46875, 0.48   , 0.5075 , 0.50125, 0.465  ,
       0.5325 , 0.51   , 0.49875, 0.5    , 0.4875 , 0.495  , 0.50875,
       0.51375, 0.505  , 0.48375, 0.4975 , 0.5225 , 0.525  , 0.49625,
       0.52375, 0.52375, 0.49   , 0.49375, 0.48875, 0.51625, 0.47875,
       0.53625, 0.5325 , 0.485  , 0.4925 , 0.50625, 0.5125 , 0.52125,
       0.54125, 0.49125, 0.5    , 0.4725 , 0.505  , 0.5    , 0.47625,
       0.505  , 0.4875 , 0.48125, 0.47375, 0.47125, 0.4575 , 0.52375,
       0.505  , 0.53

### Cross-condition generalization performance


In [None]:
start_time = time.time()
example_ccgp = F.ccgp(
    example_regressors,
    show_progress=True
)
end_time = time.time()
print(f"Time taken for {len(example_regressors)} regressors: {(end_time - start_time):.2f} seconds")

Time taken for 100 regressors: -331.80 seconds


In [50]:
start_time = time.time()
example_ccgp_boot = F.ccgp(
    example_regressors,
    for_boot=True,
    show_progress=True
)
end_time = time.time()
print(f"Time taken for {len(example_regressors)} regressors: {(end_time - start_time):.2f} seconds")

Time taken for 100 regressors: -384.22 seconds


In [51]:
example_ccgp

array([[0.75046875, 0.74921875, 0.750625  , 0.75      , 0.75015625],
       [0.75054687, 0.75078125, 0.75101562, 0.75101562, 0.75046875],
       [0.75140625, 0.75085937, 0.750625  , 0.75      , 0.75148437],
       [0.7521875 , 0.75046875, 0.75101562, 0.75046875, 0.75054687],
       [0.75085937, 0.75015625, 0.75179687, 0.75117188, 0.75007813],
       [0.75117188, 0.74992188, 0.750625  , 0.74976562, 0.75101562],
       [0.75132812, 0.75117188, 0.74945313, 0.75179688, 0.75015625],
       [0.75125   , 0.75109375, 0.75054687, 0.75015625, 0.75007812],
       [0.75101562, 0.74992188, 0.75054687, 0.750625  , 0.7503125 ],
       [0.75109375, 0.75109375, 0.75125   , 0.75101562, 0.75015625],
       [0.7509375 , 0.75039062, 0.750625  , 0.7503125 , 0.75015625],
       [0.751875  , 0.75125   , 0.75125   , 0.75148437, 0.75007812],
       [0.7509375 , 0.75070313, 0.75054687, 0.75023438, 0.75101562],
       [0.7509375 , 0.75101562, 0.75046875, 0.74992188, 0.75117188],
       [0.75039063, 0.750625  , 0.

In [52]:
example_ccgp_boot

array([[0.75117187, 0.75085938, 0.75015625, 0.75054687, 0.75140625],
       [0.7509375 , 0.7521875 , 0.751875  , 0.75046875, 0.75125   ],
       [0.75039062, 0.75007812, 0.75070313, 0.75054687, 0.75109375],
       [0.75148437, 0.74976562, 0.75085937, 0.75007813, 0.75085937],
       [0.75109375, 0.75132813, 0.75078125, 0.75148437, 0.75023438],
       [0.75078125, 0.75125   , 0.75039062, 0.75007813, 0.75054687],
       [0.75078125, 0.75164062, 0.75117188, 0.75023438, 0.75132813],
       [0.7515625 , 0.7509375 , 0.7509375 , 0.75148437, 0.75085937],
       [0.75179687, 0.75179687, 0.75039063, 0.75015625, 0.75054687],
       [0.7515625 , 0.75148437, 0.75359375, 0.75109375, 0.75039063],
       [0.75054687, 0.75085937, 0.74992187, 0.75085937, 0.7509375 ],
       [0.750625  , 0.75117188, 0.75117188, 0.75101562, 0.75078125],
       [0.75140625, 0.75054687, 0.75140625, 0.75117188, 0.7503125 ],
       [0.75109375, 0.75125   , 0.75117188, 0.75054687, 0.75078125],
       [0.75054687, 0.7503125 , 0.

### Parallelism score

In [53]:
start_time = time.time()
example_ps = F.ps(
    example_regressors,
    show_progress=True
)
end_time = time.time()
print(f"Time taken for {len(example_regressors)} regressors: {(end_time - start_time):.2f} seconds")

Time taken for 100 regressors: 7.44 seconds


In [54]:
start_time = time.time()
example_ps_boot = F.ps(
    example_regressors,
    for_boot=True,
    show_progress=True
)
end_time = time.time()
print(f"Time taken for {len(example_regressors)} regressors: {(end_time - start_time):.2f} seconds")

Time taken for 100 regressors: 8.87 seconds


In [55]:
example_ps

[array([0.99998607]),
 array([0.9999954]),
 array([0.99999099]),
 array([0.99999286]),
 array([0.99999752]),
 array([0.99999938]),
 array([0.99999915]),
 array([0.99999666]),
 array([0.99999934]),
 array([0.99999735]),
 array([0.99999899]),
 array([0.99999506]),
 array([0.99999831]),
 array([0.9999947]),
 array([0.99998714]),
 array([0.99999725]),
 array([0.99999556]),
 array([0.9999884]),
 array([0.99999915]),
 array([0.99999983]),
 array([0.9999985]),
 array([0.99999904]),
 array([0.99998903]),
 array([0.99999668]),
 array([0.99998497]),
 array([0.99999482]),
 array([0.99999853]),
 array([0.99999781]),
 array([0.99999867]),
 array([0.99999908]),
 array([0.99999426]),
 array([0.99999703]),
 array([0.99998967]),
 array([0.99997805]),
 array([0.99998237])]

In [56]:
example_ps_boot

[array([0.99995063]),
 array([0.9999673]),
 array([0.99998334]),
 array([0.99997079]),
 array([0.99991442]),
 array([0.99996726]),
 array([0.99996154]),
 array([0.99997374]),
 array([0.99994261]),
 array([0.99994763]),
 array([0.99998395]),
 array([0.99998723]),
 array([0.9999298]),
 array([0.9999152]),
 array([0.99999401]),
 array([0.99990557]),
 array([0.99995715]),
 array([0.99994674]),
 array([0.99999704]),
 array([0.99997349]),
 array([0.99998825]),
 array([0.99995799]),
 array([0.99994293]),
 array([0.99996812]),
 array([0.99999538]),
 array([0.99998644]),
 array([0.99998754]),
 array([0.99997313]),
 array([0.99997169]),
 array([0.99996855]),
 array([0.99999743]),
 array([0.99999523]),
 array([0.99998797]),
 array([0.99989615]),
 array([0.99998798])]

### General pipeline

1. Divide neural data into inference present vs. absent groupings.
2. For each brain area, find the intersection of indices with each inference type, then compute the geometric analyses separately. 

In [None]:
def sd_task():
    return "sd", F.run_geometric_analysis(
        metric="sd",
        neu_data=neu_data,
        beh_data=beh_data,
        task_data=task_data,
        show_progress=True
    )

def ccgp_task():
    return "ccgp", F.run_geometric_analysis(
        metric="ccgp",
        neu_data=neu_data,
        beh_data=beh_data,
        task_data=task_data,
        n_iter_boot=100,
        show_progress=True
    )

def ps_task():
    return "ps", F.run_geometric_analysis(
        metric="ps",
        neu_data=neu_data,
        beh_data=beh_data,
        task_data=task_data,
        n_iter_boot=100,
        show_progress=True
    )

metric_data = {}
boot_data = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = [
        executor.submit(sd_task),
        executor.submit(ccgp_task),
        executor.submit(ps_task)
    ]
    for future in concurrent.futures.as_completed(futures):
        try:
            metric, data, data_boot = future.result() 
            metric_data[metric] = data
            boot_data[metric] = data_boot

            # Save data locally
            np.save(f"/data/{metric}_analysis.npy", data)
            np.save(f"/data/{metric}_boot.npy", data_boot)

            print(f"Data saved for metric: {metric}.")
        except Exception as e:
            print(f"Error executing task: {e}")

Running analyses for inference absent trials over 15 samples...
Running analyses for inference absent trials over 15 samples...
Running analyses for inference absent trials over 15 samples...
Running analyses for inference present trials over 15 samples...


In [16]:
# TO DO
# Check split_sessions method into inference absent vs. inference present sessions
# Check that this works
# Geometric analyses also splits into inference present and absent trials