In [3]:
import torch
import btbench_config
# Make sure the config ROOT_DIR is set correctly
print("Expected braintreebank data at:", btbench_config.ROOT_DIR)
print("Sampling rate:", btbench_config.SAMPLING_RATE, "Hz")
from braintreebank_subject import BrainTreebankSubject
from btbench_datasets import BrainTreebankSubjectTrialBenchmarkDataset

btbench_tasks = ["frame_brightness", "global_flow", "local_flow", "global_flow_angle", "local_flow_angle", "face_num", "volume", "pitch", "delta_volume", 
                    "delta_pitch", "speech", "onset", "gpt2_surprisal", "word_length", "word_gap", "word_index", "word_head_pos", "word_part_speech", "speaker"]

btbench_tasks = ["frame_brightness", "global_flow", "local_flow", "face_num", "volume", "pitch", "delta_volume", 
                    "delta_pitch", "speech", "onset", "gpt2_surprisal", "word_length", "word_gap", "word_index", "word_head_pos", "word_part_speech", "speaker"]

#btbench_tasks = ["enhanced_pitch"]#, "enhanced_volume", "delta_enhanced_pitch", "delta_enhanced_volume", "raw_pitch", "raw_volume", "delta_raw_pitch", "delta_raw_volume"]

all_subject_trials = [(1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (3, 0), (3, 1), (3, 2), (4, 0), (4, 1), (4, 2), (5, 0), (6, 0), (6, 1), (6, 4), (7, 0), (7, 1), (8, 0), (9, 0), (10, 0), (10, 1)]

Expected braintreebank data at: /om2/user/zaho/braintreebank/braintreebank
Sampling rate: 2048 Hz


In [5]:
# Loop through all subject-trial pairs
for subject_id, trial_id in all_subject_trials:
    print(f"\n=== Subject {subject_id}, Trial {trial_id} ===")
    
    # Load the subject
    subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32)
    print(f"Number of electrodes: {len(subject.electrode_labels)}")

    movie_name = btbench_config.BRAINTREEBANK_SUBJECT_TRIAL_MOVIE_NAME_MAPPING[f"btbank{subject_id}_{trial_id}"]
    print(f"Movie name: {movie_name}")
    
    # Common dataset parameters
    output_indices = False
    start_neural_data_before_word_onset = 0  # the number of samples to start the neural data before each word onset
    end_neural_data_after_word_onset = btbench_config.SAMPLING_RATE * 1  # the number of samples to end the neural data after each word onset -- here we use 1 second
    
    # Try each evaluation task
    print("\nTask sizes:")
    for eval_name in btbench_tasks:
        try:
            dataset = BrainTreebankSubjectTrialBenchmarkDataset(
                subject, trial_id, dtype=torch.float32, eval_name=eval_name, 
                output_indices=output_indices, 
                start_neural_data_before_word_onset=start_neural_data_before_word_onset, 
                end_neural_data_after_word_onset=end_neural_data_after_word_onset
            )
            print(f"  {eval_name}: {len(dataset)} items")
        except Exception as e:
            print(f"  {eval_name}: Error - {str(e)}")


=== Subject 1, Trial 0 ===
Number of electrodes: 130
Movie name: fantastic-mr-fox

Task sizes:
  frame_brightness: 3953 items
  global_flow: 3952 items
  local_flow: 3952 items
  face_num: 3638 items
  volume: 3953 items
  pitch: 3953 items
  delta_volume: 3953 items
  delta_pitch: 3953 items
  speech: 4928 items
  onset: 2558 items
  gpt2_surprisal: 3953 items
  word_length: 3937 items
  word_gap: 4003 items
  word_index: 2558 items
  word_head_pos: 7072 items
  word_part_speech: 2424 items
  speaker: 6918 items

=== Subject 1, Trial 1 ===
Number of electrodes: 130
Movie name: the-martian

Task sizes:
  frame_brightness: 5395 items
  global_flow: 5394 items
  local_flow: 5395 items
  face_num: 7554 items
  volume: 5395 items
  pitch: 5395 items
  delta_volume: 5394 items
  delta_pitch: 5395 items
  speech: 12818 items
  onset: 3134 items
  gpt2_surprisal: 5395 items
  word_length: 5393 items
  word_gap: 4773 items
  word_index: 3134 items
  word_head_pos: 9332 items
  word_part_speec

KeyboardInterrupt: 

In [5]:
# Loop through all subject-trial pairs
for subject_id, trial_id in all_subject_trials:
    if subject_id not in [1, 2, 3, 4, 7, 10]:
        continue
    print(f"\n=== Subject {subject_id}, Trial {trial_id} ===")
    
    # Load the subject
    subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32)
    print(f"Number of electrodes: {len(subject.electrode_labels)}")
    
    # Common dataset parameters
    output_indices = False
    start_neural_data_before_word_onset = 0  # the number of samples to start the neural data before each word onset
    end_neural_data_after_word_onset = btbench_config.SAMPLING_RATE * 1  # the number of samples to end the neural data after each word onset -- here we use 1 second
    
    # Try each evaluation task
    print("\nTask sizes:")
    for eval_name in btbench_tasks:
        try:
            dataset = BrainTreebankSubjectTrialBenchmarkDataset(
                subject, trial_id, dtype=torch.float32, eval_name=eval_name, 
                output_indices=output_indices, 
                start_neural_data_before_word_onset=start_neural_data_before_word_onset, 
                end_neural_data_after_word_onset=end_neural_data_after_word_onset,
                lite=True
            )
            print(f"  {eval_name}: {len(dataset)} items")
        except Exception as e:
            print(f"  {eval_name}: Error - {str(e)}")


=== Subject 1, Trial 0 ===
Number of electrodes: 130

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 1, Trial 1 ===
Number of electrodes: 130

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 1, Trial 2 ===
Number of electrodes: 130

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 0 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 1 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 2 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 3 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 4 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3500 items

=== Subject 2, Trial 5 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 2561 items

=== Subject 2, Trial 6 ===
Number of electrodes: 135

Task sizes:
  enhanced_pitch: 3375 items

=== Subject 3, Trial 0 ===
Number of el

In [3]:
# Loop through all subject-trial pairs
for subject_id in [1, 2, 3, 4, 7, 10]:
    print(f"\n=== Subject {subject_id} ===")
    
    # Load the subject
    subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32)
    print(f"Number of electrodes: {len(subject.electrode_labels)}")


=== Subject 1 ===
Number of electrodes: 130

=== Subject 2 ===
Number of electrodes: 135

=== Subject 3 ===
Number of electrodes: 124

=== Subject 4 ===
Number of electrodes: 185

=== Subject 7 ===
Number of electrodes: 240

=== Subject 10 ===
Number of electrodes: 207
