In [1]:
import os
import sys
import re
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
import sys
# 获取当前工作目录（假设 Notebook 位于 parent_dir）
current_dir = os.getcwd()

# 构建项目根目录的路径（假设 parent_dir 和 model 同级）
project_root = os.path.abspath(os.path.join(current_dir, '..'))

# 将项目根目录添加到 sys.path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# 现在可以使用绝对导入
from data_preparing.megdatasets_averaged import MEGDataset
# 现在可以使用绝对导入
from model.MEG_MedformerTS import meg_encoder
from loss import ClipLoss

5


  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:

def extract_id_from_string(s):
    match = re.search(r'\d+$', s)
    if match:
        return int(match.group())
    return None

def get_megfeatures(sub, meg_model, dataloader, device, text_features_all, img_features_all, k, eval_modality, test_classes):
    meg_model.eval()
    text_features_all = text_features_all.to(device).float()
    img_features_all = img_features_all[::12].to(device).float()
    total_loss = 0
    correct = 0
    top5_correct_count=0
    total = 0
    loss_func = ClipLoss() 
    all_labels = set(range(text_features_all.size(0)))
    save_features = False
    features_list = []  # List to store features    
    features_tensor = torch.zeros(0, 0)
    
    with torch.no_grad():
        for batch_idx, (_, data, labels, text, text_features, img, img_features, index, img_index, subject_id) in enumerate(dataloader):
            data = data.to(device)
            text_features = text_features.to(device).float()
            labels = labels.to(device)
            img_features = img_features.to(device).float()
            
            batch_size = data.size(0) 
            subject_id = extract_id_from_string(subject_id[0])
            # data = data.permute(0, 2, 1)
            subject_ids = torch.full((batch_size,), subject_id, dtype=torch.long).to(device)
            neural_features = meg_model(data, subject_ids)
            
            logit_scale = meg_model.logit_scale.float()            
            features_list.append(neural_features)
               
            img_loss = loss_func(neural_features, img_features, logit_scale)
            loss = img_loss        
            total_loss += loss.item()
            
            for idx, label in enumerate(labels):

                possible_classes = list(all_labels - {label.item()})
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
                selected_img_features = img_features_all[selected_classes]
                

                logits_img = logit_scale * neural_features[idx] @ selected_img_features.T
                # logits_text = logit_scale * neural_features[idx] @ selected_text_features.T
                # logits_single = (logits_text + logits_img) / 2.0
                logits_single = logits_img
                # print("logits_single", logits_single.shape)

                # predicted_label = selected_classes[torch.argmax(logits_single).item()]
                predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
                if predicted_label == label.item():
                    correct += 1       
                     
                if k==test_classes:
                    _, top5_indices = torch.topk(logits_single, 5, largest =True)
                                                            
                    # Check if the ground truth label is among the top-5 predictions
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:                
                        top5_correct_count+=1                                 
                total += 1                    


        if save_features:
            features_tensor = torch.cat(features_list, dim=0)
            print("features_tensor", features_tensor.shape)
            torch.save(features_tensor.cpu(), f"ATM_S_neural_features_{sub}_train.pt")  # Save features as .pt file
    average_loss = total_loss / (batch_idx+1)
    accuracy = correct / total    
    top5_acc = top5_correct_count / total    
    return average_loss, accuracy, top5_acc, labels, features_tensor.cpu()

In [3]:
test_subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04']
# Inference Parameters
device_preference = 'cuda:4'  # e.g., 'cuda:0' or 'cpu'
device_type = 'gpu'  # 'cpu' or 'gpu'
data_path = "/home/ldy/THINGS-MEG/preprocessed_newsplit"
# Set device based on the argument
device = torch.device(device_preference if device_type == 'gpu' and torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

test_classes = 200
eval_modality = 'meg'  # Modality to evaluate on
# encoder_path='/mnt/dataset0/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-11_16-25/150.pth'


meg_model = meg_encoder()
meg_model.load_state_dict(torch.load("/mnt/dataset0/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-11_14-50/150.pth"))
meg_model.to(device)
meg_model.eval()  # Set model to evaluation mode



Using device: cuda:4


  meg_model.load_state_dict(torch.load("/mnt/dataset0/ldy/Workspace/EEG_Image_decode/Retrieval/models/contrast/across/ATMS/01-11_14-50/150.pth"))


meg_encoder(
  (encoder): Medformer(
    (enc_embedding): ListPatchEmbedding(
      (value_embeddings): ModuleList(
        (0): CrossChannelTokenEmbedding(
          (tokenConv): Conv2d(1, 250, kernel_size=(271, 2), stride=(1, 2), bias=False, padding_mode=circular)
        )
        (1): CrossChannelTokenEmbedding(
          (tokenConv): Conv2d(1, 250, kernel_size=(271, 4), stride=(1, 4), bias=False, padding_mode=circular)
        )
        (2): CrossChannelTokenEmbedding(
          (tokenConv): Conv2d(1, 250, kernel_size=(271, 8), stride=(1, 8), bias=False, padding_mode=circular)
        )
      )
      (position_embedding): PositionalEmbedding()
      (dropout): Dropout(p=0.25, inplace=False)
      (augmentation): ModuleList(
        (0): Flip()
        (1): Shuffle()
        (2): FrequencyMask()
        (3): Jitter()
        (4): TemporalMask()
        (5): Dropout(p=0.1, inplace=False)
      )
      (learnable_embeddings): ParameterList(
          (0): Parameter containing: [torch

In [4]:
#####################################################################################
import numpy as np  # 导入numpy用于计算平均值
test_accuracies = []
test_accuracies_top5 = []
v2_accuracies = []
v4_accuracies = []
v10_accuracies = []

for sub in test_subjects:
    test_dataset = MEGDataset(data_path, adap_subject=sub, subjects=test_subjects, train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
    
    text_features_test_all = test_dataset.text_features    
    img_features_test_all = test_dataset.img_features
    
    test_loss, test_accuracy, top5_acc, labels, meg_features_test = get_megfeatures(
        sub, meg_model, test_loader, device, text_features_test_all, img_features_test_all, k=test_classes, eval_modality=eval_modality, test_classes=test_classes
    )
    _, v2_acc, _, _, _ = get_megfeatures(
        sub, meg_model, test_loader, device, text_features_test_all, img_features_test_all, k=2, eval_modality=eval_modality, test_classes=test_classes
    )
    _, v4_acc, _, _, _ = get_megfeatures(
        sub, meg_model, test_loader, device, text_features_test_all, img_features_test_all, k=4, eval_modality=eval_modality, test_classes=test_classes
    )
    _, v10_acc, _, _, _ = get_megfeatures(
        sub, meg_model, test_loader, device, text_features_test_all, img_features_test_all, k=10, eval_modality=eval_modality, test_classes=test_classes
    )    
    
    test_accuracies.append(test_accuracy)
    test_accuracies_top5.append(top5_acc)
    v2_accuracies.append(v2_acc)
    v4_accuracies.append(v4_acc)
    v10_accuracies.append(v10_acc)
    
    print(f" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Top5 Accuracy: {top5_acc:.4f}")    
    print(f" - Test Loss: {test_loss:.4f}, v2_acc Accuracy: {v2_acc:.4f}")
    print(f" - Test Loss: {test_loss:.4f}, v4_acc Accuracy: {v4_acc:.4f}")
    print(f" - Test Loss: {test_loss:.4f}, v10_acc Accuracy: {v10_acc:.4f}")

# 计算各项指标的平均准确率
average_test_accuracy = np.mean(test_accuracies)
average_test_accuracy_top5 = np.mean(test_accuracies_top5)
average_v2_acc = np.mean(v2_accuracies)
average_v4_acc = np.mean(v4_accuracies)
average_v10_acc = np.mean(v10_accuracies)

print(f"\nAverage Test Accuracy across all subjects: {average_test_accuracy:.4f}")
print(f"\nAverage Test Top5 Accuracy across all subjects: {average_test_accuracy_top5:.4f}")
print(f"Average v2_acc Accuracy across all subjects: {average_v2_acc:.4f}")
print(f"Average v4_acc Accuracy across all subjects: {average_v4_acc:.4f}")
print(f"Average v10_acc Accuracy across all subjects: {average_v10_acc:.4f}")


self.subjects ['sub-01', 'sub-02', 'sub-03', 'sub-04']
adap_subject sub-01
preprocessed_eeg_data torch.Size([2400, 271, 201])
data_tensor torch.Size([200, 271, 201])
Data tensor shape: torch.Size([200, 271, 201]), label tensor shape: torch.Size([200]), text length: 200, image length: 2400


  saved_features = torch.load(features_filename)


 - Test Loss: 0.0000, Test Accuracy: 0.0650, Top5 Accuracy: 0.1450
 - Test Loss: 0.0000, v2_acc Accuracy: 0.8050
 - Test Loss: 0.0000, v4_acc Accuracy: 0.5750
 - Test Loss: 0.0000, v10_acc Accuracy: 0.3950
self.subjects ['sub-01', 'sub-02', 'sub-03', 'sub-04']
adap_subject sub-02
preprocessed_eeg_data torch.Size([2400, 271, 201])
data_tensor torch.Size([200, 271, 201])
Data tensor shape: torch.Size([200, 271, 201]), label tensor shape: torch.Size([200]), text length: 200, image length: 2400
 - Test Loss: 0.0000, Test Accuracy: 0.1350, Top5 Accuracy: 0.4000
 - Test Loss: 0.0000, v2_acc Accuracy: 0.8800
 - Test Loss: 0.0000, v4_acc Accuracy: 0.7750
 - Test Loss: 0.0000, v10_acc Accuracy: 0.6050
self.subjects ['sub-01', 'sub-02', 'sub-03', 'sub-04']
adap_subject sub-03
preprocessed_eeg_data torch.Size([2400, 271, 201])
data_tensor torch.Size([200, 271, 201])
Data tensor shape: torch.Size([200, 271, 201]), label tensor shape: torch.Size([200]), text length: 200, image length: 2400
 - Test 