In [1]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam, AdamW

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'
from itertools import combinations

import clip
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import tqdm

from einops.layers.torch import Rearrange, Reduce

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
import csv
from torch import Tensor
import itertools
import math
import re
import numpy as np
import argparse
import pickle


from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from sklearn.model_selection import train_test_split
import pandas as pd

In [2]:

def generate_binned_spiketrains(trigger_time_df, spike_inf_df, target_image):
    """
    生成指定image下的分箱脉冲矩阵
    
    参数
    ----
    trigger_time_df : pd.DataFrame
        列包括：start, end, image, date, order
    spike_inf_df : pd.DataFrame
        列包括：time, neuron, date
    target_image : str/int
        目标图像标识
    
    返回
    ----
    binned_data : list of ndarray
        [
            # Trial 1 的矩阵 (neurons × 100 bins)
            array([[n0_bin1_count, n0_bin2_count, ...],
                   [n1_bin1_count, n1_bin2_count, ...],
                   ...]),
            # Trial 2
            ...
        ]
    """
    # =====================================
    # 步骤 1: 筛选目标试次并转换时间单位
    # =====================================
    mask = (trigger_time_df['image'] == target_image)
    target_triggers = trigger_time_df[mask].sort_values('order')
    
    # 转换时间单位 (0.1ms → 秒)
    target_triggers = target_triggers.copy()
    target_triggers['start'] = target_triggers['start'] * 0.1e-3
    target_triggers['end'] = target_triggers['end'] * 0.1e-3

    # =====================================
    # 步骤 2: 处理神经脉冲数据
    # =====================================
    target_spikes = spike_inf_df.copy()
    target_spikes['time'] = target_spikes['time'] * 0.1e-3 
    
    # 获取所有唯一神经元ID并排序（基于完整数据集）
    all_neuron_ids = sorted(spike_inf_df['Neuron'].unique()) if not spike_inf_df.empty else []

    # =====================================
    # 步骤 3: 分箱处理每个试次
    # =====================================
    binned_data = []
    for _, trial in target_triggers.iterrows():
        trial_start = trial['start']
        trial_end = trial['end']
        trial_duration = trial_end - trial_start
        
        spike_mask = (target_spikes['time'] >= trial_start) & (target_spikes['time'] < trial_end)
        trial_spikes = target_spikes[spike_mask].copy()
        trial_spikes['rel_time'] = trial_spikes['time'] - trial_start
        
        bin_matrix = np.zeros((len(all_neuron_ids), 100), dtype=int)
        
        neuron_groups = trial_spikes.groupby('Neuron')
        for neuron_idx, neuron_id in enumerate(all_neuron_ids):
            if neuron_id in neuron_groups.groups:
                group = neuron_groups.get_group(neuron_id)
                times = group['rel_time'].values
                
                counts, _ = np.histogram(times, bins=100, range=(0, trial_duration))
                bin_matrix[neuron_idx] = counts
                
        binned_data.append(bin_matrix)
    
    return binned_data

In [3]:

class EPDataset(Dataset):
    def __init__(self, EP_data, labels, features, img_paths):
        self.img_paths = img_paths  
        self.EP_data = EP_data
        self.labels = labels
        self.features = features

    def __len__(self):
        return len(self.EP_data)
    
    def __getitem__(self, idx):
        EP_tensor = torch.tensor(self.EP_data[idx].T, dtype=torch.float32) 
        label = torch.tensor(int(self.labels[idx]), dtype=torch.long)
        feature = self.features[idx]
        return EP_tensor, label, feature, self.img_paths[idx]

def extract_features(model, dataloader, device):
    model.eval()
    all_features = []
    all_paths = []
    
    with torch.no_grad():
        for neuro, _, _, paths in dataloader:
            neuro = neuro.to(device)
            _, features = model(neuro)
            all_features.append(features.cpu())
            all_paths.extend(paths)
    
    return torch.cat(all_features), all_paths

In [4]:
class ModelConfig:
    def __init__(self,
                 input_neuron=25,        
                 time_bins=20,          
                 d_model = 150,          
                 nhead=10,                
                 num_transformer_layers=1, 
                 conv_channels=64,      
                 num_conv_blocks=3,      
                 num_classes=117,        
                 residual_dims=[256, 512, 1024], 
                 use_positional_encoding=True,  
                 dim_feedforward_ratio=4,      
                 activation='relu',
                 use_neuron_masking=True,  
                 mask_ratio=0,
                 mask_replacement='zero'):
        
        # Transformer 
        self.transformer = {
            'd_model': d_model,
            'nhead': nhead,
            'num_layers': num_transformer_layers,
            'dim_feedforward': d_model * dim_feedforward_ratio,
            'activation': activation
        }
        
        # cnn
        self.convolution = {
            'channels': conv_channels,
            'num_blocks': num_conv_blocks,
            'kernel_size': (3, 3),
            'pool_size': (2, 2)
        }
        
        # resnet
        self.residual = {
            'dims': residual_dims,
            'skip_connection': True
        }
        
        self.masking = {
            'enabled': use_neuron_masking,
            'ratio': mask_ratio,
            'replacement': mask_replacement
        }

        self.input_dim = input_neuron
        self.time_steps = time_bins
        self.num_classes = num_classes
        self.positional_encoding = use_positional_encoding
        self.lr = 2e-4
        self.epochs = 200

In [5]:

class NeuronMasker(nn.Module):
    def __init__(self, mask_ratio=0.15, replacement='random'):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.replacement = replacement
        
    def forward(self, x):
        if self.training:
            if x is None:
                raise ValueError("Input tensor x is None")
                
            batch_size, seq_len, feat_dim = x.shape
            mask = torch.rand(batch_size, 1, feat_dim, device=x.device) < self.mask_ratio
            mask = mask.expand_as(x)
            
            if self.replacement == 'zero':
                x_masked = x.masked_fill(mask, 0)
            elif self.replacement == 'random':
                random_values = torch.randn_like(x) * 0.02
                x_masked = x.masked_scatter(mask, random_values)
            else:
                raise ValueError(f"Invalid replacement: {self.replacement}")
            
            return x_masked 
        else:
            return x
        
class ResidualLinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        self.activation = nn.GELU()
        self.downsample = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()

    def forward(self, x):
        residual = self.downsample(x)
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        return x + residual

class TimeTransformerConvModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self.input_proj = nn.Linear(config.input_dim, config.transformer['d_model'])
        self.pos_encoder = PositionalEncoding(config.transformer['d_model']) if config.positional_encoding else nn.Identity()
        
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=config.transformer['d_model'],
            nhead=config.transformer['nhead'],
            dim_feedforward=config.transformer['dim_feedforward'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, config.transformer['num_layers'])
        
        self.conv_blocks = nn.Sequential()
        in_channels = 1
        for _ in range(config.convolution['num_blocks']):
            self.conv_blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, config.convolution['channels'], 
                            kernel_size=config.convolution['kernel_size'], padding='same'),
                    nn.BatchNorm2d(config.convolution['channels']),
                    nn.ELU(),
                    nn.MaxPool2d(kernel_size=config.convolution['pool_size'])
                )
            )
            in_channels = config.convolution['channels']
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(config.convolution['channels'], config.num_classes)
        
        self.residual_layers = nn.Sequential()
        current_dim = config.convolution['channels']
        for dim in config.residual['dims']:
            self.residual_layers.append(ResidualLinearBlock(current_dim, dim))
            current_dim = dim
        if current_dim != 1024:
            self.residual_layers.append(nn.Linear(current_dim, 1024))
            self.residual_layers.append(nn.LayerNorm(1024))


        self.masker = NeuronMasker(
            mask_ratio=self.config.masking['ratio'],
            replacement=self.config.masking['replacement']
        )

    def forward(self, x):
        x = self.masker(x)  # [B, T, D]
        #print(x)
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        
        x = x.unsqueeze(1)
        x = self.conv_blocks(x)
        x = self.adaptive_pool(x)
        x = x.flatten(1)
        
        logits = self.classifier(x)
        features = self.residual_layers(x)
        
        return logits, features

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1)]
        return x

In [6]:
class MultitaskLoss(nn.Module):
    def __init__(self, alpha=0, temp=0.07):
        super().__init__()
        self.alpha = alpha      # 分类损失权重
        self.temp = temp
        self.ce_loss = nn.CrossEntropyLoss()
        self.temp = temp
        
        self.ce_loss = nn.CrossEntropyLoss()
    
    def contrastive_loss(self, h_neuro, h_img):
        h_neuro = F.normalize(h_neuro, dim=1) + 1e-10
        h_img = F.normalize(h_img, dim=1) + 1e-10
        
        logits_ab = torch.matmul(h_neuro, h_img.T) / self.temp
        logits_ba = torch.matmul(h_img, h_neuro.T) / self.temp
        
        labels = torch.arange(h_neuro.size(0), device=h_neuro.device)
        loss_ab = F.cross_entropy(logits_ab, labels)
        loss_ba = F.cross_entropy(logits_ba, labels)
        
        return (loss_ab + loss_ba) / 2
    
    def forward(self, logits, labels, img_feature, features):
        loss_cls = self.ce_loss(logits, labels)
        loss_cont = self.contrastive_loss(features, img_feature)
        total_loss = self.alpha * loss_cls + (1 - self.alpha) * loss_cont
        return total_loss

In [7]:
def train_model(model, dataloader, optimizer, device, criterion, config):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (neuro, labels, img_feature, _) in enumerate(dataloader):
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)

        optimizer.zero_grad()
        
        logits, features = model(neuro)
        loss = criterion(logits, labels, img_feature, features)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), 
            max_norm=3.0,                   
            norm_type=2.0                   
        )
        optimizer.step()
        
        # 统计指标
        total_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
    train_loss = total_loss / len(dataloader)
    train_accuracy = correct / total
    return train_loss, train_accuracy

@torch.no_grad()
def evaluate_model(model, dataloader, device, criterion, config, image_cluster):
    model.eval()
    total_loss = 0.0
    correct_top1 = 0
    correct_top5 = 0
    correct_2way = 0
    correct_4way = 0
    correct_10way = 0
    total = 0
    
    # 从 DataFrame 创建聚类映射字典
    cluster_2_map = torch.tensor(image_cluster["2_cluster"].values, device=device)
    cluster_4_map = torch.tensor(image_cluster["4_cluster"].values, device=device)
    cluster_10_map = torch.tensor(image_cluster["10_cluster"].values, device=device)
    
    for neuro, labels, img_feature, _ in dataloader:
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)
        
        logits, features = model(neuro)
        
        loss = criterion(logits, labels, img_feature, features)
        total_loss += loss.item()
        
        # （Top-1 和 Top-5）
        _, predicted_top1 = torch.max(logits, 1)
        correct_top1 += (predicted_top1 == labels).sum().item()
        _, predicted_top5 = logits.topk(5, dim=1)
        correct_top5 += torch.sum(predicted_top5.eq(labels.view(-1, 1))).item()
        
        # (2-way, 4-way, 10-way)
        cluster_2_pred = cluster_2_map[predicted_top1]
        cluster_4_pred = cluster_4_map[predicted_top1]
        cluster_10_pred = cluster_10_map[predicted_top1]
        
        cluster_2_true = cluster_2_map[labels]
        cluster_4_true = cluster_4_map[labels]
        cluster_10_true = cluster_10_map[labels]
        
        correct_2way += (cluster_2_pred == cluster_2_true).sum().item()
        correct_4way += (cluster_4_pred == cluster_4_true).sum().item()
        correct_10way += (cluster_10_pred == cluster_10_true).sum().item()
        
        total += labels.size(0)
    
    test_loss = total_loss / len(dataloader)
    test_accuracy = correct_top1 / total
    top5_accuracy = correct_top5 / total
    accuracy_2way = correct_2way / total
    accuracy_4way = correct_4way / total
    accuracy_10way = correct_10way / total
    
    return test_loss, test_accuracy, top5_accuracy, accuracy_2way, accuracy_4way, accuracy_10way

In [8]:
def main_train_loop(config, model, train_loader, test_loader, device, image_cluster):
    optimizer = AdamW(model.parameters(), lr=config.lr)
    criterion = MultitaskLoss(alpha=0.7, temp=0.07)
    
    train_losses, train_accs = [], []
    test_losses, test_accs, test_top5, acc_2way, acc_4way, acc_10way = [], [], [], [], [], []
    best_acc = 0.0
    
    for epoch in range(config.epochs):
        train_loss, train_acc = train_model(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            device=device,
            criterion=criterion,
            config=config
        )
        
        test_loss, test_acc, top5_acc, accuracy_2way, accuracy_4way, accuracy_10way = evaluate_model(
            model=model,
            dataloader=test_loader,
            device=device,
            criterion=criterion,
            config=config,
            image_cluster = image_cluster
        )
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        test_top5.append(top5_acc)
        acc_2way.append(accuracy_2way)
        acc_4way.append(accuracy_4way)
        acc_10way.append(accuracy_10way)
        
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_model.pth")
            best_epoch = epoch
        
        # 打印日志
        #print(f"Epoch {epoch+1}/{config.epochs}")
        #print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2%}")
        #print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2%} | Top-5 Acc: {top5_acc:.2%}")
        #print(f"2-way Acc: {accuracy_2way:.2%} | 4-way Acc: {accuracy_4way:.2%} | 10-way Acc: {accuracy_10way:.2%}")
        #print("-" * 60)
        
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train")
    plt.plot(test_losses, label="Test")
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label="Train Acc")
    plt.plot(test_accs, label="Test Acc")
    plt.plot(test_top5, label="Test Top-5")
    plt.title("Accuracy Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.tight_layout()
    plt.close()
    
    return {
        "best_test_acc": best_acc,
        "final_top5_acc": test_top5[-1],
        "train_history": {
            "loss": train_losses,
            "accuracy": train_accs
        },
        "test_history": {
            "loss": test_losses,
            "accuracy": test_accs,
            "top5_accuracy": test_top5,
            "acc_2way": acc_2way,
            "acc_4way": acc_4way,
            "acc_10way": acc_10way
        },
        "best_epoch": best_epoch
    }

In [9]:
date_order = ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              #'012123', 
              '022223', '032123', '042323']
date_order_num = [int(i) for i in date_order]

In [10]:
trigger_time = pd.read_csv("/root/autodl-tmp/trigger_time.csv")
with open("/root/autodl-tmp/image_feature_list.pkl", 'rb') as f:
    image_feature_list = pickle.load(f)

In [15]:
result_dict = {}

test_month = [date_order[14]]
all_spike_inf = pd.read_csv(f"/root/autodl-tmp/closed_loop/spike_042323.tsv", sep="\t")
date_order_temp = date_order[:15]

EP_data_dict = {}
for date in date_order_temp:
    temp = trigger_time[trigger_time['date'] == int(date)]
    temp_spike = all_spike_inf[all_spike_inf['date'] == int(date)]

    for image in range(1, 118):
        num = 1
        spike_train = generate_binned_spiketrains(temp, temp_spike, image)
        for i in range(len(spike_train)):
            EP_data_dict[f"{date}_{image}_{num}"] = [spike_train[i], date, image - 1]
            num += 1

for i in EP_data_dict.keys():
    EP_data_dict[i].append(image_feature_list[EP_data_dict[i][2]])

EP_data_train_dict = {}
EP_data_test_dict = {}

for i in EP_data_dict.keys():
    if EP_data_dict[i][1] in test_month:
        EP_data_test_dict[i] = EP_data_dict[i]
    else:
        EP_data_train_dict[i] = EP_data_dict[i]

current_input_neuron = EP_data_train_dict['021322_100_1'][0].shape[0]



In [16]:
EP_data_train_EP_data = [item[0] for item in EP_data_train_dict.values()]
EP_data_train_image = [item[2] for item in EP_data_train_dict.values()]  
EP_data_train_feature = [item[3] for item in EP_data_train_dict.values()]
EP_data_train_img_path = ['/root/visual_decode/NaturalImages_new_2/' + str(item) + '.jpg' for item in EP_data_train_image]

EP_data_test_EP_data = [item[0] for item in EP_data_test_dict.values()]
EP_data_test_image = [item[2] for item in EP_data_test_dict.values()] 
EP_data_test_feature = [item[3] for item in EP_data_test_dict.values()]
EP_data_test_img_path = ['/root/visual_decode/NaturalImages_new_2/' + str(item) + '.jpg' for item in EP_data_test_image]

train_dataset = EPDataset(
    EP_data_train_EP_data, EP_data_train_image, EP_data_train_feature,EP_data_train_img_path
)
test_dataset = EPDataset(
    EP_data_test_EP_data, EP_data_test_image, EP_data_test_feature, EP_data_test_img_path
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True, drop_last=True)

config = ModelConfig(input_neuron=current_input_neuron)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TimeTransformerConvModel(config).to(device)

image_cluster = pd.read_csv("image_cluster.csv")
# results = main_train_loop(
#     config=config,
#     model=model,
#     train_loader=train_loader,
#     test_loader=test_loader,
#     device=device,
#     image_cluster = image_cluster
# )

In [17]:
config = ModelConfig(input_neuron=18)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TimeTransformerConvModel(config).to(device)


model.load_state_dict(torch.load("/root/visual_decode/best_model.pth"))
features = extract_features(model, train_loader, device)
torch.save(features, "features.pth")

In [18]:
class EEGAdapter(nn.Module):
    def __init__(self, input_dim=1024, output_dim=768, num_tokens=77):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, 2048),  # 扩大维度
            nn.ReLU(),
            nn.Linear(2048, output_dim * num_tokens)  # 输出77*768
        )
        self.output_dim = output_dim
        self.num_tokens = num_tokens

    def forward(self, eeg_vector):
        # eeg_vector: [batch_size, 1024]
        x = self.projection(eeg_vector)  # [batch_size, 77*768]
        x = x.view(-1, self.num_tokens, self.output_dim)  # [batch_size, 77, 768]
        return x

In [19]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
unet = pipe.unet
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder

eeg_adapter = EEGAdapter()

for param in unet.parameters():
    param.requires_grad = False  

for name, param in unet.named_parameters():
    if "attn2" in name: 
        param.requires_grad = True

ImportError: 
StableDiffusionPipeline requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`


In [None]:
import torch.optim as optim
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
optimizer = optim.AdamW(
    list(eeg_adapter.parameters()) + list(unet.parameters()), 
    lr=1e-4
)

for epoch in range(epochs):
    for batch in dataloader:
        # 加载数据
        images = batch["image"]  # [batch, 3, 512, 512]
        eeg_data = batch["eeg"]  # [batch, 1024]

        # 图像编码为潜变量
        latents = pipe.vae.encode(images).latent_dist.sample()
        latents = latents * 0.18215  # 缩放

        # 添加噪声
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (len(images),))
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # 通过适配器生成条件向量
        cond_embeddings = eeg_adapter(eeg_data)

        # UNet预测噪声
        noise_pred = unet(
            noisy_latents, 
            timesteps, 
            encoder_hidden_states=cond_embeddings  # 注入电生理条件
        ).sample

        # 计算损失
        loss = nn.functional.mse_loss(noise_pred, noise)
        print(f"Loss: {loss}")
        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()