### Importing the required libraries

In [1]:
import logging
import time
import json # For reading TR
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn 
import pickle
import torch.optim as optim
import matplotlib.pyplot as plt
import logging
import os
import wandb
import random
import math 
from torch.utils.data import Dataset, TensorDataset, DataLoader 
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from typing import List

from transformers import AutoProcessor, AutoModel
from tqdm import tqdm

In [2]:
# Import project modules
# import config # Import base config
# from config import get_tr_from_json
from data_loader import  load_fmri_data, load_video_data
from evaluate import evaluate_model, plot_predictions
from models import EncodingModel, DecodingModel # Using MLP models
from models import LSTMEncodingModel, LSTMDecodingModel # Using LSTM models
from preprocessing import preprocess_fmri, preprocess_video_embeddings, align_data
from train import train_epoch # Import reverted train_epoch
from video_encoder import VideoFeatureExtractor

In [3]:
logger = logging.getLogger(__name__) # Get logger instance

### Setup directories for data

In [4]:
BASE_DIR = Path("/workspace/hardik/") # Adjusted path
DATA_DIR = BASE_DIR / "csai_data" # Adjusted path

OUTPUT_DIR = BASE_DIR / "output_csai"
CACHE_DIR = BASE_DIR / "cache_csai"

VIDEO_DIR = DATA_DIR / "stimuli"
VIDEO_FILE_TEMPLATE = "{stimulus_lower}.mp4"

OUTPUT_DIR.mkdir(exist_ok=True)
CACHE_DIR.mkdir(exist_ok=True)

In [5]:
TRAIN_SUBJECT_IDS = ["NSD103","NSD104","NSD105","NSD106","NSD107","NSD113","NSD114", "NSD115","NSD116","NSD117","NSD119","NSD120","NSD122","NSD123","NSD124","NSD125","NSD126","NSD127","NSD128","NSD129","NSD130","NSD132","NSD134","NSD135","NSD136","NSD138","NSD140","NSD142","NSD145","NSD146","NSD147","NSD148","NSD149","NSD150","NSD151","NSD153","NSD155"] # List of subjects for training
TEST_SUBJECT_IDS = ["NSD108","NSD109","NSD110","NSD111"] # List of subjects for training
TEST_SUBJECT_ID = "NSD114"     # Single subject for testing
# STIMULI_NAMES = ["iteration", "defeat"]
# STIMULI_NAMES = ["iteration", "defeat", "growth", "lemonade"]
STIMULI_NAMES = ["defeat"]

In [6]:
# --- fMRI Parameters ---
PREPROC_DIR = DATA_DIR / "derivatives" / "preprocessed"

FMRI_VARIANT = "nocensor_srm-recon" 
FMRI_FILE_TEMPLATE = "sub-{subject_id}_task-{stimulus_lower}_{variant}.nii.gz"

# Raw data directory template (string for formatting later)
RAW_FUNC_DIR_PATH_TEMPLATE = DATA_DIR / "sub-{subject_id}" / "func"
# Raw JSON template (string)
RAW_JSON_FILENAME_TEMPLATE = "sub-{subject_id}_task-{stimulus_lower}_echo-1_bold.json"

In [7]:
TR = 1 # Will be read dynamically

TRS_DROP_MAP = {
    "growth": (2, 11),
    "lemonade": (2, 11),
    "iteration": (2, 11),
    "defeat": (2, 12),
}

# TRS_DROP_MAP = {
#     "iteration": (2, 11)
# }

### Choose the video encoder name

In [8]:
VIDEO_ENCODER_NAME = 'vitmae' # timesformer, videomae, vivit, vitmae
chosen_encoder = VIDEO_ENCODER_NAME

In [9]:
VIDEO_MODEL_IDENTIFIERS = {
    'timesformer': "facebook/timesformer-base-finetuned-k400",
    'videomae': "MCG-NJU/videomae-base-finetuned-kinetics",
    # 'xclip': "microsoft/xclip-base-patch32",
    'vivit' : "google/vivit-b-16x2",
    "vitmae" : "facebook/vit-mae-base"
}

VIDEO_MODEL_OUTPUT_DIMS = {
    'timesformer': 768,
    'videomae': 768,
    # 'xclip': 512,
    'vivit' : 768,
    'vitmae' :768 
}

In [10]:
# Placeholder - these get overwritten at runtime based on VIDEO_ENCODER_NAME
VIDEO_EMBEDDING_MODEL = VIDEO_MODEL_IDENTIFIERS[VIDEO_ENCODER_NAME]
DEC_OUTPUT_DIM = VIDEO_MODEL_OUTPUT_DIMS[VIDEO_ENCODER_NAME]        # Will be set from VIDEO_MODEL_OUTPUT_DIMS

# --- Video Processing Parameters  ---
VIDEO_CHUNK_SIZE = 8 # Frames per model input clip (TimeSformer/VideoMAE usually 8 or 16)
VIDEO_CHUNK_STRIDE = 4 # Overlap between chunks
VIDEO_EMBEDDING_BATCH_SIZE = 4 # Reduce if OOM during *embedding extraction*

# --- Preprocessing Parameters ---
ALIGN_METHOD = "convolve"
HRF_DELAY = 4.0

DO_FMRI_ZSCORE = True
DO_VIDEO_ZSCORE = True

In [11]:
# --- Model Parameters (MLP - Non-Temporal) ---
USE_TEMPORAL_MODELS = True # Ensure this is False to use MLP
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# --- MLP Hidden Dims ---
ENC_HIDDEN_DIM = 2048 # Hidden dim for MLP encoding model
DEC_HIDDEN_DIM = 2048 # Hidden dim for MLP decoding model (reduced from 2048)
# --- Input/Output Dims (set dynamically) ---
ENC_OUTPUT_DIM = -1   # Set dynamically (PCA components)
DEC_INPUT_DIM = -1    # Set dynamically (PCA components)

APPLY_PCA = True # Flag to enable/disable PCA
PCA_N_COMPONENTS = 1000 # Target number of fMRI features after PCA

# --- Training Parameters ---
LEARNING_RATE = 1e-5
BATCH_SIZE = 64
EPOCHS = 50

# --- Evaluation ---
EVAL_METRICS = ['mse', 'pearsonr', 'r2']

In [12]:
print(f"Selected Video Encoder: {chosen_encoder} (ID: {VIDEO_EMBEDDING_MODEL}, Dim: {DEC_OUTPUT_DIM})")

# --- Log Core Config ---
print(f"Train Subject(s): {TRAIN_SUBJECT_IDS}, Test Subject: {TEST_SUBJECT_ID}")
print(f"Device: {DEVICE}, fMRI Variant: {FMRI_VARIANT}")
print(f"Apply PCA: {APPLY_PCA}, PCA Components: {PCA_N_COMPONENTS if APPLY_PCA else 'N/A'}")
print(f"Using Temporal Models: {USE_TEMPORAL_MODELS}") # Should be False for MLP


Selected Video Encoder: vitmae (ID: facebook/vit-mae-base, Dim: 768)
Train Subject(s): ['NSD103', 'NSD104', 'NSD105', 'NSD106', 'NSD107', 'NSD113', 'NSD114', 'NSD115', 'NSD116', 'NSD117', 'NSD119', 'NSD120', 'NSD122', 'NSD123', 'NSD124', 'NSD125', 'NSD126', 'NSD127', 'NSD128', 'NSD129', 'NSD130', 'NSD132', 'NSD134', 'NSD135', 'NSD136', 'NSD138', 'NSD140', 'NSD142', 'NSD145', 'NSD146', 'NSD147', 'NSD148', 'NSD149', 'NSD150', 'NSD151', 'NSD153', 'NSD155'], Test Subject: NSD114
Device: cuda:0, fMRI Variant: nocensor_srm-recon
Apply PCA: True, PCA Components: 1000
Using Temporal Models: True


In [13]:
fmri_tr = 1

### Setup class to provide video embeddings

In [14]:
video_extractor = VideoFeatureExtractor(
    model_id=VIDEO_EMBEDDING_MODEL,
    fps=23.98,
    device=DEVICE
)

In [15]:
def save_data(data, filename: Path):
    logger.info(f"Saving data to {filename}")
    filename.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_data(filename: Path):
    logger.info(f"Loading data from {filename}")
    if not filename.exists():
        logger.error(f"Cache file not found: {filename}")
        return None
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except Exception as e:
        logger.error(f"Error loading cache file {filename}: {e}", exc_info=True)
        return None

In [16]:
def process_subject_data(subject_id, fmri_tr, chosen_encoder):
    """Loads, preprocesses, and aligns data for a single subject."""
    
    print(f"--- Processing data for Subject: {subject_id} ---")
    subject_fmri_aligned = []
    subject_video_aligned = []
    
    # find the model output dimensions from the dictionary defined above
    expected_vid_dim = VIDEO_MODEL_OUTPUT_DIMS[chosen_encoder]
    safe_encoder_name = chosen_encoder.replace('/','_') # Safe name for filenames

    video_embeddings_cache = {} # Cache loaded video embeddings per stimulus

    # Include PCA components in cache filename
    pca_suffix = f"_pca{PCA_N_COMPONENTS}" if APPLY_PCA else ""

    for stimulus_lower in STIMULI_NAMES:
        print(f"Processing {subject_id} / {stimulus_lower}...")
        # --- Define file paths ---
        fmri_filename = FMRI_FILE_TEMPLATE.format(
            subject_id=subject_id,
            stimulus_lower=stimulus_lower,
            variant=FMRI_VARIANT
        )
        fmri_path = PREPROC_DIR / f"sub-{subject_id}" / fmri_filename
        video_filename = VIDEO_FILE_TEMPLATE.format(stimulus_lower=stimulus_lower)
        video_path = VIDEO_DIR / video_filename

        # --- Cache Keys ---
        aligned_cache_key = f"sub-{subject_id}_{stimulus_lower}_{FMRI_VARIANT}_{safe_encoder_name}{pca_suffix}_aligned.pkl"
        aligned_cache_file = CACHE_DIR / aligned_cache_key
        
        # Video embeddings are stimulus-specific, cache key doesn't need subject_id
        embedding_cache_key = f"{stimulus_lower}_{safe_encoder_name}_embeddings.npy"
        embedding_cache_file = CACHE_DIR / embedding_cache_key

        # --- Check cache for ALIGNED data first ---
        if aligned_cache_file.exists():
            cached_data = load_data(aligned_cache_file)
            if cached_data:
                fmri_aligned, video_aligned = cached_data
                if fmri_aligned.ndim == 2 and video_aligned.ndim == 2 and \
                   fmri_aligned.shape[0] == video_aligned.shape[0] and \
                   video_aligned.shape[1] == expected_vid_dim:
                    subject_fmri_aligned.append(fmri_aligned)
                    subject_video_aligned.append(video_aligned)
                    print(f"Loaded cached aligned data for {subject_id}/{stimulus_lower}")
                    continue # Go to next stimulus if aligned cache is valid
                else:
                    print(f"-> Invalid cached aligned data shape/dim for {subject_id}/{stimulus_lower}. Recomputing.")


        # --- Load Raw fMRI Data ---
        try:
            start_load = time.time()
            print(f"Loading fMRI from: {fmri_path}")
            fmri_data = load_fmri_data(fmri_path, stimulus_lower, fmri_tr, FMRI_VARIANT, TRS_DROP_MAP)
            if fmri_data is None or fmri_data.size == 0:
                 print(f"-> fMRI data loading failed or resulted in empty array for {subject_id}/{stimulus_lower}.")
                 continue # Skip this stimulus
            print(f"fMRI loaded in {time.time() - start_load:.2f}s. Shape: {fmri_data.shape}")
        except FileNotFoundError:
            print(f"-> fMRI file not found for {subject_id}/{stimulus_lower}: {fmri_path}. Skipping stimulus.")
            continue


        # --- Load/Cache Video Embeddings (Stimulus Specific) ---
        video_embeddings = None
        if stimulus_lower in video_embeddings_cache:
            video_embeddings = video_embeddings_cache[stimulus_lower]
            print(f"Using pre-loaded video embeddings for {stimulus_lower}.")
        elif embedding_cache_file.exists():
            print(f"Loading cached {chosen_encoder} video embeddings from: {embedding_cache_file}")
            video_embeddings = np.load(embedding_cache_file)
            if video_embeddings.ndim != 2 or video_embeddings.shape[1] != expected_vid_dim:
                print(f"-> Cached embeddings {embedding_cache_file} have wrong shape/dim ({video_embeddings.shape}). Re-extracting.")
                embedding_cache_file.unlink()
                video_embeddings = None
            else:
                video_embeddings_cache[stimulus_lower] = video_embeddings # Store in memory cache
        else:
            video_embeddings = None # Needs extraction

        # --- Extract Video Embeddings if needed ---
        video_frames, video_fps_load_check = load_video_data(video_path) # Get FPS here too
        video_extractor = VideoFeatureExtractor(
        model_id=VIDEO_EMBEDDING_MODEL,
        fps=video_fps_load_check,
        device=DEVICE
        )
        if video_embeddings is None:
            try:
                # Load video frames (only if extraction is needed)
                print(f"Loading video frames from: {video_path}")
                start_vid_load = time.time()
                if not video_frames:
                     logger.error(f"Failed to load video frames for {stimulus_lower}. Skipping.")
                     continue
                print(f"Video frames loaded in {time.time() - start_vid_load:.2f}s.")

                print(f"Extracting {chosen_encoder} video embeddings for {stimulus_lower}...")
                start_embed = time.time()
                video_embeddings = video_extractor.extract_features(video_frames, batch_size=VIDEO_EMBEDDING_BATCH_SIZE)
                print(f"Video embedding shape : {video_embeddings.shape}")
                print(f"Video embedding extraction took {time.time() - start_embed:.2f}s")

                if video_embeddings is None or video_embeddings.size == 0: raise ValueError("Extractor returned empty embeddings")

                if video_embeddings.shape[1] != expected_vid_dim:
                     logger.error(f"Extracted {chosen_encoder} dim {video_embeddings.shape[1]} != expected {expected_vid_dim}.")
                     continue # Skip if extraction yields wrong dimension

                np.save(embedding_cache_file, video_embeddings)
                print(f"Saved {chosen_encoder} video embeddings to: {embedding_cache_file}")
                video_embeddings_cache[stimulus_lower] = video_embeddings # Store in memory cache
                # Need video_fps for alignment, get it from the load_video_data call
                video_fps = video_fps_load_check
                num_video_frames_original = len(video_frames)

            except Exception as e:
                logger.error(f"Video embedding extraction failed for {stimulus_lower} / {chosen_encoder}: {e}", exc_info=True)
                continue

        else:
             # If embeddings loaded from cache, we still need FPS and frame count for alignment
             # Re-load video just to get metadata if not already loaded 
             
             try:
                 # Quick load just for metadata
                 _ , video_fps_check = load_video_data(video_path)
                 # Assuming frame count isn't strictly needed if embeddings exist, but FPS is
                 video_fps = video_fps_check

                 # [TODO]
                 # We don't have exact num_frames_original if only cache was loaded,
                 # but maybe it's not strictly needed by align_data if using chunk times? Check align_data.
                 # For safety, let's pass 0 or estimate if needed by align_data.
                 # Revisit align_data's use of num_video_frames_original. If it's just for logging, it's ok.

                 num_video_frames_original = 0 # Placeholder if only cache loaded
                 
             except Exception as e:
                  logger.error(f"Could not load video metadata for {video_path} even though embeddings exist: {e}. Skipping alignment.")
                  continue


        # --- Preprocess & Align ---
        try:
            start_preprocess = time.time()
            fmri_processed = preprocess_fmri(DO_FMRI_ZSCORE, APPLY_PCA, PCA_N_COMPONENTS,  fmri_data)
            video_embeddings_processed = preprocess_video_embeddings(DO_VIDEO_ZSCORE, video_embeddings)
            fmri_aligned, video_aligned = align_data(
                fmri_processed,
                video_embeddings_processed,
                fmri_tr=fmri_tr,
                video_fps=video_fps, # Use FPS obtained above
                num_video_frames_original=num_video_frames_original, # Use value obtained above
                hrf_delay=HRF_DELAY,
                align_method=ALIGN_METHOD,
                video_chunk_size=video_extractor.num_frames_per_clip,
                video_chunk_stride=VIDEO_CHUNK_STRIDE
            )
            print(f"Preprocessing and alignment took {time.time() - start_preprocess:.2f}s")
            if fmri_aligned.shape[0] == 0 or video_aligned.shape[0] == 0:
                print(f"-> Alignment resulted in empty data for {subject_id}/{stimulus_lower}. Skipping.")
                continue

            save_data((fmri_aligned, video_aligned), aligned_cache_file) # Save to subject specific cache
            subject_fmri_aligned.append(fmri_aligned)
            subject_video_aligned.append(video_aligned)

        except Exception as e:
            logger.error(f"Error during preprocess/align for {subject_id}/{stimulus_lower}: {e}", exc_info=True)
            continue
        # --- End Stimulus Loop ---

    # --- Concatenate data for *this* subject ---
    if not subject_fmri_aligned:
        print(f"-> No data successfully processed for subject {subject_id}.")
        return None, None # Return None if no data for this subject

    final_fmri = np.concatenate(subject_fmri_aligned, axis=0)
    final_video = np.concatenate(subject_video_aligned, axis=0)
    print(f"--- Final Concatenated Data Shapes for Subject {subject_id} ---")
    print(f"Subject fMRI data: {final_fmri.shape}")
    print(f"Subject Video embeddings: {final_video.shape}")
    return final_fmri, final_video

In [17]:
all_train_fmri_raw = []
all_train_video = []
for train_subj_id in TRAIN_SUBJECT_IDS:
    fmri_subj_raw, video_subj = process_subject_data(train_subj_id, fmri_tr, chosen_encoder)
    
    if fmri_subj_raw is not None and video_subj is not None:
        all_train_fmri_raw.append(fmri_subj_raw)
        all_train_video.append(video_subj)
    else:
        logger.warning(f"Could not process data for training subject: {train_subj_id}")

--- Processing data for Subject: NSD103 ---
Processing NSD103 / defeat...
Loaded cached aligned data for NSD103/defeat
--- Final Concatenated Data Shapes for Subject NSD103 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD104 ---
Processing NSD104 / defeat...
Loaded cached aligned data for NSD104/defeat
--- Final Concatenated Data Shapes for Subject NSD104 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD105 ---
Processing NSD105 / defeat...
Loaded cached aligned data for NSD105/defeat
--- Final Concatenated Data Shapes for Subject NSD105 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD106 ---
Processing NSD106 / defeat...
Loaded cached aligned data for NSD106/defeat
--- Final Concatenated Data Shapes for Subject NSD106 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data 

In [18]:
final_train_fmri = np.concatenate(all_train_fmri_raw, axis=0)
final_train_video = np.concatenate(all_train_video, axis=0)
print(f"--- Combined Training Data Shapes (Before PCA) ---")
print(f"Train fMRI (Raw/SRM): {final_train_fmri.shape}")
print(f"Train Video ({chosen_encoder}): {final_train_video.shape}")


--- Combined Training Data Shapes (Before PCA) ---
Train fMRI (Raw/SRM): (17427, 1000)
Train Video (vitmae): (17427, 768)


### --- Prepare Training DataLoader ---

In [19]:
try:
    train_dataset = TensorDataset(torch.from_numpy(final_train_fmri).float(),
                                    torch.from_numpy(final_train_video).float())
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    print(f"Train DataLoader ready (Size: {len(train_loader)} samples).")
except Exception as e:
    print(f"Failed to create training DataLoader: {e}", exc_info=True)



Train DataLoader ready (Size: 273 samples).


### Processing the Test Data

In [20]:
# --- Process Test Data ---
final_test_fmri = None
video_test = None
test_loader = None
test_dataset = None

# if TEST_SUBJECT_ID:
#     print(f"--- Processing Test Data for Subject: {TEST_SUBJECT_ID} ---")
#     final_test_fmri, video_test = process_subject_data(TEST_SUBJECT_ID, fmri_tr, chosen_encoder)
#     print(final_test_fmri.shape, video_test.shape)

#     if final_test_fmri is None or video_test is None:
#         print(f"No test data could be processed for subject {TEST_SUBJECT_ID}. Cannot evaluate.")
#         # Decide whether to proceed with training only or exit
#     else:
#         print(f"--- Final Test Data Shapes (After PCA if Applied) ---")
#         print(f"Test fMRI data: {final_test_fmri.shape}")
#         print(f"Test Video embeddings: {video_test.shape}")

#         # --- Prepare Test DataLoader ---
#         try:
#             test_dataset = TensorDataset(torch.from_numpy(final_test_fmri).float(),
#                                             torch.from_numpy(video_test).float())
#             test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
#             print(f"Test Dataset ready (Size: {len(test_dataset)} samples).")
#             print(f"Test DataLoader ready (Size: {len(test_loader)} samples).")
#         except Exception as e:
#                 logger.error(f"Failed to create test DataLoader: {e}", exc_info=True)
#                 test_loader = None # Mark as None if failed
# else:
#         print("No TEST_SUBJECT_ID specified. Skipping test phase.")
#         test_loader = None


In [21]:
all_test_fmri_raw = []
all_test_video = []
for test_subj_id in TEST_SUBJECT_IDS:
    fmri_subj_raw, video_subj = process_subject_data(test_subj_id, fmri_tr, chosen_encoder)
    
    if fmri_subj_raw is not None and video_subj is not None:
        all_test_fmri_raw.append(fmri_subj_raw)
        all_test_video.append(video_subj)
    else:
        logger.warning(f"Could not process data for testing subject: {test_subj_id}")


--- Processing data for Subject: NSD108 ---
Processing NSD108 / defeat...
Loaded cached aligned data for NSD108/defeat
--- Final Concatenated Data Shapes for Subject NSD108 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD109 ---
Processing NSD109 / defeat...
Loaded cached aligned data for NSD109/defeat
--- Final Concatenated Data Shapes for Subject NSD109 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD110 ---
Processing NSD110 / defeat...
Loaded cached aligned data for NSD110/defeat
--- Final Concatenated Data Shapes for Subject NSD110 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)
--- Processing data for Subject: NSD111 ---
Processing NSD111 / defeat...
Loaded cached aligned data for NSD111/defeat
--- Final Concatenated Data Shapes for Subject NSD111 ---
Subject fMRI data: (471, 1000)
Subject Video embeddings: (471, 768)


In [22]:
final_test_fmri = np.concatenate(all_test_fmri_raw, axis=0)
final_test_video = np.concatenate(all_test_video, axis=0)
print(f"--- Combined Testing Data Shapes (Before PCA) ---")
print(f"Test fMRI (Raw/SRM): {final_test_fmri.shape}")
print(f"Test Video ({chosen_encoder}): {final_test_video.shape}")

--- Combined Testing Data Shapes (Before PCA) ---
Test fMRI (Raw/SRM): (1884, 1000)
Test Video (vitmae): (1884, 768)


In [23]:
try:
    test_dataset = TensorDataset(torch.from_numpy(final_test_fmri).float(),
                                    torch.from_numpy(final_test_video).float())
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    print(f"Test DataLoader ready (Size: {len(test_loader)} samples).")
except Exception as e:
    print(f"Failed to create test DataLoader: {e}", exc_info=True)



Test DataLoader ready (Size: 30 samples).


In [24]:
# --- Update config dims (Based on final *training* data) ---
ENC_OUTPUT_DIM = final_train_fmri.shape[1]
DEC_INPUT_DIM = final_train_fmri.shape[1]

# DEC_OUTPUT_DIM is already set based on video encoder choice

In [25]:
encoding_model = EncodingModel(
    video_embed_dim=DEC_OUTPUT_DIM,
    fmri_dim=ENC_OUTPUT_DIM,
    hidden_dim=ENC_HIDDEN_DIM
)
decoding_model = DecodingModel(
    fmri_dim=DEC_INPUT_DIM,
    video_embed_dim=DEC_OUTPUT_DIM,
    hidden_dim=DEC_HIDDEN_DIM
)

In [26]:
def run_training_no_val(model, train_loader, device, task='encoding', epochs=EPOCHS
                        , learning_rate=LEARNING_RATE):
    """ Simplified training loop without validation. """

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    model_type = "LSTM" if USE_TEMPORAL_MODELS else "MLP"
    
    if model_type == "LSTM" and task == 'encoding':
        model =  LSTMEncodingModel(
            video_embed_dim=DEC_OUTPUT_DIM,
            fmri_dim=ENC_OUTPUT_DIM,
            hidden_dim=ENC_HIDDEN_DIM
        )
    elif model_type == "LSTM" and task == 'decoding':
        model = LSTMDecodingModel(
            fmri_dim=DEC_INPUT_DIM,
            video_embed_dim=DEC_OUTPUT_DIM,
            hidden_dim=DEC_HIDDEN_DIM
        )
    elif model_type == "MLP" and task == 'encoding':
        model = EncodingModel(
            video_embed_dim=DEC_OUTPUT_DIM,
            fmri_dim=ENC_OUTPUT_DIM,
            hidden_dim=ENC_HIDDEN_DIM
        )
    else:
        model = DecodingModel(
            fmri_dim=DEC_INPUT_DIM,
            video_embed_dim=DEC_OUTPUT_DIM,
            hidden_dim=DEC_HIDDEN_DIM
        )
    

    print(f"Starting {task} model training ({model_type}, no validation) for {epochs} epochs on {device}")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    stimuli_str = "_".join(STIMULI_NAMES)

    run = f"{task}_{model_type}_{VIDEO_ENCODER_NAME}_{PCA_N_COMPONENTS if APPLY_PCA else 'nopca'}_{stimuli_str}"
    config = {
        "task": task,
        "model_type": model_type,
        "video_encoder": VIDEO_ENCODER_NAME,
        "pca_components": PCA_N_COMPONENTS if APPLY_PCA else "nopca",
        "stimuli": stimuli_str,
        "learning_rate": learning_rate,
        "batch_size": BATCH_SIZE,
        "epochs": epochs,
        "encoder_hidden_dim": ENC_HIDDEN_DIM,
        "decoder_hidden_dim": DEC_HIDDEN_DIM,
        "encoder_output_dim": ENC_OUTPUT_DIM,
        "decoder_input_dim": DEC_INPUT_DIM,
        "decoder_output_dim": DEC_OUTPUT_DIM,
        "train_subjects": TRAIN_SUBJECT_IDS,
        "test_subject": TEST_SUBJECT_ID,
    }
    wandb.init(project="csai_fmri_video_project", name=run, config=config, reinit=True) # Initialize wandb

    for epoch in range(epochs):
        # Pass use_temporal flag if train_epoch still expects it, otherwise remove
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, task)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f}")
        # No validation step, no saving best model based on validation
        print("> Evaluation on train")
        results, _, _ = evaluate_model(model, train_loader, nn.MSELoss(), device=DEVICE, task=task)
               
        print("\n> Evaluation on test")
        test_results, _, _ = evaluate_model(model, test_loader, nn.MSELoss(), device=DEVICE, task=task)
        wandb.log({
                "MSE": results['mse'], 
                "PearsonR Avg": results['pearsonr_avg'],
                "PearsonR All": np.median(results['pearsonr_all']),
                "R2 Uniform": results['r2_uniform'],
                "R2 Variance": results['r2_variance'],
                "Test MSE": test_results['mse'],
                "Test PearsonR Avg": test_results['pearsonr_avg'],
                "Test PearsonR All": np.median(test_results['pearsonr_all']),
                "Test R2 Uniform": test_results['r2_uniform'],
                "Test R2 Variance": test_results['r2_variance']
            })

    print(f"Training complete after {epochs} epochs. Using model from last epoch.")
    wandb.finish() # Finish wandb run
    return model # Return model from last epoch


In [27]:
encoding_model = run_training_no_val( # Use modified training loop
    encoding_model, train_loader, device=DEVICE, task='encoding',
    epochs=EPOCHS, learning_rate=LEARNING_RATE
)

Starting encoding model training (LSTM, no validation) for 50 epochs on cuda:0


[34m[1mwandb[0m: Currently logged in as: [33mmhardik003[0m ([33mwb-team-hardik[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


                                                                                 

Epoch 1/50 | Train Loss: 1.3319
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 1.0302
  Avg Pearson Correlation (across features/voxels): 0.0240
  Median Pearson Correlation: 0.0237
  R2 Score (uniform avg): -0.0336
  R2 Score (variance weighted): -0.0336

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0406
  Avg Pearson Correlation (across features/voxels): 0.0044
  Median Pearson Correlation: 0.0036
  R2 Score (uniform avg): -0.0416
  R2 Score (variance weighted): -0.0416


                                                                                 

Epoch 2/50 | Train Loss: 1.1909
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 1.0040
  Avg Pearson Correlation (across features/voxels): 0.0359
  Median Pearson Correlation: 0.0358
  R2 Score (uniform avg): -0.0073
  R2 Score (variance weighted): -0.0073

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0143
  Avg Pearson Correlation (across features/voxels): 0.0046
  Median Pearson Correlation: 0.0047
  R2 Score (uniform avg): -0.0153
  R2 Score (variance weighted): -0.0153


                                                                                 

Epoch 3/50 | Train Loss: 1.1544
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9974
  Avg Pearson Correlation (across features/voxels): 0.0459
  Median Pearson Correlation: 0.0452
  R2 Score (uniform avg): -0.0006
  R2 Score (variance weighted): -0.0006

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0075
  Avg Pearson Correlation (across features/voxels): 0.0057
  Median Pearson Correlation: 0.0048
  R2 Score (uniform avg): -0.0085
  R2 Score (variance weighted): -0.0085


                                                                                 

Epoch 4/50 | Train Loss: 1.1410
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9951
  Avg Pearson Correlation (across features/voxels): 0.0534
  Median Pearson Correlation: 0.0524
  R2 Score (uniform avg): 0.0016
  R2 Score (variance weighted): 0.0016

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0046
  Avg Pearson Correlation (across features/voxels): 0.0121
  Median Pearson Correlation: 0.0129
  R2 Score (uniform avg): -0.0056
  R2 Score (variance weighted): -0.0056


                                                                                 

Epoch 5/50 | Train Loss: 1.1333
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9933
  Avg Pearson Correlation (across features/voxels): 0.0619
  Median Pearson Correlation: 0.0612
  R2 Score (uniform avg): 0.0035
  R2 Score (variance weighted): 0.0035

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0033
  Avg Pearson Correlation (across features/voxels): 0.0120
  Median Pearson Correlation: 0.0114
  R2 Score (uniform avg): -0.0043
  R2 Score (variance weighted): -0.0043


                                                                                 

Epoch 6/50 | Train Loss: 1.1287
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9921
  Avg Pearson Correlation (across features/voxels): 0.0692
  Median Pearson Correlation: 0.0693
  R2 Score (uniform avg): 0.0046
  R2 Score (variance weighted): 0.0046

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0021
  Avg Pearson Correlation (across features/voxels): 0.0169
  Median Pearson Correlation: 0.0191
  R2 Score (uniform avg): -0.0031
  R2 Score (variance weighted): -0.0031


                                                                                 

Epoch 7/50 | Train Loss: 1.1239
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9910
  Avg Pearson Correlation (across features/voxels): 0.0765
  Median Pearson Correlation: 0.0766
  R2 Score (uniform avg): 0.0057
  R2 Score (variance weighted): 0.0057

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0022
  Avg Pearson Correlation (across features/voxels): 0.0192
  Median Pearson Correlation: 0.0195
  R2 Score (uniform avg): -0.0032
  R2 Score (variance weighted): -0.0032


                                                                                 

Epoch 8/50 | Train Loss: 1.1204
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9898
  Avg Pearson Correlation (across features/voxels): 0.0842
  Median Pearson Correlation: 0.0848
  R2 Score (uniform avg): 0.0070
  R2 Score (variance weighted): 0.0070

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0013
  Avg Pearson Correlation (across features/voxels): 0.0239
  Median Pearson Correlation: 0.0234
  R2 Score (uniform avg): -0.0023
  R2 Score (variance weighted): -0.0023


                                                                                 

Epoch 9/50 | Train Loss: 1.1164
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9894
  Avg Pearson Correlation (across features/voxels): 0.0870
  Median Pearson Correlation: 0.0875
  R2 Score (uniform avg): 0.0073
  R2 Score (variance weighted): 0.0073

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0023
  Avg Pearson Correlation (across features/voxels): 0.0254
  Median Pearson Correlation: 0.0252
  R2 Score (uniform avg): -0.0033
  R2 Score (variance weighted): -0.0033


                                                                                 

Epoch 10/50 | Train Loss: 1.1132
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9874
  Avg Pearson Correlation (across features/voxels): 0.0978
  Median Pearson Correlation: 0.0988
  R2 Score (uniform avg): 0.0094
  R2 Score (variance weighted): 0.0094

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0017
  Avg Pearson Correlation (across features/voxels): 0.0295
  Median Pearson Correlation: 0.0293
  R2 Score (uniform avg): -0.0027
  R2 Score (variance weighted): -0.0027


                                                                                 

Epoch 11/50 | Train Loss: 1.1097
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9862
  Avg Pearson Correlation (across features/voxels): 0.1033
  Median Pearson Correlation: 0.1043
  R2 Score (uniform avg): 0.0105
  R2 Score (variance weighted): 0.0105

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0018
  Avg Pearson Correlation (across features/voxels): 0.0333
  Median Pearson Correlation: 0.0332
  R2 Score (uniform avg): -0.0028
  R2 Score (variance weighted): -0.0028


                                                                                 

Epoch 12/50 | Train Loss: 1.1069
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9852
  Avg Pearson Correlation (across features/voxels): 0.1083
  Median Pearson Correlation: 0.1082
  R2 Score (uniform avg): 0.0116
  R2 Score (variance weighted): 0.0116

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0015
  Avg Pearson Correlation (across features/voxels): 0.0379
  Median Pearson Correlation: 0.0372
  R2 Score (uniform avg): -0.0025
  R2 Score (variance weighted): -0.0024


                                                                                  

Epoch 13/50 | Train Loss: 1.1038
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9843
  Avg Pearson Correlation (across features/voxels): 0.1124
  Median Pearson Correlation: 0.1126
  R2 Score (uniform avg): 0.0125
  R2 Score (variance weighted): 0.0125

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0002
  Avg Pearson Correlation (across features/voxels): 0.0452
  Median Pearson Correlation: 0.0435
  R2 Score (uniform avg): -0.0012
  R2 Score (variance weighted): -0.0012


                                                                                 

Epoch 14/50 | Train Loss: 1.1010
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9837
  Avg Pearson Correlation (across features/voxels): 0.1146
  Median Pearson Correlation: 0.1144
  R2 Score (uniform avg): 0.0131
  R2 Score (variance weighted): 0.0131

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0010
  Avg Pearson Correlation (across features/voxels): 0.0443
  Median Pearson Correlation: 0.0428
  R2 Score (uniform avg): -0.0020
  R2 Score (variance weighted): -0.0020


                                                                                  

Epoch 15/50 | Train Loss: 1.0980
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9830
  Avg Pearson Correlation (across features/voxels): 0.1176
  Median Pearson Correlation: 0.1178
  R2 Score (uniform avg): 0.0137
  R2 Score (variance weighted): 0.0137

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0013
  Avg Pearson Correlation (across features/voxels): 0.0433
  Median Pearson Correlation: 0.0430
  R2 Score (uniform avg): -0.0023
  R2 Score (variance weighted): -0.0023


                                                                                  

Epoch 16/50 | Train Loss: 1.0959
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9824
  Avg Pearson Correlation (across features/voxels): 0.1202
  Median Pearson Correlation: 0.1199
  R2 Score (uniform avg): 0.0144
  R2 Score (variance weighted): 0.0144

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0005
  Avg Pearson Correlation (across features/voxels): 0.0489
  Median Pearson Correlation: 0.0464
  R2 Score (uniform avg): -0.0015
  R2 Score (variance weighted): -0.0015


                                                                                  

Epoch 17/50 | Train Loss: 1.0927
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9820
  Avg Pearson Correlation (across features/voxels): 0.1222
  Median Pearson Correlation: 0.1221
  R2 Score (uniform avg): 0.0148
  R2 Score (variance weighted): 0.0148

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 0.9999
  Avg Pearson Correlation (across features/voxels): 0.0503
  Median Pearson Correlation: 0.0495
  R2 Score (uniform avg): -0.0009
  R2 Score (variance weighted): -0.0009


                                                                                  

Epoch 18/50 | Train Loss: 1.0911
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9813
  Avg Pearson Correlation (across features/voxels): 0.1243
  Median Pearson Correlation: 0.1238
  R2 Score (uniform avg): 0.0154
  R2 Score (variance weighted): 0.0154

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0007
  Avg Pearson Correlation (across features/voxels): 0.0501
  Median Pearson Correlation: 0.0487
  R2 Score (uniform avg): -0.0017
  R2 Score (variance weighted): -0.0017


                                                                                 

Epoch 19/50 | Train Loss: 1.0887
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9811
  Avg Pearson Correlation (across features/voxels): 0.1257
  Median Pearson Correlation: 0.1256
  R2 Score (uniform avg): 0.0157
  R2 Score (variance weighted): 0.0157

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 0.9997
  Avg Pearson Correlation (across features/voxels): 0.0530
  Median Pearson Correlation: 0.0524
  R2 Score (uniform avg): -0.0007
  R2 Score (variance weighted): -0.0007


                                                                                  

Epoch 20/50 | Train Loss: 1.0865
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9808
  Avg Pearson Correlation (across features/voxels): 0.1265
  Median Pearson Correlation: 0.1262
  R2 Score (uniform avg): 0.0159
  R2 Score (variance weighted): 0.0159

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 0.9997
  Avg Pearson Correlation (across features/voxels): 0.0551
  Median Pearson Correlation: 0.0558
  R2 Score (uniform avg): -0.0007
  R2 Score (variance weighted): -0.0007


                                                                                  

Epoch 21/50 | Train Loss: 1.0849
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9804
  Avg Pearson Correlation (across features/voxels): 0.1279
  Median Pearson Correlation: 0.1275
  R2 Score (uniform avg): 0.0163
  R2 Score (variance weighted): 0.0163

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0004
  Avg Pearson Correlation (across features/voxels): 0.0526
  Median Pearson Correlation: 0.0529
  R2 Score (uniform avg): -0.0014
  R2 Score (variance weighted): -0.0014


                                                                                  

Epoch 22/50 | Train Loss: 1.0831
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9802
  Avg Pearson Correlation (across features/voxels): 0.1293
  Median Pearson Correlation: 0.1287
  R2 Score (uniform avg): 0.0166
  R2 Score (variance weighted): 0.0166

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0008
  Avg Pearson Correlation (across features/voxels): 0.0499
  Median Pearson Correlation: 0.0478
  R2 Score (uniform avg): -0.0018
  R2 Score (variance weighted): -0.0018


                                                                                  

Epoch 23/50 | Train Loss: 1.0812
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9796
  Avg Pearson Correlation (across features/voxels): 0.1309
  Median Pearson Correlation: 0.1303
  R2 Score (uniform avg): 0.0172
  R2 Score (variance weighted): 0.0172

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0002
  Avg Pearson Correlation (across features/voxels): 0.0554
  Median Pearson Correlation: 0.0545
  R2 Score (uniform avg): -0.0012
  R2 Score (variance weighted): -0.0012


                                                                                  

Epoch 24/50 | Train Loss: 1.0792
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9795
  Avg Pearson Correlation (across features/voxels): 0.1314
  Median Pearson Correlation: 0.1309
  R2 Score (uniform avg): 0.0173
  R2 Score (variance weighted): 0.0173

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0021
  Avg Pearson Correlation (across features/voxels): 0.0496
  Median Pearson Correlation: 0.0483
  R2 Score (uniform avg): -0.0031
  R2 Score (variance weighted): -0.0031


                                                                                  

Epoch 25/50 | Train Loss: 1.0771
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9791
  Avg Pearson Correlation (across features/voxels): 0.1326
  Median Pearson Correlation: 0.1322
  R2 Score (uniform avg): 0.0177
  R2 Score (variance weighted): 0.0177

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0007
  Avg Pearson Correlation (across features/voxels): 0.0562
  Median Pearson Correlation: 0.0545
  R2 Score (uniform avg): -0.0017
  R2 Score (variance weighted): -0.0017


                                                                                  

Epoch 26/50 | Train Loss: 1.0754
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9788
  Avg Pearson Correlation (across features/voxels): 0.1339
  Median Pearson Correlation: 0.1336
  R2 Score (uniform avg): 0.0180
  R2 Score (variance weighted): 0.0180

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0014
  Avg Pearson Correlation (across features/voxels): 0.0551
  Median Pearson Correlation: 0.0532
  R2 Score (uniform avg): -0.0024
  R2 Score (variance weighted): -0.0024


                                                                                  

Epoch 27/50 | Train Loss: 1.0736
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9785
  Avg Pearson Correlation (across features/voxels): 0.1354
  Median Pearson Correlation: 0.1352
  R2 Score (uniform avg): 0.0183
  R2 Score (variance weighted): 0.0183

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 0.9997
  Avg Pearson Correlation (across features/voxels): 0.0577
  Median Pearson Correlation: 0.0561
  R2 Score (uniform avg): -0.0007
  R2 Score (variance weighted): -0.0007


                                                                                  

Epoch 28/50 | Train Loss: 1.0715
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9784
  Avg Pearson Correlation (across features/voxels): 0.1357
  Median Pearson Correlation: 0.1351
  R2 Score (uniform avg): 0.0184
  R2 Score (variance weighted): 0.0184

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0013
  Avg Pearson Correlation (across features/voxels): 0.0533
  Median Pearson Correlation: 0.0528
  R2 Score (uniform avg): -0.0023
  R2 Score (variance weighted): -0.0023


                                                                                  

Epoch 29/50 | Train Loss: 1.0698
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9779
  Avg Pearson Correlation (across features/voxels): 0.1374
  Median Pearson Correlation: 0.1369
  R2 Score (uniform avg): 0.0188
  R2 Score (variance weighted): 0.0189

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0015
  Avg Pearson Correlation (across features/voxels): 0.0544
  Median Pearson Correlation: 0.0555
  R2 Score (uniform avg): -0.0025
  R2 Score (variance weighted): -0.0025


                                                                                  

Epoch 30/50 | Train Loss: 1.0687
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9781
  Avg Pearson Correlation (across features/voxels): 0.1366
  Median Pearson Correlation: 0.1362
  R2 Score (uniform avg): 0.0187
  R2 Score (variance weighted): 0.0187

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0019
  Avg Pearson Correlation (across features/voxels): 0.0526
  Median Pearson Correlation: 0.0525
  R2 Score (uniform avg): -0.0029
  R2 Score (variance weighted): -0.0029


                                                                                  

Epoch 31/50 | Train Loss: 1.0673
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9772
  Avg Pearson Correlation (across features/voxels): 0.1397
  Median Pearson Correlation: 0.1389
  R2 Score (uniform avg): 0.0195
  R2 Score (variance weighted): 0.0195

> Evaluation on test
Evaluating encoding model...


                                                                       

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0007
  Avg Pearson Correlation (across features/voxels): 0.0578
  Median Pearson Correlation: 0.0574
  R2 Score (uniform avg): -0.0017
  R2 Score (variance weighted): -0.0017


                                                                                  

Epoch 32/50 | Train Loss: 1.0653
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9772
  Avg Pearson Correlation (across features/voxels): 0.1400
  Median Pearson Correlation: 0.1397
  R2 Score (uniform avg): 0.0196
  R2 Score (variance weighted): 0.0196

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0031
  Avg Pearson Correlation (across features/voxels): 0.0505
  Median Pearson Correlation: 0.0495
  R2 Score (uniform avg): -0.0041
  R2 Score (variance weighted): -0.0041


                                                                                  

Epoch 33/50 | Train Loss: 1.0637
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9769
  Avg Pearson Correlation (across features/voxels): 0.1410
  Median Pearson Correlation: 0.1400
  R2 Score (uniform avg): 0.0199
  R2 Score (variance weighted): 0.0199

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0018
  Avg Pearson Correlation (across features/voxels): 0.0556
  Median Pearson Correlation: 0.0537
  R2 Score (uniform avg): -0.0028
  R2 Score (variance weighted): -0.0028


                                                                                  

Epoch 34/50 | Train Loss: 1.0623
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9765
  Avg Pearson Correlation (across features/voxels): 0.1422
  Median Pearson Correlation: 0.1425
  R2 Score (uniform avg): 0.0203
  R2 Score (variance weighted): 0.0203

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0021
  Avg Pearson Correlation (across features/voxels): 0.0554
  Median Pearson Correlation: 0.0541
  R2 Score (uniform avg): -0.0031
  R2 Score (variance weighted): -0.0031


                                                                                  

Epoch 35/50 | Train Loss: 1.0603
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9762
  Avg Pearson Correlation (across features/voxels): 0.1438
  Median Pearson Correlation: 0.1427
  R2 Score (uniform avg): 0.0206
  R2 Score (variance weighted): 0.0206

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0017
  Avg Pearson Correlation (across features/voxels): 0.0543
  Median Pearson Correlation: 0.0534
  R2 Score (uniform avg): -0.0027
  R2 Score (variance weighted): -0.0027


                                                                                  

Epoch 36/50 | Train Loss: 1.0593
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9760
  Avg Pearson Correlation (across features/voxels): 0.1443
  Median Pearson Correlation: 0.1438
  R2 Score (uniform avg): 0.0208
  R2 Score (variance weighted): 0.0208

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0015
  Avg Pearson Correlation (across features/voxels): 0.0570
  Median Pearson Correlation: 0.0567
  R2 Score (uniform avg): -0.0025
  R2 Score (variance weighted): -0.0025


                                                                                  

Epoch 37/50 | Train Loss: 1.0574
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9758
  Avg Pearson Correlation (across features/voxels): 0.1448
  Median Pearson Correlation: 0.1433
  R2 Score (uniform avg): 0.0210
  R2 Score (variance weighted): 0.0210

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0033
  Avg Pearson Correlation (across features/voxels): 0.0511
  Median Pearson Correlation: 0.0513
  R2 Score (uniform avg): -0.0043
  R2 Score (variance weighted): -0.0043


                                                                                  

Epoch 38/50 | Train Loss: 1.0562
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9755
  Avg Pearson Correlation (across features/voxels): 0.1461
  Median Pearson Correlation: 0.1451
  R2 Score (uniform avg): 0.0213
  R2 Score (variance weighted): 0.0213

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0035
  Avg Pearson Correlation (across features/voxels): 0.0504
  Median Pearson Correlation: 0.0481
  R2 Score (uniform avg): -0.0045
  R2 Score (variance weighted): -0.0045


                                                                                  

Epoch 39/50 | Train Loss: 1.0546
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9751
  Avg Pearson Correlation (across features/voxels): 0.1471
  Median Pearson Correlation: 0.1464
  R2 Score (uniform avg): 0.0217
  R2 Score (variance weighted): 0.0217

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0029
  Avg Pearson Correlation (across features/voxels): 0.0545
  Median Pearson Correlation: 0.0551
  R2 Score (uniform avg): -0.0039
  R2 Score (variance weighted): -0.0039


                                                                                  

Epoch 40/50 | Train Loss: 1.0533
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9752
  Avg Pearson Correlation (across features/voxels): 0.1477
  Median Pearson Correlation: 0.1472
  R2 Score (uniform avg): 0.0216
  R2 Score (variance weighted): 0.0216

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0028
  Avg Pearson Correlation (across features/voxels): 0.0517
  Median Pearson Correlation: 0.0496
  R2 Score (uniform avg): -0.0038
  R2 Score (variance weighted): -0.0038


                                                                                  

Epoch 41/50 | Train Loss: 1.0517
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9746
  Avg Pearson Correlation (across features/voxels): 0.1491
  Median Pearson Correlation: 0.1483
  R2 Score (uniform avg): 0.0222
  R2 Score (variance weighted): 0.0222

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0044
  Avg Pearson Correlation (across features/voxels): 0.0493
  Median Pearson Correlation: 0.0473
  R2 Score (uniform avg): -0.0054
  R2 Score (variance weighted): -0.0054


                                                                                  

Epoch 42/50 | Train Loss: 1.0504
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9745
  Avg Pearson Correlation (across features/voxels): 0.1499
  Median Pearson Correlation: 0.1492
  R2 Score (uniform avg): 0.0223
  R2 Score (variance weighted): 0.0223

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0029
  Avg Pearson Correlation (across features/voxels): 0.0528
  Median Pearson Correlation: 0.0515
  R2 Score (uniform avg): -0.0039
  R2 Score (variance weighted): -0.0039


                                                                                  

Epoch 43/50 | Train Loss: 1.0489
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9740
  Avg Pearson Correlation (across features/voxels): 0.1508
  Median Pearson Correlation: 0.1502
  R2 Score (uniform avg): 0.0228
  R2 Score (variance weighted): 0.0228

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0045
  Avg Pearson Correlation (across features/voxels): 0.0517
  Median Pearson Correlation: 0.0514
  R2 Score (uniform avg): -0.0055
  R2 Score (variance weighted): -0.0055


                                                                                  

Epoch 44/50 | Train Loss: 1.0475
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9738
  Avg Pearson Correlation (across features/voxels): 0.1516
  Median Pearson Correlation: 0.1508
  R2 Score (uniform avg): 0.0230
  R2 Score (variance weighted): 0.0230

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0043
  Avg Pearson Correlation (across features/voxels): 0.0532
  Median Pearson Correlation: 0.0505
  R2 Score (uniform avg): -0.0053
  R2 Score (variance weighted): -0.0053


                                                                                  

Epoch 45/50 | Train Loss: 1.0462
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9737
  Avg Pearson Correlation (across features/voxels): 0.1523
  Median Pearson Correlation: 0.1512
  R2 Score (uniform avg): 0.0231
  R2 Score (variance weighted): 0.0231

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0041
  Avg Pearson Correlation (across features/voxels): 0.0509
  Median Pearson Correlation: 0.0486
  R2 Score (uniform avg): -0.0051
  R2 Score (variance weighted): -0.0051


                                                                                  

Epoch 46/50 | Train Loss: 1.0452
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9734
  Avg Pearson Correlation (across features/voxels): 0.1533
  Median Pearson Correlation: 0.1527
  R2 Score (uniform avg): 0.0234
  R2 Score (variance weighted): 0.0234

> Evaluation on test
Evaluating encoding model...


                                                                       

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0041
  Avg Pearson Correlation (across features/voxels): 0.0511
  Median Pearson Correlation: 0.0495
  R2 Score (uniform avg): -0.0051
  R2 Score (variance weighted): -0.0051


                                                                                  

Epoch 47/50 | Train Loss: 1.0438
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9731
  Avg Pearson Correlation (across features/voxels): 0.1540
  Median Pearson Correlation: 0.1536
  R2 Score (uniform avg): 0.0237
  R2 Score (variance weighted): 0.0237

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0048
  Avg Pearson Correlation (across features/voxels): 0.0511
  Median Pearson Correlation: 0.0495
  R2 Score (uniform avg): -0.0058
  R2 Score (variance weighted): -0.0058


                                                                                  

Epoch 48/50 | Train Loss: 1.0425
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9729
  Avg Pearson Correlation (across features/voxels): 0.1548
  Median Pearson Correlation: 0.1540
  R2 Score (uniform avg): 0.0239
  R2 Score (variance weighted): 0.0239

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0039
  Avg Pearson Correlation (across features/voxels): 0.0532
  Median Pearson Correlation: 0.0516
  R2 Score (uniform avg): -0.0049
  R2 Score (variance weighted): -0.0049


                                                                                  

Epoch 49/50 | Train Loss: 1.0406
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9727
  Avg Pearson Correlation (across features/voxels): 0.1555
  Median Pearson Correlation: 0.1546
  R2 Score (uniform avg): 0.0241
  R2 Score (variance weighted): 0.0241

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0041
  Avg Pearson Correlation (across features/voxels): 0.0528
  Median Pearson Correlation: 0.0506
  R2 Score (uniform avg): -0.0051
  R2 Score (variance weighted): -0.0051


                                                                                  

Epoch 50/50 | Train Loss: 1.0397
> Evaluation on train
Evaluating encoding model...


                                                                         

Evaluation complete. Target shape: (17427, 1000), Prediction shape: (17427, 1000)
  MSE: 0.9724
  Avg Pearson Correlation (across features/voxels): 0.1563
  Median Pearson Correlation: 0.1558
  R2 Score (uniform avg): 0.0244
  R2 Score (variance weighted): 0.0244

> Evaluation on test
Evaluating encoding model...


                                                                     

Evaluation complete. Target shape: (1884, 1000), Prediction shape: (1884, 1000)
  MSE: 1.0054
  Avg Pearson Correlation (across features/voxels): 0.0499
  Median Pearson Correlation: 0.0473
  R2 Score (uniform avg): -0.0064
  R2 Score (variance weighted): -0.0064
Training complete after 50 epochs. Using model from last epoch.


0,1
MSE,█▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
PearsonR All,▁▂▂▃▃▄▄▄▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████
PearsonR Avg,▁▂▃▃▃▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
R2 Uniform,▁▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████
R2 Variance,▁▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
Test MSE,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▂▂▂▂▂▂▂▂▂
Test PearsonR All,▁▁▁▂▂▃▄▄▄▅▆▆▆▇▇▇█▇▇██▇█▇██▇███▇█▇▇▇▇▇▇▇▇
Test PearsonR Avg,▁▁▂▂▃▄▄▄▅▅▆▆▇▇▇█▇█▇██▇█▇█████▇█▇▇▇▇▇▇▇▇▇
Test R2 Uniform,▁▆▇▇▇█████████████████████▇███▇▇▇▇▇▇▇▇▇▇
Test R2 Variance,▁▆▇▇▇████████████████████████▇▇▇▇▇▇▇▇▇▇▇

0,1
MSE,0.97236
PearsonR All,0.15585
PearsonR Avg,0.15633
R2 Uniform,0.02444
R2 Variance,0.02444
Test MSE,1.00544
Test PearsonR All,0.04729
Test PearsonR Avg,0.04986
Test R2 Uniform,-0.00645
Test R2 Variance,-0.00644


In [28]:
# if test_loader:
#     print(f"\n--- Evaluating Encoding Model on Subj {TEST_SUBJECT_ID} ---")
#     enc_results, enc_targets, enc_predictions = evaluate_model(
#             encoding_model, test_loader, nn.MSELoss(), device=DEVICE, task='encoding'
#     )
#     # --- Add encoder/subject info to output files ---
#     enc_model_suffix = f"_mlp_{chosen_encoder}_train{''.join(TRAIN_SUBJECT_IDS)}_test{TEST_SUBJECT_ID}"
#     enc_model_path = OUTPUT_DIR / f"encoding_model{enc_model_suffix}.pt"
#     torch.save(encoding_model.state_dict(), enc_model_path)
#     print(f"Saved Encoding model to {enc_model_path}")

#     enc_fig = plot_predictions(enc_targets, enc_predictions, n_samples=5, title=f"Encoding (MLP, {chosen_encoder}) Test Subj {TEST_SUBJECT_ID}")
#     enc_plot_path = OUTPUT_DIR / f"encoding_predictions{enc_model_suffix}.png"
#     enc_fig.savefig(enc_plot_path)
#     plt.close(enc_fig)
#     print(f"Saved Encoding prediction plot to {enc_plot_path}")
    
# else:
#         print("Skipping encoding model evaluation as no test loader was created.")


### --- Decoding Model Training & Evaluation ---


In [29]:
print(f"\n--- Training Decoding Model (MLP, Video: {chosen_encoder}) on Subj {TRAIN_SUBJECT_IDS} ---")
if DEVICE == 'cuda' or 'cuda:0' or 'cuda:2' or 'cuda:1':
    # Clear memory if possible
    del encoding_model
    if 'enc_targets' in locals(): del enc_targets
    if 'enc_predictions' in locals(): del enc_predictions
    torch.cuda.empty_cache()
    print("Cleared CUDA cache before starting decoding model training.")


--- Training Decoding Model (MLP, Video: vitmae) on Subj ['NSD103', 'NSD104', 'NSD105', 'NSD106', 'NSD107', 'NSD113', 'NSD114', 'NSD115', 'NSD116', 'NSD117', 'NSD119', 'NSD120', 'NSD122', 'NSD123', 'NSD124', 'NSD125', 'NSD126', 'NSD127', 'NSD128', 'NSD129', 'NSD130', 'NSD132', 'NSD134', 'NSD135', 'NSD136', 'NSD138', 'NSD140', 'NSD142', 'NSD145', 'NSD146', 'NSD147', 'NSD148', 'NSD149', 'NSD150', 'NSD151', 'NSD153', 'NSD155'] ---
Cleared CUDA cache before starting decoding model training.


In [30]:
decoding_model = run_training_no_val( # Use modified training loop
    decoding_model, train_loader, device=DEVICE, task='decoding',
    epochs=EPOCHS, learning_rate=LEARNING_RATE
)

if test_loader:
    print(f"\n--- Evaluating Decoding Model on Subj {TEST_SUBJECT_ID} ---")
    dec_results, dec_targets, dec_predictions = evaluate_model(
        decoding_model, test_loader, nn.MSELoss(), device=DEVICE, task='decoding'
    )
    dec_model_suffix = f"_mlp_{chosen_encoder}_train{''.join(TRAIN_SUBJECT_IDS)}_test{TEST_SUBJECT_ID}"
    dec_model_path = OUTPUT_DIR / f"decoding_model{dec_model_suffix}.pt"
    # torch.save(decoding_model.state_dict(), dec_model_path)
    print(f"Saved Decoding model to {dec_model_path}")

    dec_fig = plot_predictions(dec_targets, dec_predictions, n_samples=5, title=f"Decoding (MLP, {chosen_encoder}) Test Subj {TEST_SUBJECT_ID}")
    dec_plot_path = OUTPUT_DIR / f"decoding_predictions{dec_model_suffix}.png"
    # dec_fig.savefig(dec_plot_path)
    plt.close(dec_fig)
    print(f"Saved Decoding prediction plot to {dec_plot_path}")
else:
        print("Skipping decoding model evaluation as no test loader was created.")


Starting decoding model training (LSTM, no validation) for 50 epochs on cuda:0


                                                                                  

Epoch 1/50 | Train Loss: 1.1213
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 1.0132
  Avg Pearson Correlation (across features/voxels): 0.0455
  Median Pearson Correlation: 0.0445
  R2 Score (uniform avg): -0.0132
  R2 Score (variance weighted): -0.0132

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0297
  Avg Pearson Correlation (across features/voxels): -0.0013
  Median Pearson Correlation: -0.0007
  R2 Score (uniform avg): -0.0297
  R2 Score (variance weighted): -0.0297


                                                                                  

Epoch 2/50 | Train Loss: 1.0503
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.9950
  Avg Pearson Correlation (across features/voxels): 0.0724
  Median Pearson Correlation: 0.0691
  R2 Score (uniform avg): 0.0050
  R2 Score (variance weighted): 0.0050

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0063
  Avg Pearson Correlation (across features/voxels): -0.0039
  Median Pearson Correlation: -0.0044
  R2 Score (uniform avg): -0.0063
  R2 Score (variance weighted): -0.0063


                                                                                  

Epoch 3/50 | Train Loss: 1.0353
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.9896
  Avg Pearson Correlation (across features/voxels): 0.1093
  Median Pearson Correlation: 0.1018
  R2 Score (uniform avg): 0.0104
  R2 Score (variance weighted): 0.0104

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0042
  Avg Pearson Correlation (across features/voxels): 0.0054
  Median Pearson Correlation: 0.0057
  R2 Score (uniform avg): -0.0042
  R2 Score (variance weighted): -0.0042


                                                                                  

Epoch 4/50 | Train Loss: 1.0247
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.9726
  Avg Pearson Correlation (across features/voxels): 0.1713
  Median Pearson Correlation: 0.1607
  R2 Score (uniform avg): 0.0274
  R2 Score (variance weighted): 0.0274

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0081
  Avg Pearson Correlation (across features/voxels): 0.0069
  Median Pearson Correlation: 0.0075
  R2 Score (uniform avg): -0.0081
  R2 Score (variance weighted): -0.0081


                                                                                  

Epoch 5/50 | Train Loss: 1.0010
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.9442
  Avg Pearson Correlation (across features/voxels): 0.2398
  Median Pearson Correlation: 0.2338
  R2 Score (uniform avg): 0.0558
  R2 Score (variance weighted): 0.0558

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0178
  Avg Pearson Correlation (across features/voxels): 0.0079
  Median Pearson Correlation: 0.0059
  R2 Score (uniform avg): -0.0178
  R2 Score (variance weighted): -0.0178


                                                                                  

Epoch 6/50 | Train Loss: 0.9705
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.9096
  Avg Pearson Correlation (across features/voxels): 0.3031
  Median Pearson Correlation: 0.2940
  R2 Score (uniform avg): 0.0904
  R2 Score (variance weighted): 0.0904

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0302
  Avg Pearson Correlation (across features/voxels): 0.0026
  Median Pearson Correlation: 0.0026
  R2 Score (uniform avg): -0.0302
  R2 Score (variance weighted): -0.0302


                                                                                  

Epoch 7/50 | Train Loss: 0.9405
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.8740
  Avg Pearson Correlation (across features/voxels): 0.3551
  Median Pearson Correlation: 0.3463
  R2 Score (uniform avg): 0.1260
  R2 Score (variance weighted): 0.1260

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0416
  Avg Pearson Correlation (across features/voxels): 0.0038
  Median Pearson Correlation: 0.0047
  R2 Score (uniform avg): -0.0416
  R2 Score (variance weighted): -0.0416


                                                                                  

Epoch 8/50 | Train Loss: 0.9132
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.8427
  Avg Pearson Correlation (across features/voxels): 0.3928
  Median Pearson Correlation: 0.3862
  R2 Score (uniform avg): 0.1573
  R2 Score (variance weighted): 0.1573

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0597
  Avg Pearson Correlation (across features/voxels): -0.0036
  Median Pearson Correlation: -0.0038
  R2 Score (uniform avg): -0.0597
  R2 Score (variance weighted): -0.0597


                                                                                  

Epoch 9/50 | Train Loss: 0.8901
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.8166
  Avg Pearson Correlation (across features/voxels): 0.4252
  Median Pearson Correlation: 0.4193
  R2 Score (uniform avg): 0.1834
  R2 Score (variance weighted): 0.1834

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0686
  Avg Pearson Correlation (across features/voxels): -0.0067
  Median Pearson Correlation: -0.0057
  R2 Score (uniform avg): -0.0686
  R2 Score (variance weighted): -0.0686


                                                                                  

Epoch 10/50 | Train Loss: 0.8695
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7896
  Avg Pearson Correlation (across features/voxels): 0.4539
  Median Pearson Correlation: 0.4486
  R2 Score (uniform avg): 0.2104
  R2 Score (variance weighted): 0.2104

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0807
  Avg Pearson Correlation (across features/voxels): -0.0085
  Median Pearson Correlation: -0.0088
  R2 Score (uniform avg): -0.0807
  R2 Score (variance weighted): -0.0807


                                                                                  

Epoch 11/50 | Train Loss: 0.8504
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7735
  Avg Pearson Correlation (across features/voxels): 0.4750
  Median Pearson Correlation: 0.4718
  R2 Score (uniform avg): 0.2265
  R2 Score (variance weighted): 0.2265

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0789
  Avg Pearson Correlation (across features/voxels): -0.0085
  Median Pearson Correlation: -0.0081
  R2 Score (uniform avg): -0.0789
  R2 Score (variance weighted): -0.0789


                                                                                  

Epoch 12/50 | Train Loss: 0.8330
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7544
  Avg Pearson Correlation (across features/voxels): 0.4924
  Median Pearson Correlation: 0.4900
  R2 Score (uniform avg): 0.2456
  R2 Score (variance weighted): 0.2456

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0885
  Avg Pearson Correlation (across features/voxels): -0.0086
  Median Pearson Correlation: -0.0091
  R2 Score (uniform avg): -0.0885
  R2 Score (variance weighted): -0.0885


                                                                                  

Epoch 13/50 | Train Loss: 0.8189
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7392
  Avg Pearson Correlation (across features/voxels): 0.5073
  Median Pearson Correlation: 0.5029
  R2 Score (uniform avg): 0.2608
  R2 Score (variance weighted): 0.2608

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.0939
  Avg Pearson Correlation (across features/voxels): -0.0122
  Median Pearson Correlation: -0.0118
  R2 Score (uniform avg): -0.0939
  R2 Score (variance weighted): -0.0939


                                                                                  

Epoch 14/50 | Train Loss: 0.8054
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7219
  Avg Pearson Correlation (across features/voxels): 0.5237
  Median Pearson Correlation: 0.5211
  R2 Score (uniform avg): 0.2781
  R2 Score (variance weighted): 0.2781

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1007
  Avg Pearson Correlation (across features/voxels): -0.0115
  Median Pearson Correlation: -0.0118
  R2 Score (uniform avg): -0.1007
  R2 Score (variance weighted): -0.1007


                                                                                  

Epoch 15/50 | Train Loss: 0.7916
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.7071
  Avg Pearson Correlation (across features/voxels): 0.5377
  Median Pearson Correlation: 0.5340
  R2 Score (uniform avg): 0.2929
  R2 Score (variance weighted): 0.2929

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1015
  Avg Pearson Correlation (across features/voxels): -0.0074
  Median Pearson Correlation: -0.0074
  R2 Score (uniform avg): -0.1015
  R2 Score (variance weighted): -0.1015


                                                                                  

Epoch 16/50 | Train Loss: 0.7792
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6961
  Avg Pearson Correlation (across features/voxels): 0.5501
  Median Pearson Correlation: 0.5470
  R2 Score (uniform avg): 0.3039
  R2 Score (variance weighted): 0.3039

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1021
  Avg Pearson Correlation (across features/voxels): -0.0109
  Median Pearson Correlation: -0.0114
  R2 Score (uniform avg): -0.1021
  R2 Score (variance weighted): -0.1021


                                                                                  

Epoch 17/50 | Train Loss: 0.7685
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6835
  Avg Pearson Correlation (across features/voxels): 0.5595
  Median Pearson Correlation: 0.5569
  R2 Score (uniform avg): 0.3165
  R2 Score (variance weighted): 0.3165

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1049
  Avg Pearson Correlation (across features/voxels): -0.0109
  Median Pearson Correlation: -0.0099
  R2 Score (uniform avg): -0.1049
  R2 Score (variance weighted): -0.1049


                                                                                  

Epoch 18/50 | Train Loss: 0.7573
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6690
  Avg Pearson Correlation (across features/voxels): 0.5736
  Median Pearson Correlation: 0.5717
  R2 Score (uniform avg): 0.3310
  R2 Score (variance weighted): 0.3310

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1046
  Avg Pearson Correlation (across features/voxels): -0.0101
  Median Pearson Correlation: -0.0098
  R2 Score (uniform avg): -0.1046
  R2 Score (variance weighted): -0.1046


                                                                                  

Epoch 19/50 | Train Loss: 0.7467
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6537
  Avg Pearson Correlation (across features/voxels): 0.5851
  Median Pearson Correlation: 0.5827
  R2 Score (uniform avg): 0.3463
  R2 Score (variance weighted): 0.3463

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1147
  Avg Pearson Correlation (across features/voxels): -0.0121
  Median Pearson Correlation: -0.0107
  R2 Score (uniform avg): -0.1147
  R2 Score (variance weighted): -0.1147


                                                                                  

Epoch 20/50 | Train Loss: 0.7356
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6453
  Avg Pearson Correlation (across features/voxels): 0.5940
  Median Pearson Correlation: 0.5917
  R2 Score (uniform avg): 0.3547
  R2 Score (variance weighted): 0.3547

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1111
  Avg Pearson Correlation (across features/voxels): -0.0116
  Median Pearson Correlation: -0.0106
  R2 Score (uniform avg): -0.1111
  R2 Score (variance weighted): -0.1111


                                                                                  

Epoch 21/50 | Train Loss: 0.7254
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6373
  Avg Pearson Correlation (across features/voxels): 0.6032
  Median Pearson Correlation: 0.6002
  R2 Score (uniform avg): 0.3627
  R2 Score (variance weighted): 0.3627

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1021
  Avg Pearson Correlation (across features/voxels): -0.0107
  Median Pearson Correlation: -0.0107
  R2 Score (uniform avg): -0.1021
  R2 Score (variance weighted): -0.1021


                                                                                  

Epoch 22/50 | Train Loss: 0.7176
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6213
  Avg Pearson Correlation (across features/voxels): 0.6141
  Median Pearson Correlation: 0.6118
  R2 Score (uniform avg): 0.3787
  R2 Score (variance weighted): 0.3787

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1108
  Avg Pearson Correlation (across features/voxels): -0.0106
  Median Pearson Correlation: -0.0105
  R2 Score (uniform avg): -0.1108
  R2 Score (variance weighted): -0.1108


                                                                                  

Epoch 23/50 | Train Loss: 0.7061
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6091
  Avg Pearson Correlation (across features/voxels): 0.6230
  Median Pearson Correlation: 0.6208
  R2 Score (uniform avg): 0.3909
  R2 Score (variance weighted): 0.3909

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1171
  Avg Pearson Correlation (across features/voxels): -0.0079
  Median Pearson Correlation: -0.0082
  R2 Score (uniform avg): -0.1171
  R2 Score (variance weighted): -0.1171


                                                                                  

Epoch 24/50 | Train Loss: 0.6977
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.6006
  Avg Pearson Correlation (across features/voxels): 0.6308
  Median Pearson Correlation: 0.6295
  R2 Score (uniform avg): 0.3994
  R2 Score (variance weighted): 0.3994

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1152
  Avg Pearson Correlation (across features/voxels): -0.0119
  Median Pearson Correlation: -0.0120
  R2 Score (uniform avg): -0.1152
  R2 Score (variance weighted): -0.1152


                                                                                  

Epoch 25/50 | Train Loss: 0.6893
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5937
  Avg Pearson Correlation (across features/voxels): 0.6366
  Median Pearson Correlation: 0.6349
  R2 Score (uniform avg): 0.4063
  R2 Score (variance weighted): 0.4063

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1121
  Avg Pearson Correlation (across features/voxels): -0.0064
  Median Pearson Correlation: -0.0066
  R2 Score (uniform avg): -0.1121
  R2 Score (variance weighted): -0.1121


                                                                                  

Epoch 26/50 | Train Loss: 0.6799
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5814
  Avg Pearson Correlation (across features/voxels): 0.6461
  Median Pearson Correlation: 0.6447
  R2 Score (uniform avg): 0.4186
  R2 Score (variance weighted): 0.4186

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1132
  Avg Pearson Correlation (across features/voxels): -0.0073
  Median Pearson Correlation: -0.0068
  R2 Score (uniform avg): -0.1132
  R2 Score (variance weighted): -0.1132


                                                                                  

Epoch 27/50 | Train Loss: 0.6711
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5732
  Avg Pearson Correlation (across features/voxels): 0.6529
  Median Pearson Correlation: 0.6523
  R2 Score (uniform avg): 0.4268
  R2 Score (variance weighted): 0.4268

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1122
  Avg Pearson Correlation (across features/voxels): -0.0088
  Median Pearson Correlation: -0.0082
  R2 Score (uniform avg): -0.1122
  R2 Score (variance weighted): -0.1122


                                                                                  

Epoch 28/50 | Train Loss: 0.6641
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5654
  Avg Pearson Correlation (across features/voxels): 0.6593
  Median Pearson Correlation: 0.6589
  R2 Score (uniform avg): 0.4346
  R2 Score (variance weighted): 0.4346

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1126
  Avg Pearson Correlation (across features/voxels): -0.0090
  Median Pearson Correlation: -0.0083
  R2 Score (uniform avg): -0.1126
  R2 Score (variance weighted): -0.1126


                                                                                  

Epoch 29/50 | Train Loss: 0.6560
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5541
  Avg Pearson Correlation (across features/voxels): 0.6674
  Median Pearson Correlation: 0.6672
  R2 Score (uniform avg): 0.4459
  R2 Score (variance weighted): 0.4459

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1216
  Avg Pearson Correlation (across features/voxels): -0.0116
  Median Pearson Correlation: -0.0095
  R2 Score (uniform avg): -0.1216
  R2 Score (variance weighted): -0.1216


                                                                                  

Epoch 30/50 | Train Loss: 0.6478
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5470
  Avg Pearson Correlation (across features/voxels): 0.6729
  Median Pearson Correlation: 0.6732
  R2 Score (uniform avg): 0.4530
  R2 Score (variance weighted): 0.4530

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1150
  Avg Pearson Correlation (across features/voxels): -0.0107
  Median Pearson Correlation: -0.0099
  R2 Score (uniform avg): -0.1150
  R2 Score (variance weighted): -0.1150


                                                                                  

Epoch 31/50 | Train Loss: 0.6395
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5383
  Avg Pearson Correlation (across features/voxels): 0.6799
  Median Pearson Correlation: 0.6808
  R2 Score (uniform avg): 0.4617
  R2 Score (variance weighted): 0.4617

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1143
  Avg Pearson Correlation (across features/voxels): -0.0111
  Median Pearson Correlation: -0.0099
  R2 Score (uniform avg): -0.1143
  R2 Score (variance weighted): -0.1143


                                                                                  

Epoch 32/50 | Train Loss: 0.6318
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5291
  Avg Pearson Correlation (across features/voxels): 0.6849
  Median Pearson Correlation: 0.6858
  R2 Score (uniform avg): 0.4709
  R2 Score (variance weighted): 0.4709

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1216
  Avg Pearson Correlation (across features/voxels): -0.0122
  Median Pearson Correlation: -0.0107
  R2 Score (uniform avg): -0.1216
  R2 Score (variance weighted): -0.1216


                                                                                  

Epoch 33/50 | Train Loss: 0.6247
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5209
  Avg Pearson Correlation (across features/voxels): 0.6905
  Median Pearson Correlation: 0.6923
  R2 Score (uniform avg): 0.4791
  R2 Score (variance weighted): 0.4791

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1198
  Avg Pearson Correlation (across features/voxels): -0.0118
  Median Pearson Correlation: -0.0106
  R2 Score (uniform avg): -0.1198
  R2 Score (variance weighted): -0.1198


                                                                                  

Epoch 34/50 | Train Loss: 0.6167
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5169
  Avg Pearson Correlation (across features/voxels): 0.6952
  Median Pearson Correlation: 0.6972
  R2 Score (uniform avg): 0.4831
  R2 Score (variance weighted): 0.4831

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1120
  Avg Pearson Correlation (across features/voxels): -0.0116
  Median Pearson Correlation: -0.0108
  R2 Score (uniform avg): -0.1120
  R2 Score (variance weighted): -0.1120


                                                                                  

Epoch 35/50 | Train Loss: 0.6091
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5103
  Avg Pearson Correlation (across features/voxels): 0.7000
  Median Pearson Correlation: 0.7008
  R2 Score (uniform avg): 0.4897
  R2 Score (variance weighted): 0.4897

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1071
  Avg Pearson Correlation (across features/voxels): -0.0093
  Median Pearson Correlation: -0.0082
  R2 Score (uniform avg): -0.1071
  R2 Score (variance weighted): -0.1071


                                                                                  

Epoch 36/50 | Train Loss: 0.6030
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.5017
  Avg Pearson Correlation (across features/voxels): 0.7049
  Median Pearson Correlation: 0.7058
  R2 Score (uniform avg): 0.4983
  R2 Score (variance weighted): 0.4983

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1163
  Avg Pearson Correlation (across features/voxels): -0.0139
  Median Pearson Correlation: -0.0130
  R2 Score (uniform avg): -0.1163
  R2 Score (variance weighted): -0.1163


                                                                                  

Epoch 37/50 | Train Loss: 0.5975
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4932
  Avg Pearson Correlation (across features/voxels): 0.7107
  Median Pearson Correlation: 0.7115
  R2 Score (uniform avg): 0.5068
  R2 Score (variance weighted): 0.5068

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1185
  Avg Pearson Correlation (across features/voxels): -0.0110
  Median Pearson Correlation: -0.0093
  R2 Score (uniform avg): -0.1185
  R2 Score (variance weighted): -0.1185


                                                                                  

Epoch 38/50 | Train Loss: 0.5900
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4866
  Avg Pearson Correlation (across features/voxels): 0.7152
  Median Pearson Correlation: 0.7158
  R2 Score (uniform avg): 0.5134
  R2 Score (variance weighted): 0.5134

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1165
  Avg Pearson Correlation (across features/voxels): -0.0089
  Median Pearson Correlation: -0.0079
  R2 Score (uniform avg): -0.1165
  R2 Score (variance weighted): -0.1165


                                                                                  

Epoch 39/50 | Train Loss: 0.5826
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4799
  Avg Pearson Correlation (across features/voxels): 0.7210
  Median Pearson Correlation: 0.7216
  R2 Score (uniform avg): 0.5201
  R2 Score (variance weighted): 0.5201

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1115
  Avg Pearson Correlation (across features/voxels): -0.0122
  Median Pearson Correlation: -0.0119
  R2 Score (uniform avg): -0.1115
  R2 Score (variance weighted): -0.1115


                                                                                  

Epoch 40/50 | Train Loss: 0.5764
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4730
  Avg Pearson Correlation (across features/voxels): 0.7249
  Median Pearson Correlation: 0.7248
  R2 Score (uniform avg): 0.5270
  R2 Score (variance weighted): 0.5270

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1140
  Avg Pearson Correlation (across features/voxels): -0.0108
  Median Pearson Correlation: -0.0103
  R2 Score (uniform avg): -0.1140
  R2 Score (variance weighted): -0.1140


                                                                                  

Epoch 41/50 | Train Loss: 0.5715
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4686
  Avg Pearson Correlation (across features/voxels): 0.7281
  Median Pearson Correlation: 0.7279
  R2 Score (uniform avg): 0.5314
  R2 Score (variance weighted): 0.5314

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1154
  Avg Pearson Correlation (across features/voxels): -0.0122
  Median Pearson Correlation: -0.0116
  R2 Score (uniform avg): -0.1154
  R2 Score (variance weighted): -0.1154


                                                                                  

Epoch 42/50 | Train Loss: 0.5646
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4587
  Avg Pearson Correlation (across features/voxels): 0.7346
  Median Pearson Correlation: 0.7346
  R2 Score (uniform avg): 0.5413
  R2 Score (variance weighted): 0.5413

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1134
  Avg Pearson Correlation (across features/voxels): -0.0102
  Median Pearson Correlation: -0.0098
  R2 Score (uniform avg): -0.1134
  R2 Score (variance weighted): -0.1134


                                                                                  

Epoch 43/50 | Train Loss: 0.5587
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4548
  Avg Pearson Correlation (across features/voxels): 0.7372
  Median Pearson Correlation: 0.7376
  R2 Score (uniform avg): 0.5452
  R2 Score (variance weighted): 0.5452

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1147
  Avg Pearson Correlation (across features/voxels): -0.0132
  Median Pearson Correlation: -0.0126
  R2 Score (uniform avg): -0.1147
  R2 Score (variance weighted): -0.1147


                                                                                  

Epoch 44/50 | Train Loss: 0.5534
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4479
  Avg Pearson Correlation (across features/voxels): 0.7422
  Median Pearson Correlation: 0.7427
  R2 Score (uniform avg): 0.5521
  R2 Score (variance weighted): 0.5521

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1110
  Avg Pearson Correlation (across features/voxels): -0.0072
  Median Pearson Correlation: -0.0060
  R2 Score (uniform avg): -0.1110
  R2 Score (variance weighted): -0.1110


                                                                                  

Epoch 45/50 | Train Loss: 0.5469
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4460
  Avg Pearson Correlation (across features/voxels): 0.7437
  Median Pearson Correlation: 0.7438
  R2 Score (uniform avg): 0.5540
  R2 Score (variance weighted): 0.5540

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1069
  Avg Pearson Correlation (across features/voxels): -0.0092
  Median Pearson Correlation: -0.0089
  R2 Score (uniform avg): -0.1069
  R2 Score (variance weighted): -0.1069


                                                                                  

Epoch 46/50 | Train Loss: 0.5408
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4383
  Avg Pearson Correlation (across features/voxels): 0.7490
  Median Pearson Correlation: 0.7503
  R2 Score (uniform avg): 0.5617
  R2 Score (variance weighted): 0.5617

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1077
  Avg Pearson Correlation (across features/voxels): -0.0107
  Median Pearson Correlation: -0.0109
  R2 Score (uniform avg): -0.1077
  R2 Score (variance weighted): -0.1077


                                                                                  

Epoch 47/50 | Train Loss: 0.5356
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4312
  Avg Pearson Correlation (across features/voxels): 0.7533
  Median Pearson Correlation: 0.7549
  R2 Score (uniform avg): 0.5687
  R2 Score (variance weighted): 0.5688

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1094
  Avg Pearson Correlation (across features/voxels): -0.0104
  Median Pearson Correlation: -0.0101
  R2 Score (uniform avg): -0.1094
  R2 Score (variance weighted): -0.1094


                                                                                  

Epoch 48/50 | Train Loss: 0.5305
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4263
  Avg Pearson Correlation (across features/voxels): 0.7568
  Median Pearson Correlation: 0.7582
  R2 Score (uniform avg): 0.5737
  R2 Score (variance weighted): 0.5737

> Evaluation on test
Evaluating decoding model...


                                                                       

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1068
  Avg Pearson Correlation (across features/voxels): -0.0116
  Median Pearson Correlation: -0.0111
  R2 Score (uniform avg): -0.1068
  R2 Score (variance weighted): -0.1068


                                                                                  

Epoch 49/50 | Train Loss: 0.5241
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4213
  Avg Pearson Correlation (across features/voxels): 0.7600
  Median Pearson Correlation: 0.7615
  R2 Score (uniform avg): 0.5787
  R2 Score (variance weighted): 0.5787

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1052
  Avg Pearson Correlation (across features/voxels): -0.0094
  Median Pearson Correlation: -0.0097
  R2 Score (uniform avg): -0.1052
  R2 Score (variance weighted): -0.1052


                                                                                  

Epoch 50/50 | Train Loss: 0.5192
> Evaluation on train
Evaluating decoding model...


                                                                         

Evaluation complete. Target shape: (17427, 768), Prediction shape: (17427, 768)
  MSE: 0.4140
  Avg Pearson Correlation (across features/voxels): 0.7645
  Median Pearson Correlation: 0.7663
  R2 Score (uniform avg): 0.5860
  R2 Score (variance weighted): 0.5860

> Evaluation on test
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1039
  Avg Pearson Correlation (across features/voxels): -0.0069
  Median Pearson Correlation: -0.0056
  R2 Score (uniform avg): -0.1039
  R2 Score (variance weighted): -0.1039
Training complete after 50 epochs. Using model from last epoch.


0,1
MSE,███▇▇▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
PearsonR All,▁▁▂▂▃▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██████████
PearsonR Avg,▁▁▂▃▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
R2 Uniform,▁▁▁▁▂▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████
R2 Variance,▁▁▁▁▂▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████
Test MSE,▃▁▁▁▂▃▄▅▅▆▇▇▇▇▇▇▇▇██▇▇▇████▇██▇████▇▇▇▇▇
Test PearsonR All,▅▄▇█▇▄▄▂▃▂▁▃▂▂▂▂▂▂▁▃▃▃▂▂▂▂▂▃▁▂▁▂▁▂▁▂▂▂▂▄
Test PearsonR Avg,▅▄▇██▇▄▃▃▃▂▂▃▂▂▂▂▂▃▂▃▃▂▂▂▂▂▂▁▂▂▂▂▂▁▂▂▂▂▃
Test R2 Uniform,▆███▇▅▄▃▄▃▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▂▂▁▁▂▁▁▁▁▂▂▂▂▂
Test R2 Variance,▆███▆▅▄▃▄▃▂▂▂▂▂▂▂▂▁▁▂▂▂▁▁▁▂▂▁▁▂▁▁▁▁▂▂▂▂▂

0,1
MSE,0.41405
PearsonR All,0.76626
PearsonR Avg,0.76448
R2 Uniform,0.58595
R2 Variance,0.58595
Test MSE,1.10393
Test PearsonR All,-0.00556
Test PearsonR Avg,-0.00686
Test R2 Uniform,-0.10393
Test R2 Variance,-0.10393



--- Evaluating Decoding Model on Subj NSD114 ---
Evaluating decoding model...


                                                                     

Evaluation complete. Target shape: (1884, 768), Prediction shape: (1884, 768)
  MSE: 1.1039
  Avg Pearson Correlation (across features/voxels): -0.0069
  Median Pearson Correlation: -0.0056
  R2 Score (uniform avg): -0.1039
  R2 Score (variance weighted): -0.1039
Saved Decoding model to /workspace/hardik/output_csai/decoding_model_mlp_vitmae_trainNSD103NSD104NSD105NSD106NSD107NSD113NSD114NSD115NSD116NSD117NSD119NSD120NSD122NSD123NSD124NSD125NSD126NSD127NSD128NSD129NSD130NSD132NSD134NSD135NSD136NSD138NSD140NSD142NSD145NSD146NSD147NSD148NSD149NSD150NSD151NSD153NSD155_testNSD114.pt
Saved Decoding prediction plot to /workspace/hardik/output_csai/decoding_predictions_mlp_vitmae_trainNSD103NSD104NSD105NSD106NSD107NSD113NSD114NSD115NSD116NSD117NSD119NSD120NSD122NSD123NSD124NSD125NSD126NSD127NSD128NSD129NSD130NSD132NSD134NSD135NSD136NSD138NSD140NSD142NSD145NSD146NSD147NSD148NSD149NSD150NSD151NSD153NSD155_testNSD114.png
