In [None]:
import torch
import torch.nn as nn
from transformers import WhisperModel, Wav2Vec2Processor
import logging
from sklearn.model_selection import KFold
import pandas as pd
import h5py
from torch.utils.data import Dataset, DataLoader

# ------ Hyperparameters ------- #
learning_rate = 1e-5
batch_size = 16
num_epochs = 10

# Define the dimensions of the audio embeddings
audio_embedding_dim = 768  # Example, adjust based on your audio feature extractor

# Define the number of voxels in the fMRI data for the reading and listening tasks
num_voxels_reading = 81133
num_voxels_listening = 81133

# Define the processor for the audio feature extractor 
audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")  # Or your chosen model

# Define the directory where your data is stored
data_dir = '/path/to/your/data'

# Define the filename pattern for your fMRI data files
fmri_task_split = f'fmri_{task}_{{split}}.npy'  # Use this as a format string


# ------- Dataset & DataLoader ------- #
class MyDataset(Dataset):
    def __init__(self, audio_data, voxel_data): 
        self.audio_data = audio_data
        self.voxel_data = voxel_data

    def __len__(self):
        return len(self.audio_data)

    def __getitem__(self, idx):
        return self.audio_data[idx], self.voxel_data[idx]


def create_dataloader(data_dir, task):
    df = pd.read_hdf("data/df_text.hdf")  # Load the DataFrame

    audio_data = []
    voxel_data = []
    for index, row in df.iterrows():
        story_name = row['story_name']
        aligned_audio_file = row['aligned_audio_file']

        # Audio Input Processing 
        audio_input = audio_processor(aligned_audio_file, return_tensors='pt') 

        # Load fMRI Target Data
        fmri_filename = eval(f"fmri_{task}_{split}") 
        fmri_file = fmri_filename.format(row['subject'])
        with h5py.File(fmri_file, 'r') as f: 
            target_voxel_data = f[story_name][:]  

        audio_data.append(audio_input)
        voxel_data.append(target_voxel_data)

    dataset = MyDataset(audio_data, voxel_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader 

# ------- Model Definition ------- #
class M2BAM(nn.Module):
    def __init__(self):
        super(M2BAM, self).__init__()
        self.audio_model = WhisperModel.from_pretrained("openai/whisper-small")  # Or your chosen audio model

        self.multitask_layer = nn.Sequential(
            nn.Linear(audio_embedding_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, num_voxels_reading),  
            nn.Linear(2048, num_voxels_listening)  
        )

    def forward(self, audio_input):  # Only audio_input now
        audio_embeddings = self.audio_model(**audio_input).last_hidden_state[:, 0, :]
        predictions = self.multitask_layer(audio_embeddings)

        reading_pred, listening_pred = predictions[:, :num_voxels_reading], predictions[:, num_voxels_listening:]
        return reading_pred, listening_pred


# ------- Training Loop ------- #
def train_model(dataloader, task, num_epochs):
    model = M2BAM()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        for i, (text_batch, voxel_batch) in enumerate(dataloader):
            optimizer.zero_grad()
            reading_output, listening_output = model(text_batch)

            loss_reading = criterion(reading_output, voxel_batch[:, :num_voxels_reading]) 
            loss_listening = criterion(listening_output, voxel_batch[:, num_voxels_listening:])
            loss = loss_reading + loss_listening

            loss.backward()
            optimizer.step()

            if i % 10 == 0: 
                logging.info(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')

# ------- Main Execution ------- #
if __name__ == "__main__":
    task = 'reading'  # or 'listening'
    dataloader = create_dataloader(data_dir, task)
    train_model(dataloader, task, num_epochs) 
