In [None]:
import sys
import os
import random
import re
import argparse
import warnings

# Environment configuration
os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY_HERE"  # Replace with your actual key
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_SILENT"] = "true"

# Path configuration
current_dir = os.getcwd()
# Build project root path (assuming parent_dir and model are at the same level)
project_root = os.path.abspath(os.path.join(current_dir, ".."))

# Add project root to sys.path for absolute imports
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Model imports
from model.unified_encoder_multi_tower import UnifiedEncoder

# Deep learning framework imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset

# Experiment tracking
import wandb
wandb.init(mode="disabled")

# Dataset imports
from data_preparing.eegdatasets import EEGDataset
from data_preparing.megdatasets_averaged import MEGDataset
from data_preparing.fmri_datasets_joint_subjects import fMRIDataset
from data_preparing.datasets_mixer import (
    MetaEEGDataset,
    MetaMEGDataset,
    MetafMRIDataset,
    MetaDataLoader
)

# Evaluation metrics
from sklearn.metrics import confusion_matrix

# Custom loss functions
from loss import ClipLoss

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)


5
5


In [None]:
import re
import random
import torch


def extract_id_from_string(s):
    """Extract numeric ID from the end of a string.
    
    Args:
        s (str): Input string containing numeric ID at the end
        
    Returns:
        int or None: Extracted numeric ID or None if not found
    """
    match = re.search(r'\d+$', s)
    if match:
        return int(match.group())
    return None


def get_eegfeatures(unified_model, dataloader, device, text_features_all, 
                   img_features_all, k, eval_modality, test_classes):
    """Evaluate model performance on neural data and compute features.
    
    Args:
        unified_model: Trained unified encoder model
        dataloader: DataLoader for evaluation data
        device: Device to run evaluation on
        text_features_all: Pre-computed text features for all classes
        img_features_all: Pre-computed image features for all classes
        k: Number of classes to sample for evaluation
        eval_modality: Modality being evaluated ('eeg', 'meg', 'fmri')
        test_classes: Total number of test classes
        
    Returns:
        tuple: (average_loss, accuracy, top5_accuracy, labels, features_tensor)
    """
    unified_model.eval()
    
    # Prepare features based on modality
    text_features_all = text_features_all[eval_modality].to(device).float()
    
    if eval_modality in ['eeg', 'fmri']:
        img_features_all = img_features_all[eval_modality].to(device).float()
    elif eval_modality == 'meg':
        # Sample every 12th feature for MEG data
        img_features_all = img_features_all[eval_modality][::12].to(device).float()
    
    # Initialize evaluation metrics
    total_loss = 0
    correct = 0
    top5_correct_count = 0
    total = 0
    
    loss_func = ClipLoss()
    all_labels = set(range(text_features_all.size(0)))
    
    # Feature saving configuration
    save_features = False
    features_list = []
    features_tensor = torch.zeros(0, 0)
    sub = 'sub-02'  # Subject identifier for saving
    
    with torch.no_grad():
        for batch_idx, (modal, data, labels, text, text_features, 
                       img, img_features, _, _, sub_ids) in enumerate(dataloader):
            
            # Move data to device
            data = data.to(device)
            text_features = text_features.to(device).float()
            labels = labels.to(device)
            img_features = img_features.to(device).float()
            
            # Extract subject IDs
            subject_ids = [extract_id_from_string(sub_id) for sub_id in sub_ids]
            subject_ids = torch.tensor(subject_ids, dtype=torch.long).to(device)
            
            # Get neural features from model
            neural_features = unified_model(data, subject_ids, modal=eval_modality)
            logit_scale = unified_model.logit_scale.float()
            
            # Store features for potential saving
            features_list.append(neural_features)
            
            # Compute loss
            img_loss = loss_func(neural_features, img_features, logit_scale)
            loss = img_loss
            total_loss += loss.item()
            
            # Evaluate each sample in the batch
            for idx, label in enumerate(labels):
                # Sample k-1 random classes plus the true class
                possible_classes = list(all_labels - {label.item()})
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
                selected_img_features = img_features_all[selected_classes]
                
                # Compute logits for image features
                logits_img = logit_scale * neural_features[idx] @ selected_img_features.T
                logits_single = logits_img
                
                # Get prediction
                predicted_label = selected_classes[torch.argmax(logits_single).item()]
                
                # Check top-1 accuracy
                if predicted_label == label.item():
                    correct += 1
                
                # Check top-5 accuracy (only when k equals test_classes)
                if k == test_classes:
                    _, top5_indices = torch.topk(logits_single, 5, largest=True)
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:
                        top5_correct_count += 1
                
                total += 1
        
        # Save features if enabled
        if save_features:
            features_tensor = torch.cat(features_list, dim=0)
            print(f"Features tensor shape: {features_tensor.shape}")
            torch.save(features_tensor.cpu(), f"ATM_S_neural_features_{sub}_train.pt")
    
    # Calculate final metrics
    average_loss = total_loss / (batch_idx + 1)
    accuracy = correct / total
    top5_acc = top5_correct_count / total
    
    return average_loss, accuracy, top5_acc, labels, features_tensor.cpu()


# ============================================================================
# Configuration Parameters
# ============================================================================

# Model paths for different modalities
ENCODER_PATHS = {
    'eeg': '/mnt/dataset1/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-06_01-46/150.pth',
    'meg': '/mnt/dataset1/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-11_14-50/150.pth',
    'fmri': '/mnt/dataset1/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-18_01-35/150.pth'
}

# Evaluation configuration
EVAL_MODALITY = 'eeg'  # Options: 'eeg', 'meg', 'fmri'

# Subject configurations for different modalities
SUBJECTS_CONFIG = {
    'eeg': ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 
            'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'],
    'meg': ['sub-01', 'sub-02', 'sub-03', 'sub-04'],
    'fmri': ['sub-01', 'sub-02', 'sub-03']
}

# Class configurations for different modalities
CLASSES_CONFIG = {
    'eeg': 200,
    'meg': 200,
    'fmri': 100
}

# Available modalities
MODALITIES = ['eeg', 'meg', 'fmri']

# Update configuration based on evaluation modality
if EVAL_MODALITY not in MODALITIES:
    raise ValueError(f"Unsupported modality: {EVAL_MODALITY}")

test_subjects = SUBJECTS_CONFIG[EVAL_MODALITY]
test_classes = CLASSES_CONFIG[EVAL_MODALITY]

# Dataset paths
DATASET_PATHS = {
    'eeg': "/mnt/dataset1/ldy/datasets/THINGS_EEG1/processed_250Hz",
    'meg': "/home/ldy/THINGS-MEG/preprocessed_newsplit",
    'fmri': "/home/ldy/fmri_dataset/Preprocessed"
}

# Output configuration (for logging and saving results)
OUTPUT_CONFIG = {
    'output_dir': './outputs/contrast',
    'project': "train_pos_img_text_rep",
    'entity': "sustech_rethinkingbci",
    'name': "lr=3e-4_img_pos_pro_eeg"
}

# Device configuration
DEVICE_CONFIG = {
    'device_preference': 'cuda:3',  # Options: 'cuda:0', 'cuda:1', 'cpu'
    'device_type': 'gpu'  # Options: 'cpu', 'gpu'
}

# Print configuration summary
print("=" * 60)
print("EVALUATION CONFIGURATION")
print("=" * 60)
print(f"Evaluation Modality: {EVAL_MODALITY}")
print(f"Test Subjects: {test_subjects}")
print(f"Number of Test Classes: {test_classes}")
print(f"Device: {DEVICE_CONFIG['device_preference']}")
print("=" * 60)

Evaluation Modality: eeg
Test Subjects: ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
Number of Test Classes: 200


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader


def format_num(num):
    """Format number with appropriate unit (K, M, B, T, P).
    
    Args:
        num (int): Number to format
        
    Returns:
        str: Formatted number with unit
    """
    for unit in ['', 'K', 'M', 'B', 'T']:
        if num < 1000:
            return f"{num:.2f}{unit}"
        num /= 1000
    return f"{num:.2f}P"


def print_model_info(model):
    """Print model parameter information.
    
    Args:
        model: PyTorch model to analyze
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {format_num(total_params)}")
    print(f"Trainable parameters: {format_num(trainable_params)}")
    
    if total_params > 0:
        trainable_percentage = (trainable_params / total_params) * 100
        print(f"Trainable parameters percentage: {trainable_percentage:.2f}%")
    else:
        print("Total parameters count is zero, cannot compute percentage.")


# ============================================================================
# Model Initialization and Setup
# ============================================================================

# Parse encoder paths from configuration
encoder_paths = {}
for path in ENCODER_PATHS.items():
    key, value = path
    encoder_paths[key] = value

# Set device based on configuration
device = torch.device(
    DEVICE_CONFIG['device_preference'] 
    if DEVICE_CONFIG['device_type'] == 'gpu' and torch.cuda.is_available() 
    else 'cpu'
)
print(f"Using device: {device}")

# Initialize feature storage dictionaries
text_features_test_all = {}
img_features_test_all = {}

# Initialize the Unified Encoder Model
print("Initializing Unified Encoder Model...")
unified_model = UnifiedEncoder(encoder_paths, device)

# Load pre-trained model weights
MODEL_PATH = "/mnt/dataset1/ldy/Workspace/FLORA/models/contrast/across/Unified_EEG+MEG+fMRI_EEG/01-27_02-32/60.pth"
unified_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
unified_model.to(device)
unified_model.eval()  # Set model to evaluation mode

# Print model parameter information
print_model_info(unified_model)

# ============================================================================
# Model Evaluation Loop
# ============================================================================

print(f"\nStarting evaluation on {EVAL_MODALITY} modality...")
print(f"Testing on subjects: {test_subjects}")
print("=" * 60)

# Initialize accuracy tracking lists
test_accuracies = []
test_accuracies_top5 = []
v2_accuracies = []
v4_accuracies = []
v10_accuracies = []

# Evaluate each subject
for sub in test_subjects:
    print(f"\nEvaluating subject: {sub}")
    print("-" * 30)
    
    # Prepare test dataset based on evaluation modality
    if EVAL_MODALITY == 'eeg':
        test_dataset = EEGDataset(
            DATASET_PATHS['eeg'], 
            subjects=[sub], 
            train=False
        )
    elif EVAL_MODALITY == 'meg':
        test_dataset = MEGDataset(
            DATASET_PATHS['meg'], 
            subjects=[sub], 
            train=False
        )
    elif EVAL_MODALITY == 'fmri':
        test_dataset = fMRIDataset(
            DATASET_PATHS['fmri'], 
            adap_subject=sub, 
            subjects=[sub], 
            train=False
        )
    
    # Extract features from dataset
    text_features_test_all[EVAL_MODALITY] = test_dataset.text_features
    img_features_test_all[EVAL_MODALITY] = test_dataset.img_features
    
    # Create data loader
    test_loader = DataLoader(
        test_dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=0, 
        drop_last=False
    )
    
    # Evaluate with different k values
    # Full test classes evaluation
    test_loss, test_accuracy, top5_acc, labels, eeg_features_test = get_eegfeatures(
        unified_model, test_loader, device, text_features_test_all, 
        img_features_test_all, k=test_classes, eval_modality=EVAL_MODALITY, 
        test_classes=test_classes
    )
    
    # k=2 evaluation (binary classification)
    _, v2_acc, _, _, _ = get_eegfeatures(
        unified_model, test_loader, device, text_features_test_all, 
        img_features_test_all, k=2, eval_modality=EVAL_MODALITY, 
        test_classes=test_classes
    )
    
    # k=4 evaluation (4-way classification)
    _, v4_acc, _, _, _ = get_eegfeatures(
        unified_model, test_loader, device, text_features_test_all, 
        img_features_test_all, k=4, eval_modality=EVAL_MODALITY, 
        test_classes=test_classes
    )
    
    # k=10 evaluation (10-way classification)
    _, v10_acc, _, _, _ = get_eegfeatures(
        unified_model, test_loader, device, text_features_test_all, 
        img_features_test_all, k=10, eval_modality=EVAL_MODALITY, 
        test_classes=test_classes
    )
    
    # Store accuracies
    test_accuracies.append(test_accuracy)
    test_accuracies_top5.append(top5_acc)
    v2_accuracies.append(v2_acc)
    v4_accuracies.append(v4_acc)
    v10_accuracies.append(v10_acc)
    
    # Print subject results
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy (k={test_classes}): {test_accuracy:.4f}")
    print(f"Top-5 Accuracy: {top5_acc:.4f}")
    print(f"Binary Accuracy (k=2): {v2_acc:.4f}")
    print(f"4-way Accuracy (k=4): {v4_acc:.4f}")
    print(f"10-way Accuracy (k=10): {v10_acc:.4f}")

# ============================================================================
# Results Summary
# ============================================================================

# Calculate average accuracies across all subjects
average_test_accuracy = np.mean(test_accuracies)
average_test_accuracy_top5 = np.mean(test_accuracies_top5)
average_v2_acc = np.mean(v2_accuracies)
average_v4_acc = np.mean(v4_accuracies)
average_v10_acc = np.mean(v10_accuracies)

print("\n" + "=" * 60)
print("FINAL RESULTS - AVERAGE ACROSS ALL SUBJECTS")
print("=" * 60)
print(f"Average Test Accuracy (k={test_classes}): {average_test_accuracy:.4f}")
print(f"Average Top-5 Accuracy: {average_test_accuracy_top5:.4f}")
print(f"Average Binary Accuracy (k=2): {average_v2_acc:.4f}")
print(f"Average 4-way Accuracy (k=4): {average_v4_acc:.4f}")
print(f"Average 10-way Accuracy (k=10): {average_v10_acc:.4f}")
print("=" * 60)

# Optional: Print individual subject results for reference
print("\nINDIVIDUAL SUBJECT RESULTS:")
print("-" * 40)
for i, sub in enumerate(test_subjects):
    print(f"{sub}: Acc={test_accuracies[i]:.4f}, "
          f"Top5={test_accuracies_top5[i]:.4f}, "
          f"k=2={v2_accuracies[i]:.4f}, "
          f"k=4={v4_accuracies[i]:.4f}, "
          f"k=10={v10_accuracies[i]:.4f}")
print("-" * 40)

Using device: cuda:3
Total parameters: 161.96M
Trainable parameters: 7.36M
Trainable parameters percentage: 4.54%


AssertionError: 