In [2]:
import numpy as np

# all subject trials including the special case for subject 6 which only has trials 0, 1, and 4
all_subject_trials = [(1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (3, 0), (3, 1), (3, 2), (4, 0), (4, 1), (4, 2), (5, 0), (6, 0), (6, 1), (6, 4), (7, 0), (7, 1), (8, 0), (9, 0), (10, 0), (10, 1)]
print(all_subject_trials)

[(1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (3, 0), (3, 1), (3, 2), (4, 0), (4, 1), (4, 2), (5, 0), (6, 0), (6, 1), (6, 2), (7, 0), (7, 1), (8, 0), (9, 0), (10, 0), (10, 1)]


In [3]:
from braintreebank_process_chunks import *

for sub_id, trial_id in all_subject_trials:
    print(f"Processing subject {sub_id} trial {trial_id}")
    subject = Subject(sub_id)
    subject.check_electrodes(trial_id)

Processing subject 1 trial 0


FileNotFoundError: [Errno 2] No such file or directory: 'braintreebank/localization/sub_1/depth-wm.csv'

# Statistics of coordinates of electrodes across subjects

In [None]:
from braintreebank_utils import *
from braintreebank_config import *

# Initialize lists to store values across subjects
all_L = []
all_I = []
all_P = []

for sub_id in range(1, 11):
    print(f'Processing subject {sub_id}')
    def _clean_electrode_label(electrode_label):
        return electrode_label.replace('*', '').replace('#', '')


    electrode_labels_file = os.path.join(ROOT_DIR, f'braintreebank/electrode_labels/sub_{sub_id}/electrode_labels.json')
    regions_file_format = os.path.join(ROOT_DIR, f'braintreebank/localization/sub_{sub_id}/depth-wm.csv')

    # Load electrode labels from json file
    with open(electrode_labels_file, 'r') as f: electrode_labels = json.load(f)
    electrode_labels = [_clean_electrode_label(label) for label in electrode_labels]

    # Load the brain regions file for this subject
    regions_df = pd.read_csv(regions_file_format)
    regions_df['Electrode'] = regions_df['Electrode'].apply(_clean_electrode_label)
    regions_df['electrode_i'] = regions_df['Electrode'].apply(lambda x: int(electrode_labels.index(x)) if x in electrode_labels else None)

    # Find electrodes in labels but not in regions_df
    labels_not_in_df = [label for label in electrode_labels if label not in regions_df['Electrode'].values]
    print("Electrodes in labels but not in regions_df:", labels_not_in_df)
    # Find electrodes in regions_df but not in labels 
    df_not_in_labels = regions_df[~regions_df['Electrode'].isin(electrode_labels)]['Electrode'].tolist()
    print("Electrodes in regions_df but not in labels:", df_not_in_labels)

    # Count electrodes in left vs right hemisphere based on DKT atlas labels
    left_hem = regions_df[regions_df['DesikanKilliany'].str.contains('-lh')].shape[0]
    right_hem = regions_df[regions_df['DesikanKilliany'].str.contains('-rh')].shape[0]
    print(f"Number of electrodes in left hemisphere: {left_hem}, right hemisphere: {right_hem}")

    # Drop rows where electrode_i is NA
    regions_df = regions_df.dropna(subset=['electrode_i'])

    # Get min, max, mean for L, I, P columns
    for col in ['L', 'I', 'P']:
        arr = regions_df[col].to_numpy()
        print(f"{col}:")
        print(f"  Min: {arr.min()}")
        print(f"  Max: {arr.max()}")
        print(f"  Mean: {arr.mean():.2f}")
    
    # Store values for cross-subject analysis
    all_L.extend(regions_df['L'].tolist())
    all_I.extend(regions_df['I'].tolist())
    all_P.extend(regions_df['P'].tolist())

# Print overall statistics across subjects for each dimension
print("\nOverall statistics across all subjects:")
print("L dimension:")
print(f"  Min: {min(all_L)}")
print(f"  Max: {max(all_L)}")
print(f"  Mean: {sum(all_L)/len(all_L):.2f}")

print("\nI dimension:")
print(f"  Min: {min(all_I)}")
print(f"  Max: {max(all_I)}")
print(f"  Mean: {sum(all_I)/len(all_I):.2f}")

print("\nP dimension:")
print(f"  Min: {min(all_P)}")
print(f"  Max: {max(all_P)}")
print(f"  Mean: {sum(all_P)/len(all_P):.2f}")

In [None]:
# max params that fit on A100 80G

training_config = {
    'n_epochs': 4,
    'save_network_every_n_epochs': 1,

    'batch_size': 116,
    'train_subject_trials': [(2, 4)], #[(2, 4), (1, 1), (3, 1)],
    'lr_max': 0.001,
    'lr_min': 0.0001,
    #'lr_warmup_frac': 0.01, # need to specify either warmup frac or steps
    'lr_warmup_steps': 100,
    'weight_decay': 0.001,
    'random_string': "XX",
}
assert ('lr_warmup_frac' in training_config) != ('lr_warmup_steps' in training_config), "Need to specify either lr_warmup_frac or lr_warmup_steps, not both"

transformer_config = {
    'model_name': "trx",
    'max_n_electrodes': 130,
    'n_freq_features': 37,
    'max_n_time_bins': 10,
    'd_model': 256,
    'n_heads': 8,
    'n_layers': 10,
    'dropout': 0.1,
    'mask_type': 'mask-out-one',
    'dtype': torch.bfloat16,
    'device': device,
}
transformer_config['rope_encoding_scale'] = transformer_config['max_n_time_bins']
transformer_config['dim_output'] = transformer_config['n_freq_features']