In [4]:
# ------- Data Loading Function  ------- #
def load_data(data_dir, task):
    df = pd.read_hdf("data/df_text.hdf")  # Load the DataFrame

    for index, row in tqdm(df.iterrows(), desc="Loading Data", total=df.shape[0]):
        story_name = row['story_name']
        text = row['text']
        task = row['task']

        # Text Input Processing 
        text_input = bert_tokenizer(text, return_tensors='pt')  # Assuming you have a BERT tokenizer

        if task == 'listening': 
            aligned_audio_file = row['aligned_audio_file']
            audio_input = whisper_processor(aligned_audio_file, return_tensors='pt')  # Assuming a Whisper processor
        else:
            audio_input = None  # Placeholder for consistency

        # fMRI Target Data (You'll need to fill this based on how you load fMRI voxels)
        fmri_filename = eval(f"fmri_{task}_{split}") 
        fmri_file = fmri_filename.format(row['subject'])
        with h5py.File(fmri_file, 'r') as f: 
            target_voxel_data = f[story_name][:]  # Example, adjust as needed

        yield text_input, audio_input, target_voxel_data

# Function to load a trained model
def load_model(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model = M2BAM()
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device) 
    model.eval()  # Set to evaluation mode
    return model

# Prediction Function
def predict_and_evaluate(model, test_data, threshold=0.5):
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for text_input, audio_input, target_voxel_data in test_data:
            text_input, audio_input, target_voxel_data = text_input.to(device), audio_input.to(device), target_voxel_data.to(device)

            reading_output, listening_output = model(text_input, audio_input)

            # Apply threshold to get binary predictions
            predictions_reading = (reading_output > threshold).float().cpu().numpy()
            predictions_listening = (listening_output > threshold).float().cpu().numpy()

            all_predictions.append(np.hstack((predictions_reading, predictions_listening)))
            all_targets.append(target_voxel_data.cpu().numpy())

    all_predictions = np.vstack(all_predictions)
    all_targets = np.vstack(all_targets)

    # Calculate F1-macro (averaged over classes) for each story
    story_f1_scores = []
    df = pd.read_hdf("data/df_text.hdf")  # Load the dataframe to get story names
    for index, row in df.iterrows():
        story_name = row["story_name"]
        story_mask = df["story_name"] == story_name
        story_preds = all_predictions[story_mask]
        story_targets = all_targets[story_mask]

        f1 = f1_score(story_targets, story_preds, average='macro')
        story_f1_scores.append((story_name, f1))

    return story_f1_scores  

# **************** MAIN EXECUTION ****************
# 1. Load the test data
test_data = list(load_data('data/test_dir', 'both'))  # Assuming 'both' loads reading + listening

# 2. Load a trained model checkpoint
model_path = 'm2bam_model_fold_0_epoch_9.pt'  # Example - adjust according to your saved models
model = load_model(model_path)

# 3. Run predictions 
story_f1_scores = predict_and_evaluate(model, test_data)

# 4. Print or save results
print("Story-wise F1 Macro Scores:")
for story, f1 in story_f1_scores:
    print(f"Story: {story}, F1-Macro: {f1:.2f}")

Story-wise F1 Macro Scores:
Story: 0, F1-Macro: 0.89
Story: 1, F1-Macro: 0.72
Story: 2, F1-Macro: 0.91
Story: 3, F1-Macro: 0.85
Story: 4, F1-Macro: 0.82
Story: 5, F1-Macro: 0.73
Story: 6, F1-Macro: 0.68
Story: 7, F1-Macro: 0.88
Story: 8, F1-Macro: 0.84
Story: 9, F1-Macro: 0.90
Story: 10, F1-Macro: 0.88
