In [1]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam, AdamW

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'
from itertools import combinations

import clip
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import tqdm

from einops.layers.torch import Rearrange, Reduce

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
import csv
from torch import Tensor
import itertools
import math
import re
import numpy as np
import argparse
import pickle


from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from sklearn.model_selection import train_test_split
import pandas as pd

In [2]:
class ModelConfig:
    def __init__(self,
                 input_neuron=25,        
                 time_bins=20,          
                 d_model = 150,          
                 nhead=10,                
                 num_transformer_layers=1, 
                 conv_channels=64,      
                 num_conv_blocks=3,      
                 num_classes=117,        
                 residual_dims=[256, 512, 1024], 
                 use_positional_encoding=True,  
                 dim_feedforward_ratio=4,      
                 activation='relu',
                 use_neuron_masking=True,  
                 mask_ratio=0,
                 mask_replacement='zero'):
        
        # Transformer 
        self.transformer = {
            'd_model': d_model,
            'nhead': nhead,
            'num_layers': num_transformer_layers,
            'dim_feedforward': d_model * dim_feedforward_ratio,
            'activation': activation
        }
        
        # cnn
        self.convolution = {
            'channels': conv_channels,
            'num_blocks': num_conv_blocks,
            'kernel_size': (3, 3),
            'pool_size': (2, 2)
        }
        
        # resnet
        self.residual = {
            'dims': residual_dims,
            'skip_connection': True
        }
        
        self.masking = {
            'enabled': use_neuron_masking,
            'ratio': mask_ratio,
            'replacement': mask_replacement
        }

        self.input_dim = input_neuron
        self.time_steps = time_bins
        self.num_classes = num_classes
        self.positional_encoding = use_positional_encoding
        self.lr = 2e-4
        self.epochs = 200

In [3]:

class NeuronMasker(nn.Module):
    def __init__(self, mask_ratio=0.15, replacement='random'):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.replacement = replacement
        
    def forward(self, x):
        if self.training:
            if x is None:
                raise ValueError("Input tensor x is None")
                
            batch_size, seq_len, feat_dim = x.shape
            mask = torch.rand(batch_size, 1, feat_dim, device=x.device) < self.mask_ratio
            mask = mask.expand_as(x)
            
            if self.replacement == 'zero':
                x_masked = x.masked_fill(mask, 0)
            elif self.replacement == 'random':
                random_values = torch.randn_like(x) * 0.02
                x_masked = x.masked_scatter(mask, random_values)
            else:
                raise ValueError(f"Invalid replacement: {self.replacement}")
            
            return x_masked 
        else:
            return x
        
class ResidualLinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        self.activation = nn.GELU()
        self.downsample = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()

    def forward(self, x):
        residual = self.downsample(x)
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        return x + residual

class TimeTransformerConvModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self.input_proj = nn.Linear(config.input_dim, config.transformer['d_model'])
        self.pos_encoder = PositionalEncoding(config.transformer['d_model']) if config.positional_encoding else nn.Identity()
        
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=config.transformer['d_model'],
            nhead=config.transformer['nhead'],
            dim_feedforward=config.transformer['dim_feedforward'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, config.transformer['num_layers'])
        
        self.conv_blocks = nn.Sequential()
        in_channels = 1
        for _ in range(config.convolution['num_blocks']):
            self.conv_blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, config.convolution['channels'], 
                            kernel_size=config.convolution['kernel_size'], padding='same'),
                    nn.BatchNorm2d(config.convolution['channels']),
                    nn.ELU(),
                    nn.MaxPool2d(kernel_size=config.convolution['pool_size'])
                )
            )
            in_channels = config.convolution['channels']
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(config.convolution['channels'], config.num_classes)
        
        self.residual_layers = nn.Sequential()
        current_dim = config.convolution['channels']
        for dim in config.residual['dims']:
            self.residual_layers.append(ResidualLinearBlock(current_dim, dim))
            current_dim = dim
        if current_dim != 1024:
            self.residual_layers.append(nn.Linear(current_dim, 1024))
            self.residual_layers.append(nn.LayerNorm(1024))


        self.masker = NeuronMasker(
            mask_ratio=self.config.masking['ratio'],
            replacement=self.config.masking['replacement']
        )

    def forward(self, x):
        x = self.masker(x)  # [B, T, D]
        #print(x)
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        
        x = x.unsqueeze(1)
        x = self.conv_blocks(x)
        x = self.adaptive_pool(x)
        x = x.flatten(1)
        
        logits = self.classifier(x)
        features = self.residual_layers(x)
        
        return logits, features

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1)]
        return x

In [4]:
class MultitaskLoss(nn.Module):
    def __init__(self, alpha=0, temp=0.07):
        super().__init__()
        self.alpha = alpha      # 分类损失权重
        self.temp = temp
        self.ce_loss = nn.CrossEntropyLoss()
        self.temp = temp
        
        self.ce_loss = nn.CrossEntropyLoss()
    
    def contrastive_loss(self, h_neuro, h_img):
        h_neuro = F.normalize(h_neuro, dim=1) + 1e-10
        h_img = F.normalize(h_img, dim=1) + 1e-10
        
        logits_ab = torch.matmul(h_neuro, h_img.T) / self.temp
        logits_ba = torch.matmul(h_img, h_neuro.T) / self.temp
        
        labels = torch.arange(h_neuro.size(0), device=h_neuro.device)
        loss_ab = F.cross_entropy(logits_ab, labels)
        loss_ba = F.cross_entropy(logits_ba, labels)
        
        return (loss_ab + loss_ba) / 2
    
    def forward(self, logits, labels, img_feature, features):
        loss_cls = self.ce_loss(logits, labels)
        loss_cont = self.contrastive_loss(features, img_feature)
        total_loss = self.alpha * loss_cls + (1 - self.alpha) * loss_cont
        return total_loss

In [5]:
def train_model(model, dataloader, optimizer, device, criterion, config):
    model.train()
    total_loss = 0.0
    
    for batch_idx, (neuro, labels, img_feature) in enumerate(dataloader):
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)

        optimizer.zero_grad()
        
        logits, features = model(neuro)
        loss = criterion(logits, labels, img_feature, features)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), 
            max_norm=3.0,                   
            norm_type=2.0                   
        )
        optimizer.step()
        
        total_loss += loss.item()
        
    train_loss = total_loss / len(dataloader)
    return train_loss

@torch.no_grad()
def evaluate_model(model, dataloader, device, criterion, config):
    model.eval()
    total_loss = 0.0
    all_features = []
    all_img_features = []
    
    for neuro, labels, img_feature in dataloader:
        neuro = neuro.to(device)
        labels = labels.to(device)
        img_feature = img_feature.to(device)
        
        logits, features = model(neuro)
        loss = criterion(logits, labels, img_feature, features)
        total_loss += loss.item()
        
        all_features.append(F.normalize(features, dim=1))
        all_img_features.append(F.normalize(img_feature, dim=1))
    
    features_mat = torch.cat(all_features, dim=0)
    img_mat = torch.cat(all_img_features, dim=0)
    
    sims = torch.matmul(features_mat, img_mat.t())
    N = sims.size(0)
    indices = torch.arange(N, device=sims.device)
    
    # neuro -> image 排名
    ranks_ab = torch.argsort(sims, dim=1, descending=True)
    pos_ab = torch.argmax((ranks_ab == indices.view(-1, 1)).int(), dim=1) + 1
    
    # image -> neuro 排名
    ranks_ba = torch.argsort(sims.t(), dim=1, descending=True)
    pos_ba = torch.argmax((ranks_ba == indices.view(-1, 1)).int(), dim=1) + 1
    
    def recall_at(k, ranks):
        return (ranks <= k).float().mean().item()
    
    r1 = 0.5 * (recall_at(1, pos_ab) + recall_at(1, pos_ba))
    r5 = 0.5 * (recall_at(5, pos_ab) + recall_at(5, pos_ba))
    r10 = 0.5 * (recall_at(10, pos_ab) + recall_at(10, pos_ba))
    
    mrr = 0.5 * ((1.0 / pos_ab.float()).mean().item() + (1.0 / pos_ba.float()).mean().item())
    median_rank = 0.5 * (pos_ab.float().median().item() + pos_ba.float().median().item())
    
    val_loss = total_loss / len(dataloader)
    
    return {
        "val_loss": val_loss,
        "recall@1": r1,
        "recall@5": r5,
        "recall@10": r10,
        "mrr": mrr,
        "median_rank": median_rank
    }

In [6]:
def main_train_loop(config, model, train_loader, test_loader, device, image_cluster=None):
    optimizer = AdamW(model.parameters(), lr=config.lr)
    criterion = MultitaskLoss(alpha=0.0, temp=0.07)
    
    train_losses = []
    val_losses = []
    recalls1 = []
    recalls5 = []
    recalls10 = []
    mrrs = []
    med_ranks = []
    best_r1 = 0.0
    
    for epoch in range(config.epochs):
        train_loss = train_model(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            device=device,
            criterion=criterion,
            config=config
        )
        
        metrics = evaluate_model(
            model=model,
            dataloader=test_loader,
            device=device,
            criterion=criterion,
            config=config
        )
        
        train_losses.append(train_loss)
        val_losses.append(metrics["val_loss"])
        recalls1.append(metrics["recall@1"])
        recalls5.append(metrics["recall@5"])
        recalls10.append(metrics["recall@10"])
        mrrs.append(metrics["mrr"])
        med_ranks.append(metrics["median_rank"])
        
        if metrics["recall@1"] > best_r1:
            best_r1 = metrics["recall@1"]
            torch.save(model, "best_model_VISp.pth")
            best_epoch = epoch
        
        print(f"Epoch {epoch+1}/{config.epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {metrics['val_loss']:.4f}")
        print(f"Recall@1: {metrics['recall@1']:.4f} | Recall@5: {metrics['recall@5']:.4f} | Recall@10: {metrics['recall@10']:.4f}")
        print(f"MRR: {metrics['mrr']:.4f} | Median Rank: {metrics['median_rank']:.2f}")
        print("-" * 60)
        
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(recalls1, label="Recall@1")
    plt.plot(recalls5, label="Recall@5")
    plt.plot(recalls10, label="Recall@10")
    plt.title("Retrieval Metrics")
    plt.xlabel("Epoch")
    plt.ylabel("Score")
    plt.legend()
    
    plt.tight_layout()
    plt.close()
    
    return {
        "best_recall@1": best_r1,
        "train_history": {
            "loss": train_losses
        },
        "val_history": {
            "loss": val_losses,
            "recall@1": recalls1,
            "recall@5": recalls5,
            "recall@10": recalls10,
            "mrr": mrrs,
            "median_rank": med_ranks
        },
        "best_epoch": best_epoch
    }

In [7]:
with open("activity_dict_region.pkl", 'rb') as f:
    activity_dict = pickle.load(f)

In [8]:
class EPDataset(Dataset):
    def __init__(self, EP_data, labels, features):  
        self.EP_data = EP_data
        self.labels = labels
        self.features = features
        #self.max_pool = nn.MaxPool1d(kernel_size=5, stride=5)

    def __len__(self):
        return len(self.EP_data)
        
    def __getitem__(self, idx):
        EP_tensor = torch.tensor(self.EP_data[idx], dtype=torch.float32) 
        label = torch.tensor(int(self.labels[idx]), dtype=torch.long)
        feature = torch.tensor(self.features[idx], dtype=torch.float32)
        
        return EP_tensor.T, label, feature

In [13]:
results = {}
for region in [#'VISam', 
               'VISp'#, 'VISl', 'VISrl', 'VISpm', 'VISal'
               ]:
    print(f'Start region: {region}')

    session_list = list(activity_dict[region].keys())
    for session in session_list:
        if len(activity_dict[region][session].keys()) != 5850:
            session_list.remove(session)
        
    activity_dict_all = {}
    for image in activity_dict[region][session].keys():
        temp = pd.DataFrame()
        for session in session_list:
            temp = pd.concat((temp, pd.DataFrame(activity_dict[region][session][image])), axis=0)
        
        global_min = temp.min()
        global_max = temp.max()
        activity_dict_all[image] = (temp - global_min) / (global_max - global_min + 1e-8)
    image_feature = pd.read_csv("/media/ubuntu/sda/neuropixels/visual_decode/image_feature.csv", index_col=0)
    
    EP_data = []
    labels = []
    features = []

    for image in activity_dict_all.keys():
        if "117" not in image:
            EP_data.append(np.array(activity_dict_all[image]))
            image = int(image.split("_")[0])
            labels.append(image)
            features.append(np.array(image_feature.iloc[image, :]))

    from torch.utils.data import DataLoader, Dataset, random_split

    dataset = EPDataset(EP_data=EP_data, labels=labels, features=features)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    batch_size = 1024
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    config = ModelConfig(input_neuron=EP_data[0].shape[0])
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TimeTransformerConvModel(config).to(device)

    image_cluster = pd.read_csv("image_cluster.csv")
    print(f'Start train: {region}')
    results[region] = main_train_loop(
        config=config,
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        device=device,
        image_cluster = image_cluster
    )


Start region: VISp
Start train: VISp
Epoch 1/200
Train Loss: 6.8286 | Val Loss: 5.9872
Recall@1: 0.0009 | Recall@5: 0.0043 | Recall@10: 0.0111
MRR: 0.0072 | Median Rank: 539.00
------------------------------------------------------------
Epoch 2/200
Train Loss: 6.8026 | Val Loss: 5.9735
Recall@1: 0.0030 | Recall@5: 0.0081 | Recall@10: 0.0137
MRR: 0.0104 | Median Rank: 473.00
------------------------------------------------------------
Epoch 3/200
Train Loss: 6.7211 | Val Loss: 5.9480
Recall@1: 0.0017 | Recall@5: 0.0077 | Recall@10: 0.0141
MRR: 0.0104 | Median Rank: 437.00
------------------------------------------------------------
Epoch 4/200
Train Loss: 6.5415 | Val Loss: 5.9545
Recall@1: 0.0009 | Recall@5: 0.0094 | Recall@10: 0.0171
MRR: 0.0104 | Median Rank: 405.00
------------------------------------------------------------
Epoch 5/200
Train Loss: 6.3350 | Val Loss: 5.9590
Recall@1: 0.0013 | Recall@5: 0.0064 | Recall@10: 0.0154
MRR: 0.0103 | Median Rank: 405.50
-------------------

In [10]:
# plt.figure(figsize=(4, 3), dpi=300)

# for region in results.keys():
#     sns.lineplot(
#         data=results[region]['test_history']['accuracy'],
#         label= region
#     )

# plt.ylabel('Accuracy') 
# plt.xlabel('Epoch')  
# plt.tight_layout()  


In [11]:
model.eval()
train_features = []
train_labels = []
for neuro, labels, img_feature in train_loader:
    neuro = neuro.to(device)
    labels = labels.to(device)
    img_feature = img_feature.to(device)
    
    logits, features = model(neuro)

    train_labels.extend(labels.cpu().detach().numpy())
    train_features.append(features.cpu().detach().numpy())

train_features = np.vstack(train_features)
train_labels = np.array(train_labels)

eval_features = []
eval_labels = []
for neuro, labels, img_feature in test_loader:
    neuro = neuro.to(device)
    labels = labels.to(device)
    img_feature = img_feature.to(device)
    
    logits, features = model(neuro)

    eval_labels.extend(labels.cpu().detach().numpy())
    eval_features.append(features.cpu().detach().numpy())

eval_features = np.vstack(eval_features)
eval_labels = np.array(eval_labels)

In [12]:
train_features = np.concat((train_features, train_labels.reshape(-1, 1)), axis=1)
eval_features = np.concat((eval_features, eval_labels.reshape(-1, 1)), axis = 1)

AttributeError: module 'numpy' has no attribute 'concat'

In [None]:
np.save('VISp_train_features.npy', train_features)
np.save('VISp_eval_features.npy', eval_features)