In [1]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'
from itertools import combinations

import clip
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import tqdm

from einops.layers.torch import Rearrange, Reduce

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
import csv
from torch import Tensor
import itertools
import math
import re
import numpy as np
import argparse
from torch import nn
from torch.optim import AdamW


from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from sklearn.model_selection import train_test_split
import pandas as pd

In [2]:
class EPDataset(Dataset):
    def __init__(self, EP_data, labels, features):  
        self.EP_data = EP_data
        self.labels = labels
        self.features = features
        self.max_pool = nn.MaxPool1d(kernel_size=5, stride=5)

    def __len__(self):
        return len(self.EP_data)
        
    def __getitem__(self, idx):
        EP_tensor = torch.tensor(self.EP_data[idx].T, dtype=torch.float32) 
        label = torch.tensor(int(self.labels[idx]), dtype=torch.long)
        feature = self.features[idx]
                
        return EP_tensor, label, feature

In [16]:
class ModelConfig:
    def __init__(self,
                 input_neuron=25,        
                 time_bins=20,          
                 d_model = 150,          
                 nhead=10,                
                 num_transformer_layers=1, 
                 conv_channels=64,      
                 num_conv_blocks=3,      
                 num_classes=117,        
                 residual_dims=[256, 512, 1024], 
                 use_positional_encoding=True,  
                 dim_feedforward_ratio=4,      
                 activation='relu',
                 use_neuron_masking=True,  
                 mask_ratio=0,
                 mask_replacement='zero',
                 epochs = 100):
        
        # Transformer 
        self.transformer = {
            'd_model': d_model,
            'nhead': nhead,
            'num_layers': num_transformer_layers,
            'dim_feedforward': d_model * dim_feedforward_ratio,
            'activation': activation
        }
        
        # cnn
        self.convolution = {
            'channels': conv_channels,
            'num_blocks': num_conv_blocks,
            'kernel_size': (3, 3),
            'pool_size': (2, 2)
        }
        
        # resnet
        self.residual = {
            'dims': residual_dims,
            'skip_connection': True
        }
        
        self.masking = {
            'enabled': use_neuron_masking,
            'ratio': mask_ratio,
            'replacement': mask_replacement
        }

        self.input_dim = input_neuron
        self.time_steps = time_bins
        self.num_classes = num_classes
        self.positional_encoding = use_positional_encoding
        self.lr = 2e-4
        self.epochs = epochs

In [17]:
import torch
import torch.nn as nn
import math

def detect_lost_neurons(ep_data_list):
    if not ep_data_list:
        return np.zeros(25, dtype=bool)
    
    all_data = np.stack(ep_data_list)
    neuron_activity = np.sum(np.abs(all_data), axis=(0, 2))
    return neuron_activity == 0
    
class NeuronMasker(nn.Module):
    def __init__(self, mask_ratio=0.15, replacement='random'):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.replacement = replacement
        
    def forward(self, x):
        if self.training:
            if x is None:
                raise ValueError("Input tensor x is None")
                
            batch_size, seq_len, feat_dim = x.shape
            mask = torch.rand(batch_size, 1, feat_dim, device=x.device) < self.mask_ratio
            mask = mask.expand_as(x)
            
            if self.replacement == 'zero':
                x_masked = x.masked_fill(mask, 0)
            elif self.replacement == 'random':
                random_values = torch.randn_like(x) * 0.02
                x_masked = x.masked_scatter(mask, random_values)
            else:
                raise ValueError(f"Invalid replacement: {self.replacement}")
            
            return x_masked 
        else:
            return x
        
class ResidualLinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        self.activation = nn.GELU()
        self.downsample = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()

    def forward(self, x):
        residual = self.downsample(x)
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        return x + residual

class TimeTransformerConvModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self.input_proj = nn.Linear(config.input_dim, config.transformer['d_model'])
        self.pos_encoder = PositionalEncoding(config.transformer['d_model']) if config.positional_encoding else nn.Identity()
        
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=config.transformer['d_model'],
            nhead=config.transformer['nhead'],
            dim_feedforward=config.transformer['dim_feedforward'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, config.transformer['num_layers'])
        
        self.conv_blocks = nn.Sequential()
        in_channels = 1
        for _ in range(config.convolution['num_blocks']):
            self.conv_blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, config.convolution['channels'], 
                            kernel_size=config.convolution['kernel_size'], padding='same'),
                    nn.BatchNorm2d(config.convolution['channels']),
                    nn.ELU(),
                    nn.MaxPool2d(kernel_size=config.convolution['pool_size'])
                )
            )
            in_channels = config.convolution['channels']
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(config.convolution['channels'], config.num_classes)
        
        self.residual_layers = nn.Sequential()
        current_dim = config.convolution['channels']
        for dim in config.residual['dims']:
            self.residual_layers.append(ResidualLinearBlock(current_dim, dim))
            current_dim = dim
        if current_dim != 1024:
            self.residual_layers.append(nn.Linear(current_dim, 1024))
            self.residual_layers.append(nn.LayerNorm(1024))


        self.masker = NeuronMasker(
            mask_ratio=self.config.masking['ratio'],
            replacement=self.config.masking['replacement']
        )

    def forward(self, x):
        x = self.masker(x)  # [B, T, D]
        #print(x)
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        
        x = x.unsqueeze(1)
        x = self.conv_blocks(x)
        x = self.adaptive_pool(x)
        x = x.flatten(1)
        
        logits = self.classifier(x)
        features = self.residual_layers(x)
        
        return logits, features

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1)]
        return x

In [18]:
class MultitaskLoss(nn.Module):
    def __init__(self, alpha=0.3, temp=0.07):
        super().__init__()
        self.alpha = alpha      # 分类损失权重
        self.temp = temp
        self.ce_loss = nn.CrossEntropyLoss()
        self.temp = temp
        
        self.ce_loss = nn.CrossEntropyLoss()
    
    def contrastive_loss(self, h_neuro, h_img):
        h_neuro = F.normalize(h_neuro, dim=1) + 1e-10
        h_img = F.normalize(h_img, dim=1) + 1e-10
        
        logits_ab = torch.matmul(h_neuro, h_img.T) / self.temp
        logits_ba = torch.matmul(h_img, h_neuro.T) / self.temp
        
        labels = torch.arange(h_neuro.size(0), device=h_neuro.device)
        loss_ab = F.cross_entropy(logits_ab, labels)
        loss_ba = F.cross_entropy(logits_ba, labels)
        
        return (loss_ab + loss_ba) / 2
    
    def forward(self, logits, labels, img_feature, features):
        loss_cls = self.ce_loss(logits, labels)
        loss_cont = self.contrastive_loss(features, img_feature)
        total_loss = self.alpha * loss_cls + (1 - self.alpha) * loss_cont
        return total_loss

In [19]:
def train_model(model, dataloader, optimizer, device, criterion, config, use_distributed=False):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    # 如果是分布式训练，设置epoch
    if use_distributed and hasattr(dataloader.sampler, 'set_epoch'):
        dataloader.sampler.set_epoch(0)  # 这里可以根据epoch参数调整
    
    for batch_idx, (neuro, labels, img_feature) in enumerate(dataloader):
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)

        optimizer.zero_grad()
        
        # 如果是DDP模型，使用module属性
        if hasattr(model, 'module'):
            logits, features = model.module(neuro)
        else:
            logits, features = model(neuro)
            
        loss = criterion(logits, labels, img_feature, features)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), 
            max_norm=3.0,                   
            norm_type=2.0                   
        )
        optimizer.step()
        
        # 统计指标
        total_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
    train_loss = total_loss / len(dataloader)
    train_accuracy = correct / total
    return train_loss, train_accuracy

@torch.no_grad()
def evaluate_model(model, dataloader, device, criterion, config, image_cluster, use_distributed=False):
    model.eval()
    total_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    correct_2way = 0
    correct_4way = 0
    correct_10way = 0
    total = 0
    
    # 从 DataFrame 创建聚类映射字典
    cluster_2_map = torch.tensor(image_cluster["2_cluster"].values, device=device)
    cluster_4_map = torch.tensor(image_cluster["4_cluster"].values, device=device)
    cluster_10_map = torch.tensor(image_cluster["10_cluster"].values, device=device)
    
    for neuro, labels, img_feature in dataloader:
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)
        
        # 如果是DDP模型，使用module属性
        if hasattr(model, 'module'):
            logits, features = model.module(neuro)
        else:
            logits, features = model(neuro)
        
        loss = criterion(logits, labels, img_feature, features)
        total_loss += loss.item()
        
        # （Top-1 和 Top-5）
        _, predicted_top1 = torch.max(logits, 1)
        correct_top1 += (predicted_top1 == labels).sum().item()
        _, predicted_top5 = logits.topk(5, dim=1)
        correct_top5 += torch.sum(predicted_top5.eq(labels.view(-1, 1))).item()
        
        # (2-way, 4-way, 10-way)
        cluster_2_pred = cluster_2_map[predicted_top1]
        cluster_4_pred = cluster_4_map[predicted_top1]
        cluster_10_pred = cluster_10_map[predicted_top1]
        
        cluster_2_true = cluster_2_map[labels]
        cluster_4_true = cluster_4_map[labels]
        cluster_10_true = cluster_10_map[labels]
        
        correct_2way += (cluster_2_pred == cluster_2_true).sum().item()
        correct_4way += (cluster_4_pred == cluster_4_true).sum().item()
        correct_10way += (cluster_10_pred == cluster_10_true).sum().item()
        
        total += labels.size(0)
    
    # 如果是分布式训练，需要同步所有进程的结果
    if use_distributed:
        import torch.distributed as dist
        # 创建tensor来收集所有进程的结果
        metrics = torch.tensor([total_loss, correct_top1, correct_top5, correct_2way, correct_4way, correct_10way, total], device=device)
        dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
        total_loss, correct_top1, correct_top5, correct_2way, correct_4way, correct_10way, total = metrics.tolist()
    
    test_loss = total_loss / len(dataloader)
    test_accuracy = correct_top1 / total
    top5_accuracy = correct_top5 / total
    accuracy_2way = correct_2way / total
    accuracy_4way = correct_4way / total
    accuracy_10way = correct_10way / total
    
    return test_loss, test_accuracy, top5_accuracy, accuracy_2way, accuracy_4way, accuracy_10way

In [20]:
def main_train_loop(config, model, train_loader, test_loader, device, image_cluster):
    optimizer = AdamW(model.parameters(), lr=config.lr)
    criterion = MultitaskLoss(alpha=0.7, temp=0.07)
    
    train_losses, train_accs = [], []
    test_losses, test_accs, test_top5, acc_2way, acc_4way, acc_10way = [], [], [], [], [], []
    best_acc = 0.0
    
    for epoch in range(config.epochs):
        train_loss, train_acc = train_model(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            device=device,
            criterion=criterion,
            config=config
        )
        
        test_loss, test_acc, top5_acc, accuracy_2way, accuracy_4way, accuracy_10way = evaluate_model(
            model=model,
            dataloader=test_loader,
            device=device,
            criterion=criterion,
            config=config,
            image_cluster = image_cluster
        )
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        test_top5.append(top5_acc)
        acc_2way.append(accuracy_2way)
        acc_4way.append(accuracy_4way)
        acc_10way.append(accuracy_10way)
        
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_model.pth")
            best_epoch = epoch
        
        # 打印日志
        #print(f"Epoch {epoch+1}/{config.epochs}")
        #print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2%}")
        #print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2%} | Top-5 Acc: {top5_acc:.2%}")
        #print(f"2-way Acc: {accuracy_2way:.2%} | 4-way Acc: {accuracy_4way:.2%} | 10-way Acc: {accuracy_10way:.2%}")
        #print("-" * 60)
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train")
    plt.plot(test_losses, label="Test")
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label="Train Acc")
    plt.plot(test_accs, label="Test Acc")
    plt.plot(test_top5, label="Test Top-5")
    plt.title("Accuracy Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.tight_layout()
    plt.close()
    
    return {
        "best_test_acc": best_acc,
        "final_top5_acc": test_top5[-1],
        "train_history": {
            "loss": train_losses,
            "accuracy": train_accs
        },
        "test_history": {
            "loss": test_losses,
            "accuracy": test_accs,
            "top5_accuracy": test_top5,
            "acc_2way": acc_2way,
            "acc_4way": acc_4way,
            "acc_10way": acc_10way
        },
        "best_epoch": best_epoch
    }

In [21]:
date_order = ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              #'012123', 
              '022223', '032123', '042323']



In [22]:
# 新的数据加载代码 - 适配mouse6数据格式
import pickle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

def load_mouse6_data():
    """加载mouse6的数据"""
    # 加载spike rate数据
    with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_spike_rate_data.pkl", 'rb') as f:
        trial_spike_rate_data = pickle.load(f)
    
    # 加载日期和图像标签
    with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_date.pkl", 'rb') as f:
        trial_date = pickle.load(f)
    
    with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_image.pkl", 'rb') as f:
        trial_image = pickle.load(f)
    
    # 加载图像特征（如果存在）
    try:
        with open("image_feature.pkl", 'rb') as f:
            image_feature_list = pickle.load(f)
    except FileNotFoundError:
        print("image_feature.pkl not found, using dummy features")
        # 创建虚拟特征（117个图像，每个图像512维特征）
        image_feature_list = [np.random.randn(512) for _ in range(117)]
    
    return trial_spike_rate_data, trial_date, trial_image, image_feature_list

def prepare_train_test_data(trial_spike_rate_data, trial_date, trial_image, image_feature_list, test_dates):
    """准备训练和测试数据"""
    # 创建DataFrame
    data_df = pd.DataFrame({
        'spike_data': trial_spike_rate_data,
        'date': trial_date,
        'image': trial_image
    })
    
    # 添加特征
    data_df['feature'] = [image_feature_list[img] for img in data_df['image']]
    
    # 分割训练和测试数据
    train_data = data_df[~data_df['date'].isin(test_dates)]
    test_data = data_df[data_df['date'].isin(test_dates)]
    
    return train_data, test_data

# 定义日期顺序（从consistency notebook复制）
date_order = ['21322', '22522', '31722', '42422', 
              '52422', '062422', '72322', '82322', 
              '92422', '102122', '112022', '122022', 
              '22223', '32123', '42323']

# 加载数据
trial_spike_rate_data, trial_date, trial_image, image_feature_list = load_mouse6_data()

print(f"Loaded {len(trial_spike_rate_data)} trials")
print(f"Spike data shape: {trial_spike_rate_data[0].shape}")
print(f"Date range: {min(trial_date)} to {max(trial_date)}")
print(f"Image range: {min(trial_image)} to {max(trial_image)}")


image_feature.pkl not found, using dummy features
Loaded 18456 trials
Spike data shape: (18, 40)
Date range: 21322 to 122022
Image range: 1 to 117


In [27]:
# 主训练循环 - 适配新数据格式
def train_with_mouse6_data():
    """使用mouse6数据进行训练"""
    results_dict = {}
    
    # 遍历每个测试月份
    for slide in range(1, len(date_order)): 
        test_month = [int(date_order[slide])]  # 转换为整数以匹配数据格式
        
        print(f"Testing on month: {test_month}")
        
        # 准备训练和测试数据
        train_data, test_data = prepare_train_test_data(
            trial_spike_rate_data, trial_date, trial_image, image_feature_list, test_month
        )
        
        print(f"Train data: {len(train_data)} trials")
        print(f"Test data: {len(test_data)} trials")
        
        if len(train_data) == 0 or len(test_data) == 0:
            print(f"Skipping slide {slide} due to insufficient data")
            continue
        
        # 创建数据集
        current_input_neuron = train_data['spike_data'].iloc[0].shape[0]
        current_time_bins = train_data['spike_data'].iloc[0].shape[1]
        
        train_dataset = EPDataset(
            train_data['spike_data'].tolist(), 
            train_data['image'].tolist(), 
            train_data['feature'].tolist()
        )
        test_dataset = EPDataset(
            test_data['spike_data'].tolist(), 
            test_data['image'].tolist(), 
            test_data['feature'].tolist()
        )
        
        # 创建数据加载器
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)
        
        # 创建模型配置
        config = ModelConfig(
            input_neuron=current_input_neuron,
            time_bins=current_time_bins
        )
        
        # 设置设备
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = TimeTransformerConvModel(config).to(device)
        
        # 加载图像聚类信息（如果存在）
        try:
            image_cluster = pd.read_csv("image_cluster.csv")
        except FileNotFoundError:
            print("image_cluster.csv not found, using default clustering")
            # 创建默认聚类（每个图像一个类别）
            image_cluster = pd.DataFrame({
                'image': range(117),
                'cluster': range(117)
            })
        
        print(f"Training with {current_input_neuron} neurons, {current_time_bins} time bins")
        
        # 开始训练
        results = main_train_loop(
            config=config,
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
            image_cluster=image_cluster
        )
        
        results_dict[slide] = results
        print(f"Completed training for slide {slide}")
        print("-" * 60)
    
    return results_dict

# 运行训练
results_dict = train_with_mouse6_data()

# # 保存结果
# import pickle
# with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/mouse6_training_results.pkl", "wb") as f:
#     pickle.dump(results_dict, f)

# print("Training completed and results saved!")


Testing on month: [22522]


IndexError: list index out of range

In [24]:
# 修复索引错误的调试代码
import pickle
import pandas as pd
import numpy as np

def debug_data_loading():
    """调试数据加载过程"""
    print("=== 调试数据加载 ===")
    
    # 加载数据
    try:
        with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_spike_rate_data.pkl", 'rb') as f:
            trial_spike_rate_data = pickle.load(f)
        print(f"✓ Loaded spike data: {len(trial_spike_rate_data)} trials")
    except Exception as e:
        print(f"✗ Error loading spike data: {e}")
        return
    
    try:
        with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_date.pkl", 'rb') as f:
            trial_date = pickle.load(f)
        print(f"✓ Loaded dates: {len(trial_date)} entries")
    except Exception as e:
        print(f"✗ Error loading dates: {e}")
        return
    
    try:
        with open("/media/ubuntu/sda/data/paper_architecture/02_consistency/data_for_train/mouse6/trial_image.pkl", 'rb') as f:
            trial_image = pickle.load(f)
        print(f"✓ Loaded images: {len(trial_image)} entries")
    except Exception as e:
        print(f"✗ Error loading images: {e}")
        return
    
    # 检查数据一致性
    if len(trial_spike_rate_data) != len(trial_date) or len(trial_date) != len(trial_image):
        print("⚠ Warning: Data lengths inconsistent!")
        min_len = min(len(trial_spike_rate_data), len(trial_date), len(trial_image))
        print(f"Truncating to minimum length: {min_len}")
        trial_spike_rate_data = trial_spike_rate_data[:min_len]
        trial_date = trial_date[:min_len]
        trial_image = trial_image[:min_len]
    
    # 检查数据内容
    print(f"Date range: {min(trial_date)} to {max(trial_date)}")
    print(f"Image range: {min(trial_image)} to {max(trial_image)}")
    print(f"Unique images: {sorted(set(trial_image))}")
    
    if len(trial_spike_rate_data) > 0:
        print(f"Spike data shape: {trial_spike_rate_data[0].shape}")
    
    # 检查图像特征
    try:
        with open("image_feature.pkl", 'rb') as f:
            image_feature_list = pickle.load(f)
        print(f"✓ Loaded image features: {len(image_feature_list)} features")
        if len(image_feature_list) > 0:
            print(f"Feature dimension: {len(image_feature_list[0])}")
    except FileNotFoundError:
        print("⚠ image_feature.pkl not found, creating dummy features")
        image_feature_list = [np.random.randn(512) for _ in range(117)]
    except Exception as e:
        print(f"✗ Error loading image features: {e}")
        image_feature_list = [np.random.randn(512) for _ in range(117)]
    
    # 测试特征索引
    print("\\n=== 测试特征索引 ===")
    for i, img_id in enumerate(trial_image[:5]):  # 测试前5个
        print(f"Trial {i}: Image ID = {img_id}")
        if img_id >= 0 and img_id < len(image_feature_list):
            print(f"  ✓ Valid index, feature shape: {image_feature_list[img_id].shape}")
        else:
            print(f"  ✗ Invalid index! Image ID {img_id} not in range [0, {len(image_feature_list)-1}]")
    
    return trial_spike_rate_data, trial_date, trial_image, image_feature_list

# 运行调试
trial_spike_rate_data, trial_date, trial_image, image_feature_list = debug_data_loading()


=== 调试数据加载 ===
✓ Loaded spike data: 18456 trials
✓ Loaded dates: 18456 entries
✓ Loaded images: 18456 entries
Date range: 21322 to 122022
Image range: 1 to 117
Unique images: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117]
Spike data shape: (18, 40)
⚠ image_feature.pkl not found, creating dummy features
\n=== 测试特征索引 ===
Trial 0: Image ID = 27
  ✓ Valid index, feature shape: (512,)
Trial 1: Image ID = 42
  ✓ Valid index, feature shape: (512,)
Trial 2: Image ID = 5
  ✓ Valid index, feature shape: (512,)
Trial 3: Image ID = 101
  ✓ Valid index, feature shape: (512,)


In [25]:
# 修复版本的数据准备函数
def safe_prepare_train_test_data(trial_spike_rate_data, trial_date, trial_image, image_feature_list, test_dates):
    """安全准备训练和测试数据，避免索引错误"""
    print(f"Preparing data with test dates: {test_dates}")
    
    # 创建DataFrame
    data_df = pd.DataFrame({
        'spike_data': trial_spike_rate_data,
        'date': trial_date,
        'image': trial_image
    })
    
    print(f"Created DataFrame with {len(data_df)} rows")
    print(f"Image ID range: {data_df['image'].min()} to {data_df['image'].max()}")
    print(f"Available features: {len(image_feature_list)}")
    
    # 安全添加特征
    features = []
    for i, img_id in enumerate(data_df['image']):
        # 处理不同的图像ID格式
        if isinstance(img_id, (int, float)):
            # 如果是1-117范围的ID，转换为0-116
            if 1 <= img_id <= 117:
                img_idx = int(img_id) - 1
            # 如果已经是0-116范围的ID
            elif 0 <= img_id < 117:
                img_idx = int(img_id)
            else:
                print(f"Warning: Invalid image ID {img_id} at index {i}, using index 0")
                img_idx = 0
        else:
            print(f"Warning: Non-numeric image ID {img_id} at index {i}, using index 0")
            img_idx = 0
        
        # 安全获取特征
        if 0 <= img_idx < len(image_feature_list):
            feature = image_feature_list[img_idx]
        else:
            print(f"Warning: Image index {img_idx} out of range, using dummy feature")
            feature = np.random.randn(512)
        
        features.append(feature)
    
    data_df['feature'] = features
    
    # 分割训练和测试数据
    train_data = data_df[~data_df['date'].isin(test_dates)]
    test_data = data_df[data_df['date'].isin(test_dates)]
    
    print(f"Train data: {len(train_data)} trials")
    print(f"Test data: {len(test_data)} trials")
    
    if len(train_data) > 0:
        print(f"Train data image range: {train_data['image'].min()} to {train_data['image'].max()}")
    if len(test_data) > 0:
        print(f"Test data image range: {test_data['image'].min()} to {test_data['image'].max()}")
    
    return train_data, test_data

# 定义日期顺序
date_order = ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              '022223', '032123', '042323']

print("数据准备函数已定义完成！")


数据准备函数已定义完成！


In [29]:
# 修复版本的快速测试训练函数
def safe_quick_test_training():
    """安全快速测试训练流程"""
    print("=== 开始安全快速测试 ===")
    
    # 选择第一个测试月份
    test_month = [int(date_order[1])]  # 使用第二个月份作为测试
    
    print(f"Testing on month: {test_month}")
    
    # 准备数据
    try:
        train_data, test_data = safe_prepare_train_test_data(
            trial_spike_rate_data, trial_date, trial_image, image_feature_list, test_month
        )
    except Exception as e:
        print(f"Error preparing data: {e}")
        return None
    
    if len(train_data) == 0 or len(test_data) == 0:
        print("No data available for testing")
        return None
    
    # 创建数据集
    try:
        current_input_neuron = train_data['spike_data'].iloc[0].shape[0]
        current_time_bins = train_data['spike_data'].iloc[0].shape[1]
        
        print(f"Data shape: {current_input_neuron} neurons x {current_time_bins} time bins")
        
        train_dataset = EPDataset(
            train_data['spike_data'].tolist(), 
            train_data['image'].tolist(), 
            train_data['feature'].tolist()
        )
        test_dataset = EPDataset(
            test_data['spike_data'].tolist(), 
            test_data['image'].tolist(), 
            test_data['feature'].tolist()
        )
        
        print(f"Created datasets - Train: {len(train_dataset)}, Test: {len(test_dataset)}")
        
    except Exception as e:
        print(f"Error creating datasets: {e}")
        return None
    
    # 创建数据加载器
    try:
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)  # 减小batch size
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, drop_last=True)
        
        print("Created data loaders")
        
    except Exception as e:
        print(f"Error creating data loaders: {e}")
        return None
    
    # 创建模型配置
    try:
        config = ModelConfig(
            input_neuron=current_input_neuron,
            time_bins=current_time_bins,
            epochs=5  # 进一步减少训练轮数
        )
        
        # 设置设备
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        model = TimeTransformerConvModel(config).to(device)
        
        print("Created model successfully")
        
    except Exception as e:
        print(f"Error creating model: {e}")
        return None
    
    # 创建默认图像聚类
    try:
        image_cluster = pd.DataFrame({
            'image': range(117),
            'cluster': range(117)
        })
        
        print("Starting training...")
        
        # 开始训练
        results = main_train_loop(
            config=config,
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
            image_cluster=image_cluster
        )
        
        print("Training completed successfully!")
        print(f"Final test accuracy: {results.get('test_accuracy', 'N/A')}")
        
        return results
        
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()
        return None

# 运行安全快速测试
print("准备运行安全快速测试...")
safe_results = safe_quick_test_training()


准备运行安全快速测试...
=== 开始安全快速测试 ===
Testing on month: [22522]
Preparing data with test dates: [22522]
Created DataFrame with 18456 rows
Image ID range: 1 to 117
Available features: 117
Train data: 17286 trials
Test data: 1170 trials
Train data image range: 1 to 117
Test data image range: 1 to 117
Data shape: 18 neurons x 40 time bins
Created datasets - Train: 17286, Test: 1170
Created data loaders
Using device: cuda
Created model successfully
Starting training...
Error during training: mat1 and mat2 shapes cannot be multiplied (16x1024 and 512x16)


Traceback (most recent call last):
  File "/tmp/ipykernel_3363688/3666048438.py", line 89, in safe_quick_test_training
    results = main_train_loop(
        config=config,
    ...<4 lines>...
        image_cluster=image_cluster
    )
  File "/tmp/ipykernel_3363688/2739153585.py", line 10, in main_train_loop
    train_loss, train_acc = train_model(
                            ~~~~~~~~~~~^
        model=model,
        ^^^^^^^^^^^^
    ...<4 lines>...
        config=config
        ^^^^^^^^^^^^^
    )
    ^
  File "/tmp/ipykernel_3363688/1792460012.py", line 24, in train_model
    loss = criterion(logits, labels, img_feature, features)
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

/pytorch/aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [5,0,0] Assertion `t >= 0 && t < n_classes` failed.


In [30]:
# 调试模型维度问题
import torch
import torch.nn as nn

def debug_model_dimensions():
    """调试模型维度问题"""
    print("=== 调试模型维度 ===")
    
    # 检查数据形状
    if len(trial_spike_rate_data) > 0:
        sample_shape = trial_spike_rate_data[0].shape
        print(f"Sample spike data shape: {sample_shape}")
        
        # 模拟一个batch的数据
        batch_size = 16
        sample_data = torch.randn(batch_size, sample_shape[1], sample_shape[0])  # [B, T, D]
        print(f"Input tensor shape: {sample_data.shape}")
        
        # 创建模型配置
        config = ModelConfig(
            input_neuron=sample_shape[0],
            time_bins=sample_shape[1],
            d_model=150,
            conv_channels=64,
            num_classes=117
        )
        
        print(f"Model config:")
        print(f"  input_neuron: {config.input_dim}")
        print(f"  time_bins: {config.time_steps}")
        print(f"  d_model: {config.transformer['d_model']}")
        print(f"  conv_channels: {config.convolution['channels']}")
        print(f"  num_classes: {config.num_classes}")
        
        # 创建模型
        device = torch.device("cpu")  # 使用CPU进行调试
        model = TimeTransformerConvModel(config).to(device)
        
        print(f"\\n=== 模型结构调试 ===")
        
        # 逐步调试forward过程
        x = sample_data.to(device)
        print(f"1. Input: {x.shape}")
        
        # 通过masker
        x = model.masker(x)
        print(f"2. After masker: {x.shape}")
        
        # 通过input_proj
        x = model.input_proj(x)
        print(f"3. After input_proj: {x.shape}")
        
        # 通过positional encoding
        x = model.pos_encoder(x)
        print(f"4. After pos_encoder: {x.shape}")
        
        # 通过transformer
        x = model.transformer(x)
        print(f"5. After transformer: {x.shape}")
        
        # 添加channel维度
        x = x.unsqueeze(1)
        print(f"6. After unsqueeze: {x.shape}")
        
        # 通过conv blocks
        try:
            x = model.conv_blocks(x)
            print(f"7. After conv_blocks: {x.shape}")
        except Exception as e:
            print(f"7. Error in conv_blocks: {e}")
            return
        
        # 通过adaptive pool
        x = model.adaptive_pool(x)
        print(f"8. After adaptive_pool: {x.shape}")
        
        # flatten
        x = x.flatten(1)
        print(f"9. After flatten: {x.shape}")
        
        # 检查classifier
        try:
            logits = model.classifier(x)
            print(f"10. After classifier: {logits.shape}")
        except Exception as e:
            print(f"10. Error in classifier: {e}")
            print(f"    classifier input dim: {model.classifier.in_features}")
            print(f"    classifier output dim: {model.classifier.out_features}")
            print(f"    actual input shape: {x.shape}")
            return
        
        # 检查residual layers
        try:
            features = model.residual_layers(x)
            print(f"11. After residual_layers: {features.shape}")
        except Exception as e:
            print(f"11. Error in residual_layers: {e}")
            return
        
        print("\\n✓ 模型维度调试完成，没有发现错误")
        
        # 检查图像特征维度
        print(f"\\n=== 图像特征维度检查 ===")
        if len(image_feature_list) > 0:
            feature_dim = len(image_feature_list[0])
            print(f"Image feature dimension: {feature_dim}")
            
            # 检查对比学习中的矩阵乘法
            h_neuro = features  # [batch_size, 1024]
            h_img = torch.randn(batch_size, feature_dim)  # [batch_size, feature_dim]
            
            print(f"h_neuro shape: {h_neuro.shape}")
            print(f"h_img shape: {h_img.shape}")
            
            # 检查矩阵乘法
            try:
                logits_ab = torch.matmul(h_neuro, h_img.T) / 0.07
                print(f"✓ Contrastive logits shape: {logits_ab.shape}")
            except Exception as e:
                print(f"✗ Error in contrastive loss: {e}")
                print(f"    h_neuro: {h_neuro.shape}")
                print(f"    h_img.T: {h_img.T.shape}")
        
    else:
        print("No spike data available for debugging")

# 运行调试
debug_model_dimensions()


=== 调试模型维度 ===
Sample spike data shape: (18, 40)
Input tensor shape: torch.Size([16, 40, 18])
Model config:
  input_neuron: 18
  time_bins: 40
  d_model: 150
  conv_channels: 64
  num_classes: 117
\n=== 模型结构调试 ===
1. Input: torch.Size([16, 40, 18])
2. After masker: torch.Size([16, 40, 18])
3. After input_proj: torch.Size([16, 40, 150])
4. After pos_encoder: torch.Size([16, 40, 150])
5. After transformer: torch.Size([16, 40, 150])
6. After unsqueeze: torch.Size([16, 1, 40, 150])
7. After conv_blocks: torch.Size([16, 64, 5, 18])
8. After adaptive_pool: torch.Size([16, 64, 1, 1])
9. After flatten: torch.Size([16, 64])
10. After classifier: torch.Size([16, 117])
11. After residual_layers: torch.Size([16, 1024])
\n✓ 模型维度调试完成，没有发现错误
\n=== 图像特征维度检查 ===
Image feature dimension: 512
h_neuro shape: torch.Size([16, 1024])
h_img shape: torch.Size([16, 512])
✗ Error in contrastive loss: mat1 and mat2 shapes cannot be multiplied (16x1024 and 512x16)
    h_neuro: torch.Size([16, 1024])
    h_img.T: t

In [31]:
# 修复模型维度问题的训练函数
def create_fixed_model_config(trial_spike_rate_data):
    """创建修复了维度问题的模型配置"""
    if len(trial_spike_rate_data) == 0:
        raise ValueError("No spike data available")
    
    sample_shape = trial_spike_rate_data[0].shape
    n_neurons = sample_shape[0]
    n_time_bins = sample_shape[1]
    
    print(f"Data shape: {n_neurons} neurons x {n_time_bins} time bins")
    
    # 创建修复的配置
    config = ModelConfig(
        input_neuron=n_neurons,
        time_bins=n_time_bins,
        d_model=128,  # 减小d_model以避免维度问题
        nhead=8,      # 确保d_model能被nhead整除
        conv_channels=64,
        num_conv_blocks=2,  # 减少卷积层数
        num_classes=117,
        residual_dims=[256, 512],  # 调整residual维度
        epochs=5
    )
    
    # 验证配置
    assert config.transformer['d_model'] % config.transformer['nhead'] == 0, \
        f"d_model ({config.transformer['d_model']}) must be divisible by nhead ({config.transformer['nhead']})"
    
    print(f"Created config:")
    print(f"  input_dim: {config.input_dim}")
    print(f"  time_steps: {config.time_steps}")
    print(f"  d_model: {config.transformer['d_model']}")
    print(f"  nhead: {config.transformer['nhead']}")
    print(f"  conv_channels: {config.convolution['channels']}")
    
    return config

def safe_train_with_fixed_dimensions():
    """使用修复维度的安全训练"""
    print("=== 开始修复维度的安全训练 ===")
    
    try:
        # 创建修复的模型配置
        config = create_fixed_model_config(trial_spike_rate_data)
        
        # 选择测试月份
        test_month = [int(date_order[1])]
        print(f"Testing on month: {test_month}")
        
        # 准备数据
        train_data, test_data = safe_prepare_train_test_data(
            trial_spike_rate_data, trial_date, trial_image, image_feature_list, test_month
        )
        
        if len(train_data) == 0 or len(test_data) == 0:
            print("No data available for testing")
            return None
        
        # 创建数据集
        train_dataset = EPDataset(
            train_data['spike_data'].tolist(), 
            train_data['image'].tolist(), 
            train_data['feature'].tolist()
        )
        test_dataset = EPDataset(
            test_data['spike_data'].tolist(), 
            test_data['image'].tolist(), 
            test_data['feature'].tolist()
        )
        
        # 创建数据加载器
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)  # 减小batch size
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, drop_last=True)
        
        # 设置设备
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        # 创建模型
        model = TimeTransformerConvModel(config).to(device)
        
        # 测试模型前向传播
        print("Testing model forward pass...")
        with torch.no_grad():
            sample_batch = next(iter(train_loader))
            EP_tensor, label, feature = sample_batch
            EP_tensor = EP_tensor.to(device)
            feature = torch.stack(feature).to(device)
            
            print(f"Input shapes:")
            print(f"  EP_tensor: {EP_tensor.shape}")
            print(f"  feature: {feature.shape}")
            
            try:
                logits, features = model(EP_tensor)
                print(f"✓ Model forward pass successful!")
                print(f"  logits: {logits.shape}")
                print(f"  features: {features.shape}")
            except Exception as e:
                print(f"✗ Model forward pass failed: {e}")
                import traceback
                traceback.print_exc()
                return None
        
        # 创建图像聚类
        image_cluster = pd.DataFrame({
            'image': range(117),
            'cluster': range(117)
        })
        
        print("Starting training...")
        
        # 开始训练
        results = main_train_loop(
            config=config,
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            device=device,
            image_cluster=image_cluster
        )
        
        print("Training completed successfully!")
        print(f"Final test accuracy: {results.get('test_accuracy', 'N/A')}")
        
        return results
        
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()
        return None

# 运行修复维度的训练
fixed_results = safe_train_with_fixed_dimensions()


=== 开始修复维度的安全训练 ===
Data shape: 18 neurons x 40 time bins
Created config:
  input_dim: 18
  time_steps: 40
  d_model: 128
  nhead: 8
  conv_channels: 64
Testing on month: [22522]
Preparing data with test dates: [22522]
Created DataFrame with 18456 rows
Image ID range: 1 to 117
Available features: 117
Train data: 17286 trials
Test data: 1170 trials
Train data image range: 1 to 117
Test data image range: 1 to 117
Using device: cuda
Error during training: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



Traceback (most recent call last):
  File "/tmp/ipykernel_3363688/1546890164.py", line 81, in safe_train_with_fixed_dimensions
    model = TimeTransformerConvModel(config).to(device)
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1343, in to
    return self._apply(convert)
           ~~~~~~~~~~~^^^^^^^^^
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 903, in _apply
    module._apply(fn)
    ~~~~~~~~~~~~~^^^^
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 930, in _apply
    param_applied = fn(param)
  File "/home/ubuntu/.conda/envs/visual_decoding/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1329, in convert
    return t.to(
           ~~~~^
        device,
        ^^^^^^^
        dtype if t.is_floating_point() or t.is_complex() else None,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^