In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, WhisperModel
import logging
from sklearn.model_selection import KFold
import pandas as pd

# ------ Hyperparameters ------- #
learning_rate = 1e-5
batch_size = 16
num_epochs = 10
alpha = 1  # Weight for text embeddings in concatenation
beta = 1   # Weight for audio embeddings in concatenation

# ------- Model Definition ------- #
class M2BAM(nn.Module):
    def __init__(self):
        super(M2BAM, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.whisper = WhisperModel.from_pretrained("openai/whisper-small")

        self.multitask_layer = nn.Sequential(
            nn.Linear(text_embedding_dim + audio_embedding_dim, 512),  
            nn.ReLU(),
            nn.Linear(512, num_voxels_reading),  
            nn.Linear(512, num_voxels_listening)  
        )

    def forward(self, text_input, audio_input):
        text_embeddings = self.bert(**text_input).last_hidden_state[:, 0, :]  
        audio_embeddings = self.whisper(**audio_input).last_hidden_state[:, 0, :]

        combined_embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1)  # Concatenate along feature dimension
        predictions = self.multitask_layer(combined_embeddings)

        reading_pred, listening_pred = predictions[:, :num_voxels_reading], predictions[:, num_voxels_reading:]
        return reading_pred, listening_pred

# ------- Data Loading Function  ------- #
def load_data(data_dir, task):
    df = pd.read_hdf("data/df_text.hdf")  # Load the DataFrame

    for index, row in tqdm(df.iterrows(), desc="Loading Data", total=df.shape[0]):
        story_name = row['story_name']
        text = row['text']
        task = row['task']

        # Text Input Processing 
        text_input = bert_tokenizer(text, return_tensors='pt')  # Assuming you have a BERT tokenizer

        if task == 'listening': 
            aligned_audio_file = row['aligned_audio_file']
            audio_input = whisper_processor(aligned_audio_file, return_tensors='pt')  # Assuming a Whisper processor
        else:
            audio_input = None  # Placeholder for consistency

        # fMRI Target Data (You'll need to fill this based on how you load fMRI voxels)
        fmri_filename = eval(f"fmri_{task}_{split}") 
        fmri_file = fmri_filename.format(row['subject'])
        with h5py.File(fmri_file, 'r') as f: 
            target_voxel_data = f[story_name][:]  # Example, adjust as needed

        yield text_input, audio_input, target_voxel_data

# ------- Training Loop ------- #
model = M2BAM()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()  # Example loss function

# Load all data into memory
all_data = list(load_data(data_dir, task))

# Initialize KFold cross-validator
kf = KFold(n_splits=10)

# Loop over each fold
for fold, (train_indices, test_indices) in enumerate(kf.split(all_data)):
    # Split data into training and validation sets
    train_data = [all_data[i] for i in train_indices]
    test_data = [all_data[i] for i in test_indices]

    # Train on the training set
    for epoch in range(num_epochs):
        for i, (text_batch, audio_batch, voxel_batch) in enumerate(train_data):
            optimizer.zero_grad()
            reading_output, listening_output = model(text_batch, audio_batch)

            # Compute loss based on your target voxel data and choice of loss function
            loss_reading = criterion(reading_output, voxel_batch[:, :num_voxels_reading]) 
            loss_listening = criterion(listening_output, voxel_batch[:, num_voxels_reading:])
            loss = loss_reading + loss_listening

            loss.backward()
            optimizer.step()

            # Logging the loss
            if i % 10 == 0:  # Log every 10 batches
                logging.info(f'Fold: {fold}, Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')

        # Save the model and optimizer state after each epoch
        torch.save({
            'fold': fold,
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, f'm2bam_model_fold_{fold}_epoch_{epoch}.pt')
        

        # Evaluate on the test set
        model.eval()  # Set the model to evaluation mode
        total_loss = 0
        with torch.no_grad():  # Disable gradient calculation
            for i, (text_batch, audio_batch, voxel_batch) in enumerate(test_data):
                reading_output, listening_output = model(text_batch, audio_batch)

                # Compute loss based on your target voxel data and choice of loss function
                loss_reading = criterion(reading_output, voxel_batch[:, :num_voxels_reading]) 
                loss_listening = criterion(listening_output, voxel_batch[:, num_voxels_reading:])
                loss = loss_reading + loss_listening

                total_loss += loss.item()

        # Calculate the average loss over the test set
        average_loss = total_loss / len(test_data)

        logging.info(f'Fold: {fold}, Test Loss: {average_loss}')
        
        
        # Predict and save the results
        model.eval()  # Set the model to evaluation mode
        predictions = []
        with torch.no_grad():  # Disable gradient calculation
            for i, (text_batch, audio_batch, voxel_batch) in enumerate(test_data):
                reading_output, listening_output = model(text_batch, audio_batch)
                predictions.append(reading_output)
        
        # Convert predictions to a DataFrame
        df_predictions = pd.DataFrame(torch.cat(predictions).cpu().numpy())
        
        # Save to a csv file
        df_predictions.to_csv(f'/prediction/predictions_fold_{fold}.csv', index=False)