In [2]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import pickle
import os
from torch import nn
import torch.nn.functional as F

# === Load embeddings and events ===
with open('/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_softmax.pkl', 'rb') as f:
    image_embeddings_dict = pickle.load(f)

image_embeddings = image_embeddings_dict['natural_scenes']  # (118, D)
image_embeddings_repeated = np.repeat(image_embeddings, 50, axis=0)  # (5900, D)
assert image_embeddings_repeated.shape[0] == 5900

events = np.load('/home/maria/Documents/AllenBrainObservatory/neural_activity_matrices/500860585_neural_responses.npy')  # (N_neurons, 5900)

# === Dataset ===
class NeuronVisionDataset(torch.utils.data.Dataset):
    def __init__(self, image_embeddings, neural_events):
        assert image_embeddings.shape[0] == neural_events.shape[1]
        self.image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32)
        self.neural_events = torch.tensor(neural_events.T, dtype=torch.float32)  # (5900, N_neurons)

    def __len__(self):
        return self.image_embeddings.shape[0]

    def __getitem__(self, idx):
        return {
            "image_embedding": self.image_embeddings[idx],   # (D,)
            "neural_activity": self.neural_events[idx]       # (N_neurons,)
        }

full_dataset = NeuronVisionDataset(image_embeddings_repeated, events)
full_loader = DataLoader(full_dataset, batch_size=4, shuffle=False)  # Small batch size

# === Model ===
class PixelAttentionModel(nn.Module):
    def __init__(self, image_dim, neuron_dim, num_neurons, attention_dim=32):
        super().__init__()
        self.image_dim = image_dim
        self.neuron_dim = neuron_dim
        self.num_neurons = num_neurons
        self.feature_dim = image_dim + neuron_dim

        self.neuron_embeddings = nn.Parameter(torch.randn(num_neurons, neuron_dim))
        self.to_q = nn.Linear(1, attention_dim)
        self.to_k = nn.Linear(1, attention_dim)
        self.to_v = nn.Linear(1, attention_dim)
        self.output_proj = nn.Linear(self.feature_dim * attention_dim, 1)

    def forward(self, image_embedding, neuron_idx):
        B, D = image_embedding.shape
        neuron_idx = neuron_idx.to(torch.long)

        if neuron_idx.ndim == 1:
            K = 1
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, D')
        elif neuron_idx.ndim == 2:
            K = neuron_idx.shape[1]
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, K, D')
        else:
            raise ValueError("neuron_idx must be shape (B,) or (B, K)")

        if K == 1:
            image_exp = image_embedding  # (B, D)
        else:
            image_exp = image_embedding.unsqueeze(1).expand(-1, K, -1)  # (B, K, D)

        combined = torch.cat([image_exp, neuron_emb], dim=-1)  # (B, K, D + D')
        B_flat = B * K if K > 1 else B
        F_feat = self.feature_dim
        x = combined.view(B_flat, F_feat, 1)  # (B*K, F, 1)

        Q = self.to_q(x)
        K_ = self.to_k(x)
        V = self.to_v(x)

        attn_scores = torch.matmul(Q, K_.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_weights, V)  # (B*K, F, A)

        attn_out_flat = attn_out.view(B_flat, -1)
        output = self.output_proj(attn_out_flat).squeeze(-1)
        return output.view(B, K) if K > 1 else output.view(B)

# === Load model ===
image_dim = image_embeddings_repeated.shape[1]
num_neurons = events.shape[0]
neuron_dim = 32
attention_dim = 16

device = torch.device("cpu")
model = PixelAttentionModel(image_dim, neuron_dim, num_neurons, attention_dim)
model.load_state_dict(torch.load("/home/maria/LuckyMouse/notebooks/models/fold_1.pt", map_location=device))
model.to(device)
model.eval()

# === Per-neuron evaluation loop ===
per_neuron_correct = torch.zeros(num_neurons)
per_neuron_total = torch.zeros(num_neurons)

print("🔍 Evaluating each neuron individually...")
with torch.no_grad():
    for neuron_id in tqdm(range(num_neurons)):
        correct = 0
        total = 0

        for batch in full_loader:
            image_emb = batch["image_embedding"]  # (B, D)
            labels = batch["neural_activity"][:, neuron_id]  # (B,)

            neuron_idx = torch.full((image_emb.size(0),), neuron_id, dtype=torch.long)

            logits = model(image_emb, neuron_idx)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            correct += (preds == labels).sum().item()
            total += len(labels)

        per_neuron_correct[neuron_id] = correct
        per_neuron_total[neuron_id] = total

# === Report results ===
overall_accuracy = per_neuron_correct.sum().item() / per_neuron_total.sum().item()
per_neuron_accuracy = (per_neuron_correct / per_neuron_total).numpy()

print(f"\n✅ Overall accuracy: {overall_accuracy:.4f}")

for i, acc in enumerate(per_neuron_accuracy):
    print(f"Neuron {i:3d} accuracy: {acc:.4f}")


🔍 Evaluating each neuron individually...


  1%|▏         | 2/155 [00:57<1:13:29, 28.82s/it]


KeyboardInterrupt: 