In [1]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset

# Load ViT image embeddings
with open('/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_softmax.pkl', 'rb') as f:
    image_embeddings_dict = pickle.load(f)

# Extract and repeat embeddings
image_embeddings = image_embeddings_dict['natural_scenes']  # shape (118, D)
image_embeddings_repeated = np.repeat(image_embeddings, 50, axis=0)  # (5900, D)
assert image_embeddings_repeated.shape[0] == 5900

# Load calcium event data
events = np.load('/home/maria/Documents/AllenBrainObservatory/neural_activity_matrices/500860585_neural_responses.npy')  # (N_neurons, 5900)

class NeuronVisionDataset(Dataset):
    def __init__(self, image_embeddings, neural_events):
        assert image_embeddings.shape[0] == neural_events.shape[1]
        self.image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32)
        self.neural_events = torch.tensor(neural_events.T, dtype=torch.float32)  # (5900, N_neurons)

    def __len__(self):
        return self.image_embeddings.shape[0]

    def __getitem__(self, idx):
        return {
            "image_embedding": self.image_embeddings[idx],   # (D,)
            "neural_activity": self.neural_events[idx]       # (N_neurons,)
        }


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PixelAttentionModel(nn.Module):
    def __init__(self, image_dim, neuron_dim, num_neurons, attention_dim=32):
        super().__init__()
        self.image_dim = image_dim
        self.neuron_dim = neuron_dim
        self.num_neurons = num_neurons
        self.feature_dim = image_dim + neuron_dim

        self.neuron_embeddings = nn.Parameter(torch.randn(num_neurons, neuron_dim))

        self.to_q = nn.Linear(1, attention_dim)
        self.to_k = nn.Linear(1, attention_dim)
        self.to_v = nn.Linear(1, attention_dim)

        self.output_proj = nn.Linear(self.feature_dim * attention_dim, 1)

    def forward(self, image_embedding, neuron_idx):
        B, D = image_embedding.shape
        neuron_idx = neuron_idx.to(torch.long)

        if neuron_idx.ndim == 1:
            K = 1
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, D')
        elif neuron_idx.ndim == 2:
            K = neuron_idx.shape[1]
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, K, D')
        else:
            raise ValueError("neuron_idx must be shape (B,) or (B, K)")

        # Expand image embedding
        if K == 1:
            image_exp = image_embedding  # (B, D)
        else:
            image_exp = image_embedding.unsqueeze(1).expand(-1, K, -1)  # (B, K, D)

        # Concatenate → (B, K, D + D')
        combined = torch.cat([image_exp, neuron_emb], dim=-1)  # (B, K, D + D')

        # Reshape to treat feature indices as sequence
        B_flat = B * K if K > 1 else B
        feature_len = self.feature_dim
        x = combined.view(B_flat, feature_len, 1)  # (B*K, F, 1)

        # Compute QKV
        Q = self.to_q(x)  # (B*K, F, A)
        K_ = self.to_k(x)
        V = self.to_v(x)

        attn_scores = torch.matmul(Q, K_.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_weights, V)  # (B*K, F, A)

        # Project output
        attn_out_flat = attn_out.view(B_flat, -1)
        output = self.output_proj(attn_out_flat).squeeze(-1)  # (B*K,)

        return output.view(B, K) if K > 1 else output.view(B)


In [4]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

# Parameters
batch_size = 16
neuron_batch_size = 32
epochs = 10
learning_rate = 1e-3

# Dataset and DataLoader
dataset = NeuronVisionDataset(image_embeddings_repeated, events)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model
image_dim = image_embeddings_repeated.shape[1]
num_neurons = events.shape[0]
neuron_dim = 32
attention_dim = 16

model = PixelAttentionModel(image_dim, neuron_dim, num_neurons, attention_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training
for epoch in range(epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    total_examples = 0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        image_emb = batch["image_embedding"].to(device)  # (B, D)
        neural_data = batch["neural_activity"].to(device)  # (B, N)

        B, N = neural_data.shape

        neuron_idx = torch.randint(0, N, (B, neuron_batch_size), device=device)  # (B, K)
        target = torch.gather(neural_data, dim=1, index=neuron_idx)  # (B, K)

        preds = model(image_emb, neuron_idx)  # (B, K)
        loss = F.binary_cross_entropy_with_logits(preds, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * B * neuron_batch_size
        preds_binary = (torch.sigmoid(preds) > 0.5).float()
        correct = (preds_binary == target).sum().item()
        total_correct += correct
        total_examples += B * neuron_batch_size

    avg_loss = total_loss / (len(dataset) * neuron_batch_size)
    accuracy = total_correct / total_examples
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={accuracy:.4f}")


Epoch 1/10: 100%|██████████| 369/369 [00:55<00:00,  6.63it/s]


Epoch 1: Loss=0.2923, Accuracy=0.9276


Epoch 2/10: 100%|██████████| 369/369 [00:57<00:00,  6.45it/s]


Epoch 2: Loss=0.2450, Accuracy=0.9327


Epoch 3/10: 100%|██████████| 369/369 [00:58<00:00,  6.36it/s]


Epoch 3: Loss=0.2408, Accuracy=0.9332


Epoch 4/10: 100%|██████████| 369/369 [00:59<00:00,  6.15it/s]


Epoch 4: Loss=0.2368, Accuracy=0.9335


Epoch 5/10: 100%|██████████| 369/369 [01:04<00:00,  5.75it/s]


Epoch 5: Loss=0.2354, Accuracy=0.9334


Epoch 6/10: 100%|██████████| 369/369 [01:07<00:00,  5.43it/s]


Epoch 6: Loss=0.2342, Accuracy=0.9331


Epoch 7/10: 100%|██████████| 369/369 [01:06<00:00,  5.56it/s]


Epoch 7: Loss=0.2338, Accuracy=0.9328


Epoch 8/10: 100%|██████████| 369/369 [01:07<00:00,  5.46it/s]


Epoch 8: Loss=0.2321, Accuracy=0.9338


Epoch 9/10: 100%|██████████| 369/369 [01:11<00:00,  5.18it/s]


Epoch 9: Loss=0.2346, Accuracy=0.9322


Epoch 10/10: 100%|██████████| 369/369 [01:14<00:00,  4.92it/s]

Epoch 10: Loss=0.2323, Accuracy=0.9332



