### Importing the required libraries

In [1]:
import logging
import time
import json # For reading TR
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn 
import pickle
import torch.optim as optim
import matplotlib.pyplot as plt
import logging
import os
import wandb
import random
import math 
from torch.utils.data import Dataset, TensorDataset, DataLoader 
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from typing import List

from transformers import AutoProcessor, AutoModel
from tqdm import tqdm

In [2]:
# Import project modules
# import config # Import base config
# from config import get_tr_from_json
from data_loader import  load_fmri_data, load_video_data
from evaluate import evaluate_model, plot_predictions
from models import EncodingModel, DecodingModel, VoxelWiseEncodingModel # Using MLP models
from models import LSTMEncodingModel, LSTMDecodingModel # Using LSTM models
from preprocessing import preprocess_fmri, preprocess_video_embeddings, align_data
from train import train_epoch # Import reverted train_epoch
from video_encoder import VideoFeatureExtractor

In [3]:
logger = logging.getLogger(__name__) # Get logger instance

### Setup directories for data

In [4]:
BASE_DIR = Path("/workspace/hardik/") # Adjusted path
DATA_DIR = BASE_DIR / "csai_data" # Adjusted path

OUTPUT_DIR = BASE_DIR / "output_csai"
CACHE_DIR = BASE_DIR / "cache_csai"

VIDEO_DIR = DATA_DIR / "stimuli"
VIDEO_FILE_TEMPLATE = "{stimulus_lower}.mp4"

OUTPUT_DIR.mkdir(exist_ok=True)
CACHE_DIR.mkdir(exist_ok=True)

In [5]:
TRAIN_SUBJECT_IDS = ["NSD103","NSD104","NSD105","NSD106","NSD107","NSD113","NSD114", "NSD115","NSD116","NSD117","NSD119","NSD120","NSD122","NSD123","NSD124","NSD125","NSD126","NSD127","NSD128","NSD129","NSD130","NSD132","NSD134","NSD135","NSD136","NSD138","NSD140","NSD142","NSD145","NSD146","NSD147","NSD148","NSD149","NSD150","NSD151","NSD153","NSD155"] # List of subjects for training
TEST_SUBJECT_IDS = ["NSD108","NSD109","NSD110","NSD111"] # List of subjects for training
TEST_SUBJECT_ID = "NSD114"     # Single subject for testing
# STIMULI_NAMES = ["iteration", "defeat"]
STIMULI_NAMES = ["iteration"]
# STIMULI_NAMES = ["growth"]

In [6]:
# --- fMRI Parameters ---
PREPROC_DIR = DATA_DIR / "derivatives" / "preprocessed"

FMRI_VARIANT = "nocensor_srm-recon" 
FMRI_FILE_TEMPLATE = "sub-{subject_id}_task-{stimulus_lower}_{variant}.nii.gz"

# Raw data directory template (string for formatting later)
RAW_FUNC_DIR_PATH_TEMPLATE = DATA_DIR / "sub-{subject_id}" / "func"
# Raw JSON template (string)
RAW_JSON_FILENAME_TEMPLATE = "sub-{subject_id}_task-{stimulus_lower}_echo-1_bold.json"

In [7]:
TR = 1 # Will be read dynamically

TRS_DROP_MAP = {
    "growth": (2, 11),
    "lemonade": (2, 11),
    "iteration": (2, 11),
    "defeat": (2, 12),
}

# TRS_DROP_MAP = {
#     "iteration": (2, 11)
# }

### Choose the video encoder name

In [8]:
VIDEO_ENCODER_NAME = 'vitmae' # timesformer, videomae, vivit, vitmae
chosen_encoder = VIDEO_ENCODER_NAME

In [9]:
VIDEO_MODEL_IDENTIFIERS = {
    'timesformer': "facebook/timesformer-base-finetuned-k400",
    'videomae': "MCG-NJU/videomae-base-finetuned-kinetics",
    # 'xclip': "microsoft/xclip-base-patch32",
    'vivit' : "google/vivit-b-16x2",
    "vitmae" : "facebook/vit-mae-base"
}

VIDEO_MODEL_OUTPUT_DIMS = {
    'timesformer': 768,
    'videomae': 768,
    # 'xclip': 512,
    'vivit' : 768,
    'vitmae' :768 
}

In [10]:
# Placeholder - these get overwritten at runtime based on VIDEO_ENCODER_NAME
VIDEO_EMBEDDING_MODEL = VIDEO_MODEL_IDENTIFIERS[VIDEO_ENCODER_NAME]
DEC_OUTPUT_DIM = VIDEO_MODEL_OUTPUT_DIMS[VIDEO_ENCODER_NAME]        # Will be set from VIDEO_MODEL_OUTPUT_DIMS

# --- Video Processing Parameters  ---
VIDEO_CHUNK_SIZE = 8 # Frames per model input clip (TimeSformer/VideoMAE usually 8 or 16)
VIDEO_CHUNK_STRIDE = 4 # Overlap between chunks
VIDEO_EMBEDDING_BATCH_SIZE = 4 # Reduce if OOM during *embedding extraction*

# --- Preprocessing Parameters ---
ALIGN_METHOD = "convolve"
HRF_DELAY = 4.0

DO_FMRI_ZSCORE = True
DO_VIDEO_ZSCORE = True

In [11]:
# --- Model Parameters (MLP - Non-Temporal) ---
USE_TEMPORAL_MODELS = False # Ensure this is False to use MLP
DEVICE = "cuda:2" if torch.cuda.is_available() else "cpu"
# --- MLP Hidden Dims ---
ENC_HIDDEN_DIM = 512 # Hidden dim for MLP encoding model
DEC_HIDDEN_DIM = 2048 # Hidden dim for MLP decoding model (reduced from 2048)
# --- Input/Output Dims (set dynamically) ---
ENC_OUTPUT_DIM = -1   # Set dynamically (PCA components)
DEC_INPUT_DIM = -1    # Set dynamically (PCA components)

APPLY_PCA = True # Flag to enable/disable PCA
PCA_N_COMPONENTS = 1000 # Target number of fMRI features after PCA

# --- Training Parameters ---
LEARNING_RATE = 1e-5
BATCH_SIZE = 64
EPOCHS = 50

# --- Evaluation ---
EVAL_METRICS = ['mse', 'pearsonr', 'r2']

In [12]:
print(f"Selected Video Encoder: {chosen_encoder} (ID: {VIDEO_EMBEDDING_MODEL}, Dim: {DEC_OUTPUT_DIM})")

# --- Log Core Config ---
print(f"Train Subject(s): {TRAIN_SUBJECT_IDS}, Test Subject: {TEST_SUBJECT_ID}")
print(f"Device: {DEVICE}, fMRI Variant: {FMRI_VARIANT}")
print(f"Apply PCA: {APPLY_PCA}, PCA Components: {PCA_N_COMPONENTS if APPLY_PCA else 'N/A'}")
print(f"Using Temporal Models: {USE_TEMPORAL_MODELS}") # Should be False for MLP


Selected Video Encoder: vitmae (ID: facebook/vit-mae-base, Dim: 768)
Train Subject(s): ['NSD103', 'NSD104', 'NSD105', 'NSD106', 'NSD107', 'NSD113', 'NSD114', 'NSD115', 'NSD116', 'NSD117', 'NSD119', 'NSD120', 'NSD122', 'NSD123', 'NSD124', 'NSD125', 'NSD126', 'NSD127', 'NSD128', 'NSD129', 'NSD130', 'NSD132', 'NSD134', 'NSD135', 'NSD136', 'NSD138', 'NSD140', 'NSD142', 'NSD145', 'NSD146', 'NSD147', 'NSD148', 'NSD149', 'NSD150', 'NSD151', 'NSD153', 'NSD155'], Test Subject: NSD114
Device: cuda:2, fMRI Variant: nocensor_srm-recon
Apply PCA: True, PCA Components: 1000
Using Temporal Models: False


In [13]:
fmri_tr = 1

### Setup class to provide video embeddings

In [14]:
video_extractor = VideoFeatureExtractor(
    model_id=VIDEO_EMBEDDING_MODEL,
    fps=23.98,
    device=DEVICE
)

In [15]:
def save_data(data, filename: Path):
    logger.info(f"Saving data to {filename}")
    filename.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_data(filename: Path):
    logger.info(f"Loading data from {filename}")
    if not filename.exists():
        logger.error(f"Cache file not found: {filename}")
        return None
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except Exception as e:
        logger.error(f"Error loading cache file {filename}: {e}", exc_info=True)
        return None

In [16]:
def process_subject_data(subject_id, fmri_tr, chosen_encoder):
    """Loads, preprocesses, and aligns data for a single subject."""
    
    print(f"--- Processing data for Subject: {subject_id} ---")
    subject_fmri_aligned = []
    subject_video_aligned = []
    
    # find the model output dimensions from the dictionary defined above
    expected_vid_dim = VIDEO_MODEL_OUTPUT_DIMS[chosen_encoder]
    safe_encoder_name = chosen_encoder.replace('/','_') # Safe name for filenames

    video_embeddings_cache = {} # Cache loaded video embeddings per stimulus

    # Include PCA components in cache filename
    pca_suffix = f"_pca{PCA_N_COMPONENTS}" if APPLY_PCA else ""

    for stimulus_lower in STIMULI_NAMES:
        print(f"Processing {subject_id} / {stimulus_lower}...")
        # --- Define file paths ---
        fmri_filename = FMRI_FILE_TEMPLATE.format(
            subject_id=subject_id,
            stimulus_lower=stimulus_lower,
            variant=FMRI_VARIANT
        )
        fmri_path = PREPROC_DIR / f"sub-{subject_id}" / fmri_filename
        video_filename = VIDEO_FILE_TEMPLATE.format(stimulus_lower=stimulus_lower)
        video_path = VIDEO_DIR / video_filename

        # --- Cache Keys ---
        aligned_cache_key = f"sub-{subject_id}_{stimulus_lower}_{FMRI_VARIANT}_{safe_encoder_name}{pca_suffix}_aligned.pkl"
        aligned_cache_file = CACHE_DIR / aligned_cache_key
        
        # Video embeddings are stimulus-specific, cache key doesn't need subject_id
        embedding_cache_key = f"{stimulus_lower}_{safe_encoder_name}_embeddings.npy"
        embedding_cache_file = CACHE_DIR / embedding_cache_key

        # --- Check cache for ALIGNED data first ---
        if aligned_cache_file.exists():
            cached_data = load_data(aligned_cache_file)
            if cached_data:
                fmri_aligned, video_aligned = cached_data
                if fmri_aligned.ndim == 2 and video_aligned.ndim == 2 and \
                   fmri_aligned.shape[0] == video_aligned.shape[0] and \
                   video_aligned.shape[1] == expected_vid_dim:
                    subject_fmri_aligned.append(fmri_aligned)
                    subject_video_aligned.append(video_aligned)
                    print(f"Loaded cached aligned data for {subject_id}/{stimulus_lower}")
                    continue # Go to next stimulus if aligned cache is valid
                else:
                    print(f"-> Invalid cached aligned data shape/dim for {subject_id}/{stimulus_lower}. Recomputing.")


        # --- Load Raw fMRI Data ---
        try:
            start_load = time.time()
            print(f"Loading fMRI from: {fmri_path}")
            fmri_data = load_fmri_data(fmri_path, stimulus_lower, fmri_tr, FMRI_VARIANT, TRS_DROP_MAP)
            if fmri_data is None or fmri_data.size == 0:
                 print(f"-> fMRI data loading failed or resulted in empty array for {subject_id}/{stimulus_lower}.")
                 continue # Skip this stimulus
            print(f"fMRI loaded in {time.time() - start_load:.2f}s. Shape: {fmri_data.shape}")
        except FileNotFoundError:
            print(f"-> fMRI file not found for {subject_id}/{stimulus_lower}: {fmri_path}. Skipping stimulus.")
            continue


        # --- Load/Cache Video Embeddings (Stimulus Specific) ---
        video_embeddings = None
        if stimulus_lower in video_embeddings_cache:
            video_embeddings = video_embeddings_cache[stimulus_lower]
            print(f"Using pre-loaded video embeddings for {stimulus_lower}.")
        elif embedding_cache_file.exists():
            print(f"Loading cached {chosen_encoder} video embeddings from: {embedding_cache_file}")
            video_embeddings = np.load(embedding_cache_file)
            if video_embeddings.ndim != 2 or video_embeddings.shape[1] != expected_vid_dim:
                print(f"-> Cached embeddings {embedding_cache_file} have wrong shape/dim ({video_embeddings.shape}). Re-extracting.")
                embedding_cache_file.unlink()
                video_embeddings = None
            else:
                video_embeddings_cache[stimulus_lower] = video_embeddings # Store in memory cache
        else:
            video_embeddings = None # Needs extraction

        # --- Extract Video Embeddings if needed ---
        video_frames, video_fps_load_check = load_video_data(video_path) # Get FPS here too
        video_extractor = VideoFeatureExtractor(
        model_id=VIDEO_EMBEDDING_MODEL,
        fps=video_fps_load_check,
        device=DEVICE
        )
        if video_embeddings is None:
            try:
                # Load video frames (only if extraction is needed)
                print(f"Loading video frames from: {video_path}")
                start_vid_load = time.time()
                if not video_frames:
                     logger.error(f"Failed to load video frames for {stimulus_lower}. Skipping.")
                     continue
                print(f"Video frames loaded in {time.time() - start_vid_load:.2f}s.")

                print(f"Extracting {chosen_encoder} video embeddings for {stimulus_lower}...")
                start_embed = time.time()
                video_embeddings = video_extractor.extract_features(video_frames, batch_size=VIDEO_EMBEDDING_BATCH_SIZE)
                print(f"Video embedding shape : {video_embeddings.shape}")
                print(f"Video embedding extraction took {time.time() - start_embed:.2f}s")

                if video_embeddings is None or video_embeddings.size == 0: raise ValueError("Extractor returned empty embeddings")

                if video_embeddings.shape[1] != expected_vid_dim:
                     logger.error(f"Extracted {chosen_encoder} dim {video_embeddings.shape[1]} != expected {expected_vid_dim}.")
                     continue # Skip if extraction yields wrong dimension

                np.save(embedding_cache_file, video_embeddings)
                print(f"Saved {chosen_encoder} video embeddings to: {embedding_cache_file}")
                video_embeddings_cache[stimulus_lower] = video_embeddings # Store in memory cache
                # Need video_fps for alignment, get it from the load_video_data call
                video_fps = video_fps_load_check
                num_video_frames_original = len(video_frames)

            except Exception as e:
                logger.error(f"Video embedding extraction failed for {stimulus_lower} / {chosen_encoder}: {e}", exc_info=True)
                continue

        else:
             # If embeddings loaded from cache, we still need FPS and frame count for alignment
             # Re-load video just to get metadata if not already loaded 
             
             try:
                 # Quick load just for metadata
                 _ , video_fps_check = load_video_data(video_path)
                 # Assuming frame count isn't strictly needed if embeddings exist, but FPS is
                 video_fps = video_fps_check

                 # [TODO]
                 # We don't have exact num_frames_original if only cache was loaded,
                 # but maybe it's not strictly needed by align_data if using chunk times? Check align_data.
                 # For safety, let's pass 0 or estimate if needed by align_data.
                 # Revisit align_data's use of num_video_frames_original. If it's just for logging, it's ok.

                 num_video_frames_original = 0 # Placeholder if only cache loaded
                 
             except Exception as e:
                  logger.error(f"Could not load video metadata for {video_path} even though embeddings exist: {e}. Skipping alignment.")
                  continue


        # --- Preprocess & Align ---
        try:
            start_preprocess = time.time()
            fmri_processed = preprocess_fmri(DO_FMRI_ZSCORE, APPLY_PCA, PCA_N_COMPONENTS,  fmri_data)
            video_embeddings_processed = preprocess_video_embeddings(DO_VIDEO_ZSCORE, video_embeddings)
            fmri_aligned, video_aligned = align_data(
                fmri_processed,
                video_embeddings_processed,
                fmri_tr=fmri_tr,
                video_fps=video_fps, # Use FPS obtained above
                num_video_frames_original=num_video_frames_original, # Use value obtained above
                hrf_delay=HRF_DELAY,
                align_method=ALIGN_METHOD,
                video_chunk_size=video_extractor.num_frames_per_clip,
                video_chunk_stride=VIDEO_CHUNK_STRIDE
            )
            print(f"Preprocessing and alignment took {time.time() - start_preprocess:.2f}s")
            if fmri_aligned.shape[0] == 0 or video_aligned.shape[0] == 0:
                print(f"-> Alignment resulted in empty data for {subject_id}/{stimulus_lower}. Skipping.")
                continue

            save_data((fmri_aligned, video_aligned), aligned_cache_file) # Save to subject specific cache
            subject_fmri_aligned.append(fmri_aligned)
            subject_video_aligned.append(video_aligned)

        except Exception as e:
            logger.error(f"Error during preprocess/align for {subject_id}/{stimulus_lower}: {e}", exc_info=True)
            continue
        # --- End Stimulus Loop ---

    # --- Concatenate data for *this* subject ---
    if not subject_fmri_aligned:
        print(f"-> No data successfully processed for subject {subject_id}.")
        return None, None # Return None if no data for this subject

    final_fmri = np.concatenate(subject_fmri_aligned, axis=0)
    final_video = np.concatenate(subject_video_aligned, axis=0)
    print(f"--- Final Concatenated Data Shapes for Subject {subject_id} ---")
    print(f"Subject fMRI data: {final_fmri.shape}")
    print(f"Subject Video embeddings: {final_video.shape}")
    return final_fmri, final_video

In [17]:
all_train_fmri_raw = []
all_train_video = []
for train_subj_id in TRAIN_SUBJECT_IDS:
    fmri_subj_raw, video_subj = process_subject_data(train_subj_id, fmri_tr, chosen_encoder)
    
    if fmri_subj_raw is not None and video_subj is not None:
        all_train_fmri_raw.append(fmri_subj_raw)
        all_train_video.append(video_subj)
    else:
        logger.warning(f"Could not process data for training subject: {train_subj_id}")

--- Processing data for Subject: NSD103 ---
Processing NSD103 / iteration...
Loaded cached aligned data for NSD103/iteration
--- Final Concatenated Data Shapes for Subject NSD103 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD104 ---
Processing NSD104 / iteration...
Loaded cached aligned data for NSD104/iteration
--- Final Concatenated Data Shapes for Subject NSD104 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD105 ---
Processing NSD105 / iteration...
Loaded cached aligned data for NSD105/iteration
--- Final Concatenated Data Shapes for Subject NSD105 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD106 ---
Processing NSD106 / iteration...
Loaded cached aligned data for NSD106/iteration
--- Final Concatenated Data Shapes for Subject NSD106 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 7

In [18]:
final_train_fmri = np.concatenate(all_train_fmri_raw, axis=0)
final_train_video = np.concatenate(all_train_video, axis=0)
print(f"--- Combined Training Data Shapes (Before PCA) ---")
print(f"Train fMRI (Raw/SRM): {final_train_fmri.shape}")
print(f"Train Video ({chosen_encoder}): {final_train_video.shape}")


--- Combined Training Data Shapes (Before PCA) ---
Train fMRI (Raw/SRM): (27454, 1000)
Train Video (vitmae): (27454, 768)


### --- Prepare Training DataLoader ---

In [19]:
try:
    train_dataset = TensorDataset(torch.from_numpy(final_train_fmri).float(),
                                    torch.from_numpy(final_train_video).float())
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    print(f"Train DataLoader ready (Size: {len(train_loader)} samples).")
except Exception as e:
    print(f"Failed to create training DataLoader: {e}", exc_info=True)



Train DataLoader ready (Size: 429 samples).


### Processing the Test Data

In [20]:
# --- Process Test Data ---
final_test_fmri = None
video_test = None
test_loader = None
test_dataset = None

# if TEST_SUBJECT_ID:
#     print(f"--- Processing Test Data for Subject: {TEST_SUBJECT_ID} ---")
#     final_test_fmri, video_test = process_subject_data(TEST_SUBJECT_ID, fmri_tr, chosen_encoder)
#     print(final_test_fmri.shape, video_test.shape)

#     if final_test_fmri is None or video_test is None:
#         print(f"No test data could be processed for subject {TEST_SUBJECT_ID}. Cannot evaluate.")
#         # Decide whether to proceed with training only or exit
#     else:
#         print(f"--- Final Test Data Shapes (After PCA if Applied) ---")
#         print(f"Test fMRI data: {final_test_fmri.shape}")
#         print(f"Test Video embeddings: {video_test.shape}")

#         # --- Prepare Test DataLoader ---
#         try:
#             test_dataset = TensorDataset(torch.from_numpy(final_test_fmri).float(),
#                                             torch.from_numpy(video_test).float())
#             test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
#             print(f"Test Dataset ready (Size: {len(test_dataset)} samples).")
#             print(f"Test DataLoader ready (Size: {len(test_loader)} samples).")
#         except Exception as e:
#                 logger.error(f"Failed to create test DataLoader: {e}", exc_info=True)
#                 test_loader = None # Mark as None if failed
# else:
#         print("No TEST_SUBJECT_ID specified. Skipping test phase.")
#         test_loader = None


In [21]:
all_test_fmri_raw = []
all_test_video = []
for test_subj_id in TEST_SUBJECT_IDS:
    fmri_subj_raw, video_subj = process_subject_data(test_subj_id, fmri_tr, chosen_encoder)
    
    if fmri_subj_raw is not None and video_subj is not None:
        all_test_fmri_raw.append(fmri_subj_raw)
        all_test_video.append(video_subj)
    else:
        logger.warning(f"Could not process data for testing subject: {test_subj_id}")


--- Processing data for Subject: NSD108 ---
Processing NSD108 / iteration...
Loaded cached aligned data for NSD108/iteration
--- Final Concatenated Data Shapes for Subject NSD108 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD109 ---
Processing NSD109 / iteration...
Loaded cached aligned data for NSD109/iteration
--- Final Concatenated Data Shapes for Subject NSD109 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD110 ---
Processing NSD110 / iteration...
Loaded cached aligned data for NSD110/iteration
--- Final Concatenated Data Shapes for Subject NSD110 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 768)
--- Processing data for Subject: NSD111 ---
Processing NSD111 / iteration...
Loaded cached aligned data for NSD111/iteration
--- Final Concatenated Data Shapes for Subject NSD111 ---
Subject fMRI data: (742, 1000)
Subject Video embeddings: (742, 7

In [22]:
final_test_fmri = np.concatenate(all_test_fmri_raw, axis=0)
final_test_video = np.concatenate(all_test_video, axis=0)
print(f"--- Combined Testing Data Shapes (Before PCA) ---")
print(f"Test fMRI (Raw/SRM): {final_test_fmri.shape}")
print(f"Test Video ({chosen_encoder}): {final_test_video.shape}")

--- Combined Testing Data Shapes (Before PCA) ---
Test fMRI (Raw/SRM): (2968, 1000)
Test Video (vitmae): (2968, 768)


In [23]:
try:
    test_dataset = TensorDataset(torch.from_numpy(final_test_fmri).float(),
                                    torch.from_numpy(final_test_video).float())
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    print(f"Test DataLoader ready (Size: {len(test_loader)} samples).")
except Exception as e:
    print(f"Failed to create test DataLoader: {e}", exc_info=True)



Test DataLoader ready (Size: 47 samples).


In [24]:
# --- Update config dims (Based on final *training* data) ---
ENC_OUTPUT_DIM = final_train_fmri.shape[1]
DEC_INPUT_DIM = final_train_fmri.shape[1]

# DEC_OUTPUT_DIM is already set based on video encoder choice

In [25]:
encoding_model = EncodingModel(
    video_embed_dim=DEC_OUTPUT_DIM,
    fmri_dim=ENC_OUTPUT_DIM,
    hidden_dim=ENC_HIDDEN_DIM
)
decoding_model = DecodingModel(
    fmri_dim=DEC_INPUT_DIM,
    video_embed_dim=DEC_OUTPUT_DIM,
    hidden_dim=DEC_HIDDEN_DIM
)

In [26]:
def run_training_no_val(model, train_loader, device, task='encoding', epochs=EPOCHS
                        , learning_rate=LEARNING_RATE):
    """ Simplified training loop without validation. """

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    model_type = "LSTM" if USE_TEMPORAL_MODELS else "VoxelWiseMLP"
    
    if model_type == "LSTM" and task == 'encoding':
        model =  LSTMEncodingModel(
            video_embed_dim=DEC_OUTPUT_DIM,
            fmri_dim=ENC_OUTPUT_DIM,
            hidden_dim=ENC_HIDDEN_DIM
        )
    elif model_type == "LSTM" and task == 'decoding':
        model = LSTMDecodingModel(
            fmri_dim=DEC_INPUT_DIM,
            video_embed_dim=DEC_OUTPUT_DIM,
            hidden_dim=DEC_HIDDEN_DIM
        )
    elif model_type == "VoxelWiseMLP" and task == 'encoding':
        model = VoxelWiseEncodingModel(
            video_embed_dim=DEC_OUTPUT_DIM,
            fmri_dim=ENC_OUTPUT_DIM,
            hidden_dim=ENC_HIDDEN_DIM
        )
    else:
        model = DecodingModel(
            fmri_dim=DEC_INPUT_DIM,
            video_embed_dim=DEC_OUTPUT_DIM,
            hidden_dim=DEC_HIDDEN_DIM
        )
    

    print(f"Starting {task} model training ({model_type}, no validation) for {epochs} epochs on {device}")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    stimuli_str = "_".join(STIMULI_NAMES)

    run = f"{task}_{model_type}_{VIDEO_ENCODER_NAME}_{PCA_N_COMPONENTS if APPLY_PCA else 'nopca'}_{stimuli_str}"
    config = {
        "task": task,
        "model_type": model_type,
        "video_encoder": VIDEO_ENCODER_NAME,
        "pca_components": PCA_N_COMPONENTS if APPLY_PCA else "nopca",
        "stimuli": stimuli_str,
        "learning_rate": learning_rate,
        "batch_size": BATCH_SIZE,
        "epochs": epochs,
        "encoder_hidden_dim": ENC_HIDDEN_DIM,
        "decoder_hidden_dim": DEC_HIDDEN_DIM,
        "encoder_output_dim": ENC_OUTPUT_DIM,
        "decoder_input_dim": DEC_INPUT_DIM,
        "decoder_output_dim": DEC_OUTPUT_DIM,
        "train_subjects": TRAIN_SUBJECT_IDS,
        "test_subject": TEST_SUBJECT_ID,
    }
    wandb.init(project="csai_fmri_video_project", name=run, config=config, reinit=True) # Initialize wandb

    for epoch in range(epochs):
        # Pass use_temporal flag if train_epoch still expects it, otherwise remove
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, task)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f}")
        # No validation step, no saving best model based on validation
        print("> Evaluation on train")
        results, _, _ = evaluate_model(model, train_loader, nn.MSELoss(), device=DEVICE, task=task)
               
        print("\n> Evaluation on test")
        test_results, _, _ = evaluate_model(model, test_loader, nn.MSELoss(), device=DEVICE, task=task)
        wandb.log({
                "MSE": results['mse'], 
                "PearsonR Avg": results['pearsonr_avg'],
                "PearsonR All": np.median(results['pearsonr_all']),
                "R2 Uniform": results['r2_uniform'],
                "R2 Variance": results['r2_variance'],
                "Test MSE": test_results['mse'],
                "Test PearsonR Avg": test_results['pearsonr_avg'],
                "Test PearsonR All": np.median(test_results['pearsonr_all']),
                "Test R2 Uniform": test_results['r2_uniform'],
                "Test R2 Variance": test_results['r2_variance']
            })

    print(f"Training complete after {epochs} epochs. Using model from last epoch.")
    wandb.finish() # Finish wandb run
    return model # Return model from last epoch


In [27]:
encoding_model = run_training_no_val( # Use modified training loop
    encoding_model, train_loader, device=DEVICE, task='encoding',
    epochs=EPOCHS, learning_rate=LEARNING_RATE
)

Starting encoding model training (VoxelWiseMLP, no validation) for 50 epochs on cuda:2


[34m[1mwandb[0m: Currently logged in as: [33mmhardik003[0m ([33mwb-team-hardik[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


                                                                                  

Epoch 1/50 | Train Loss: 0.9997
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9885
  Avg Pearson Correlation (across features/voxels): 0.1055
  Median Pearson Correlation: 0.1049
  R2 Score (uniform avg): 0.0099
  R2 Score (variance weighted): 0.0099

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0007
  Avg Pearson Correlation (across features/voxels): 0.0263
  Median Pearson Correlation: 0.0254
  R2 Score (uniform avg): -0.0015
  R2 Score (variance weighted): -0.0015


                                                                                  

Epoch 2/50 | Train Loss: 0.9904
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9821
  Avg Pearson Correlation (across features/voxels): 0.1369
  Median Pearson Correlation: 0.1362
  R2 Score (uniform avg): 0.0163
  R2 Score (variance weighted): 0.0163

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0011
  Avg Pearson Correlation (across features/voxels): 0.0350
  Median Pearson Correlation: 0.0340
  R2 Score (uniform avg): -0.0019
  R2 Score (variance weighted): -0.0019


                                                                                  

Epoch 3/50 | Train Loss: 0.9854
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9778
  Avg Pearson Correlation (across features/voxels): 0.1528
  Median Pearson Correlation: 0.1522
  R2 Score (uniform avg): 0.0207
  R2 Score (variance weighted): 0.0207

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0016
  Avg Pearson Correlation (across features/voxels): 0.0404
  Median Pearson Correlation: 0.0402
  R2 Score (uniform avg): -0.0023
  R2 Score (variance weighted): -0.0023


                                                                                  

Epoch 4/50 | Train Loss: 0.9819
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9746
  Avg Pearson Correlation (across features/voxels): 0.1630
  Median Pearson Correlation: 0.1626
  R2 Score (uniform avg): 0.0238
  R2 Score (variance weighted): 0.0238

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0023
  Avg Pearson Correlation (across features/voxels): 0.0439
  Median Pearson Correlation: 0.0436
  R2 Score (uniform avg): -0.0031
  R2 Score (variance weighted): -0.0031


                                                                                  

Epoch 5/50 | Train Loss: 0.9791
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9720
  Avg Pearson Correlation (across features/voxels): 0.1696
  Median Pearson Correlation: 0.1691
  R2 Score (uniform avg): 0.0264
  R2 Score (variance weighted): 0.0264

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0030
  Avg Pearson Correlation (across features/voxels): 0.0464
  Median Pearson Correlation: 0.0468
  R2 Score (uniform avg): -0.0038
  R2 Score (variance weighted): -0.0038


                                                                                  

Epoch 6/50 | Train Loss: 0.9770
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9700
  Avg Pearson Correlation (across features/voxels): 0.1746
  Median Pearson Correlation: 0.1739
  R2 Score (uniform avg): 0.0284
  R2 Score (variance weighted): 0.0284

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0036
  Avg Pearson Correlation (across features/voxels): 0.0488
  Median Pearson Correlation: 0.0488
  R2 Score (uniform avg): -0.0044
  R2 Score (variance weighted): -0.0044


                                                                                  

Epoch 7/50 | Train Loss: 0.9753
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9684
  Avg Pearson Correlation (across features/voxels): 0.1788
  Median Pearson Correlation: 0.1781
  R2 Score (uniform avg): 0.0301
  R2 Score (variance weighted): 0.0301

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0041
  Avg Pearson Correlation (across features/voxels): 0.0505
  Median Pearson Correlation: 0.0513
  R2 Score (uniform avg): -0.0049
  R2 Score (variance weighted): -0.0049


                                                                                  

Epoch 8/50 | Train Loss: 0.9738
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9670
  Avg Pearson Correlation (across features/voxels): 0.1820
  Median Pearson Correlation: 0.1813
  R2 Score (uniform avg): 0.0315
  R2 Score (variance weighted): 0.0315

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0049
  Avg Pearson Correlation (across features/voxels): 0.0516
  Median Pearson Correlation: 0.0525
  R2 Score (uniform avg): -0.0057
  R2 Score (variance weighted): -0.0057


                                                                                  

Epoch 9/50 | Train Loss: 0.9726
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9659
  Avg Pearson Correlation (across features/voxels): 0.1842
  Median Pearson Correlation: 0.1837
  R2 Score (uniform avg): 0.0325
  R2 Score (variance weighted): 0.0325

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0057
  Avg Pearson Correlation (across features/voxels): 0.0522
  Median Pearson Correlation: 0.0526
  R2 Score (uniform avg): -0.0065
  R2 Score (variance weighted): -0.0065


                                                                                  

Epoch 10/50 | Train Loss: 0.9717
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9649
  Avg Pearson Correlation (across features/voxels): 0.1866
  Median Pearson Correlation: 0.1859
  R2 Score (uniform avg): 0.0336
  R2 Score (variance weighted): 0.0336

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0062
  Avg Pearson Correlation (across features/voxels): 0.0533
  Median Pearson Correlation: 0.0527
  R2 Score (uniform avg): -0.0070
  R2 Score (variance weighted): -0.0070


                                                                                  

Epoch 11/50 | Train Loss: 0.9707
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9641
  Avg Pearson Correlation (across features/voxels): 0.1882
  Median Pearson Correlation: 0.1874
  R2 Score (uniform avg): 0.0343
  R2 Score (variance weighted): 0.0343

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0071
  Avg Pearson Correlation (across features/voxels): 0.0531
  Median Pearson Correlation: 0.0523
  R2 Score (uniform avg): -0.0079
  R2 Score (variance weighted): -0.0079


                                                                                  

Epoch 12/50 | Train Loss: 0.9701
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9634
  Avg Pearson Correlation (across features/voxels): 0.1894
  Median Pearson Correlation: 0.1888
  R2 Score (uniform avg): 0.0350
  R2 Score (variance weighted): 0.0350

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0077
  Avg Pearson Correlation (across features/voxels): 0.0543
  Median Pearson Correlation: 0.0539
  R2 Score (uniform avg): -0.0085
  R2 Score (variance weighted): -0.0084


                                                                                  

Epoch 13/50 | Train Loss: 0.9694
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9630
  Avg Pearson Correlation (across features/voxels): 0.1904
  Median Pearson Correlation: 0.1896
  R2 Score (uniform avg): 0.0355
  R2 Score (variance weighted): 0.0355

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0082
  Avg Pearson Correlation (across features/voxels): 0.0550
  Median Pearson Correlation: 0.0556
  R2 Score (uniform avg): -0.0090
  R2 Score (variance weighted): -0.0090


                                                                                  

Epoch 14/50 | Train Loss: 0.9689
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9624
  Avg Pearson Correlation (across features/voxels): 0.1916
  Median Pearson Correlation: 0.1908
  R2 Score (uniform avg): 0.0360
  R2 Score (variance weighted): 0.0360

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0088
  Avg Pearson Correlation (across features/voxels): 0.0548
  Median Pearson Correlation: 0.0543
  R2 Score (uniform avg): -0.0096
  R2 Score (variance weighted): -0.0096


                                                                                  

Epoch 15/50 | Train Loss: 0.9684
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9620
  Avg Pearson Correlation (across features/voxels): 0.1925
  Median Pearson Correlation: 0.1919
  R2 Score (uniform avg): 0.0365
  R2 Score (variance weighted): 0.0365

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0092
  Avg Pearson Correlation (across features/voxels): 0.0551
  Median Pearson Correlation: 0.0549
  R2 Score (uniform avg): -0.0100
  R2 Score (variance weighted): -0.0100


                                                                                  

Epoch 16/50 | Train Loss: 0.9680
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9617
  Avg Pearson Correlation (across features/voxels): 0.1931
  Median Pearson Correlation: 0.1924
  R2 Score (uniform avg): 0.0368
  R2 Score (variance weighted): 0.0368

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0096
  Avg Pearson Correlation (across features/voxels): 0.0558
  Median Pearson Correlation: 0.0567
  R2 Score (uniform avg): -0.0104
  R2 Score (variance weighted): -0.0104


                                                                                  

Epoch 17/50 | Train Loss: 0.9677
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9613
  Avg Pearson Correlation (across features/voxels): 0.1938
  Median Pearson Correlation: 0.1931
  R2 Score (uniform avg): 0.0371
  R2 Score (variance weighted): 0.0371

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0098
  Avg Pearson Correlation (across features/voxels): 0.0562
  Median Pearson Correlation: 0.0557
  R2 Score (uniform avg): -0.0106
  R2 Score (variance weighted): -0.0106


                                                                                  

Epoch 18/50 | Train Loss: 0.9673
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9611
  Avg Pearson Correlation (across features/voxels): 0.1943
  Median Pearson Correlation: 0.1936
  R2 Score (uniform avg): 0.0374
  R2 Score (variance weighted): 0.0374

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0104
  Avg Pearson Correlation (across features/voxels): 0.0561
  Median Pearson Correlation: 0.0567
  R2 Score (uniform avg): -0.0111
  R2 Score (variance weighted): -0.0111


                                                                                  

Epoch 19/50 | Train Loss: 0.9670
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9608
  Avg Pearson Correlation (across features/voxels): 0.1948
  Median Pearson Correlation: 0.1942
  R2 Score (uniform avg): 0.0376
  R2 Score (variance weighted): 0.0376

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0108
  Avg Pearson Correlation (across features/voxels): 0.0557
  Median Pearson Correlation: 0.0554
  R2 Score (uniform avg): -0.0115
  R2 Score (variance weighted): -0.0115


                                                                                  

Epoch 20/50 | Train Loss: 0.9668
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9606
  Avg Pearson Correlation (across features/voxels): 0.1953
  Median Pearson Correlation: 0.1946
  R2 Score (uniform avg): 0.0378
  R2 Score (variance weighted): 0.0378

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0111
  Avg Pearson Correlation (across features/voxels): 0.0561
  Median Pearson Correlation: 0.0555
  R2 Score (uniform avg): -0.0119
  R2 Score (variance weighted): -0.0119


                                                                                  

Epoch 21/50 | Train Loss: 0.9665
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9604
  Avg Pearson Correlation (across features/voxels): 0.1956
  Median Pearson Correlation: 0.1949
  R2 Score (uniform avg): 0.0380
  R2 Score (variance weighted): 0.0380

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0115
  Avg Pearson Correlation (across features/voxels): 0.0558
  Median Pearson Correlation: 0.0560
  R2 Score (uniform avg): -0.0123
  R2 Score (variance weighted): -0.0123


                                                                                  

Epoch 22/50 | Train Loss: 0.9664
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9603
  Avg Pearson Correlation (across features/voxels): 0.1960
  Median Pearson Correlation: 0.1953
  R2 Score (uniform avg): 0.0382
  R2 Score (variance weighted): 0.0382

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0115
  Avg Pearson Correlation (across features/voxels): 0.0567
  Median Pearson Correlation: 0.0562
  R2 Score (uniform avg): -0.0123
  R2 Score (variance weighted): -0.0123


                                                                                  

Epoch 23/50 | Train Loss: 0.9662
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9601
  Avg Pearson Correlation (across features/voxels): 0.1962
  Median Pearson Correlation: 0.1954
  R2 Score (uniform avg): 0.0383
  R2 Score (variance weighted): 0.0383

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0119
  Avg Pearson Correlation (across features/voxels): 0.0562
  Median Pearson Correlation: 0.0553
  R2 Score (uniform avg): -0.0127
  R2 Score (variance weighted): -0.0127


                                                                                  

Epoch 24/50 | Train Loss: 0.9661
> Evaluation on train
Evaluating encoding model...


                                                                        

Evaluation complete. Target shape: (27454, 1000), Prediction shape: (27454, 1000)
  MSE: 0.9600
  Avg Pearson Correlation (across features/voxels): 0.1965
  Median Pearson Correlation: 0.1957
  R2 Score (uniform avg): 0.0384
  R2 Score (variance weighted): 0.0384

> Evaluation on test
Evaluating encoding model...


                                                                      

Evaluation complete. Target shape: (2968, 1000), Prediction shape: (2968, 1000)
  MSE: 1.0121
  Avg Pearson Correlation (across features/voxels): 0.0566
  Median Pearson Correlation: 0.0560
  R2 Score (uniform avg): -0.0129
  R2 Score (variance weighted): -0.0129


Training (encoding):  95%|█████████▍| 406/429 [02:09<00:07,  3.18it/s, loss=0.947]

In [None]:
# if test_loader:
#     print(f"\n--- Evaluating Encoding Model on Subj {TEST_SUBJECT_ID} ---")
#     enc_results, enc_targets, enc_predictions = evaluate_model(
#             encoding_model, test_loader, nn.MSELoss(), device=DEVICE, task='encoding'
#     )
#     # --- Add encoder/subject info to output files ---
#     enc_model_suffix = f"_mlp_{chosen_encoder}_train{''.join(TRAIN_SUBJECT_IDS)}_test{TEST_SUBJECT_ID}"
#     enc_model_path = OUTPUT_DIR / f"encoding_model{enc_model_suffix}.pt"
#     torch.save(encoding_model.state_dict(), enc_model_path)
#     print(f"Saved Encoding model to {enc_model_path}")

#     enc_fig = plot_predictions(enc_targets, enc_predictions, n_samples=5, title=f"Encoding (MLP, {chosen_encoder}) Test Subj {TEST_SUBJECT_ID}")
#     enc_plot_path = OUTPUT_DIR / f"encoding_predictions{enc_model_suffix}.png"
#     enc_fig.savefig(enc_plot_path)
#     plt.close(enc_fig)
#     print(f"Saved Encoding prediction plot to {enc_plot_path}")
    
# else:
#         print("Skipping encoding model evaluation as no test loader was created.")


### --- Decoding Model Training & Evaluation ---


In [None]:
# print(f"\n--- Training Decoding Model (MLP, Video: {chosen_encoder}) on Subj {TRAIN_SUBJECT_IDS} ---")
# if DEVICE == 'cuda' or 'cuda:0' or 'cuda:2':
#     # Clear memory if possible
#     del encoding_model
#     if 'enc_targets' in locals(): del enc_targets
#     if 'enc_predictions' in locals(): del enc_predictions
#     torch.cuda.empty_cache()
#     print("Cleared CUDA cache before starting decoding model training.")

In [None]:
# decoding_model = run_training_no_val( # Use modified training loop
#     decoding_model, train_loader, device=DEVICE, task='decoding',
#     epochs=EPOCHS, learning_rate=LEARNING_RATE
# )

# if test_loader:
#     print(f"\n--- Evaluating Decoding Model on Subj {TEST_SUBJECT_ID} ---")
#     dec_results, dec_targets, dec_predictions = evaluate_model(
#         decoding_model, test_loader, nn.MSELoss(), device=DEVICE, task='decoding'
#     )
#     dec_model_suffix = f"_mlp_{chosen_encoder}_train{''.join(TRAIN_SUBJECT_IDS)}_test{TEST_SUBJECT_ID}"
#     dec_model_path = OUTPUT_DIR / f"decoding_model{dec_model_suffix}.pt"
#     # torch.save(decoding_model.state_dict(), dec_model_path)
#     print(f"Saved Decoding model to {dec_model_path}")

#     dec_fig = plot_predictions(dec_targets, dec_predictions, n_samples=5, title=f"Decoding (MLP, {chosen_encoder}) Test Subj {TEST_SUBJECT_ID}")
#     dec_plot_path = OUTPUT_DIR / f"decoding_predictions{dec_model_suffix}.png"
#     # dec_fig.savefig(dec_plot_path)
#     plt.close(dec_fig)
#     print(f"Saved Decoding prediction plot to {dec_plot_path}")
# else:
#         print("Skipping decoding model evaluation as no test loader was created.")
