In [2]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset

# Load ViT image embeddings
with open('/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_softmax.pkl', 'rb') as f:
    image_embeddings_dict = pickle.load(f)

# Extract and repeat embeddings
image_embeddings = image_embeddings_dict['natural_scenes']  # shape (118, D)
image_embeddings_repeated = np.repeat(image_embeddings, 50, axis=0)  # (5900, D)
assert image_embeddings_repeated.shape[0] == 5900

# Load calcium event data
events = np.load('/home/maria/Documents/AllenBrainObservatory/neural_activity_matrices/500860585_neural_responses.npy')  # (N_neurons, 5900)

class NeuronVisionDataset(Dataset):
    def __init__(self, image_embeddings, neural_events):
        assert image_embeddings.shape[0] == neural_events.shape[1]
        self.image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32)
        self.neural_events = torch.tensor(neural_events.T, dtype=torch.float32)  # (5900, N_neurons)

    def __len__(self):
        return self.image_embeddings.shape[0]

    def __getitem__(self, idx):
        return {
            "image_embedding": self.image_embeddings[idx],   # (D,)
            "neural_activity": self.neural_events[idx]       # (N_neurons,)
        }

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PixelAttentionModel(nn.Module):
    def __init__(self, image_dim, neuron_dim, num_neurons, attention_dim=32):
        super().__init__()
        self.image_dim = image_dim
        self.neuron_dim = neuron_dim
        self.num_neurons = num_neurons
        self.feature_dim = image_dim + neuron_dim

        self.neuron_embeddings = nn.Parameter(torch.randn(num_neurons, neuron_dim))

        self.to_q = nn.Linear(1, attention_dim)
        self.to_k = nn.Linear(1, attention_dim)
        self.to_v = nn.Linear(1, attention_dim)

        self.output_proj = nn.Linear(self.feature_dim * attention_dim, 1)

    def forward(self, image_embedding, neuron_idx):
        B, D = image_embedding.shape
        neuron_idx = neuron_idx.to(torch.long)

        if neuron_idx.ndim == 1:
            K = 1
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, D')
        elif neuron_idx.ndim == 2:
            K = neuron_idx.shape[1]
            neuron_emb = self.neuron_embeddings[neuron_idx]  # (B, K, D')
        else:
            raise ValueError("neuron_idx must be shape (B,) or (B, K)")

        # Expand image embedding
        if K == 1:
            image_exp = image_embedding  # (B, D)
        else:
            image_exp = image_embedding.unsqueeze(1).expand(-1, K, -1)  # (B, K, D)

        # Concatenate → (B, K, D + D')
        combined = torch.cat([image_exp, neuron_emb], dim=-1)  # (B, K, D + D')

        # Reshape to treat feature indices as sequence
        B_flat = B * K if K > 1 else B
        feature_len = self.feature_dim
        x = combined.view(B_flat, feature_len, 1)  # (B*K, F, 1)

        # Compute QKV
        Q = self.to_q(x)  # (B*K, F, A)
        K_ = self.to_k(x)
        V = self.to_v(x)

        attn_scores = torch.matmul(Q, K_.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_weights, V)  # (B*K, F, A)

        # Project output
        attn_out_flat = attn_out.view(B_flat, -1)
        output = self.output_proj(attn_out_flat).squeeze(-1)  # (B*K,)

        return output.view(B, K) if K > 1 else output.view(B)


In [None]:
from sklearn.model_selection import KFold
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

import os

# Make sure models directory exists
os.makedirs("models", exist_ok=True)

# Constants
num_folds = 10
batch_size = 16
neuron_batch_size = 32
epochs = 10
learning_rate = 1e-3

# Step 1: image-level split
num_images = 118
trials_per_image = 50
image_indices = np.arange(num_images)  # [0, 1, ..., 117]
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

fold_results = []

for fold, (train_img_idx, test_img_idx) in enumerate(kf.split(image_indices)):
    print(f"\n🔁 Fold {fold + 1}/{num_folds}")

    # Map image indices to trial indices
    train_trial_idx = np.concatenate([np.arange(i * trials_per_image, (i + 1) * trials_per_image) for i in train_img_idx])
    test_trial_idx = np.concatenate([np.arange(i * trials_per_image, (i + 1) * trials_per_image) for i in test_img_idx])

    # Build datasets
    train_dataset = NeuronVisionDataset(
        image_embeddings_repeated[train_trial_idx],
        events[:, train_trial_idx]
    )
    test_dataset = NeuronVisionDataset(
        image_embeddings_repeated[test_trial_idx],
        events[:, test_trial_idx]
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    image_dim = image_embeddings_repeated.shape[1]
    num_neurons = events.shape[0]
    neuron_dim = 32
    attention_dim = 16

    model = PixelAttentionModel(image_dim, neuron_dim, num_neurons, attention_dim)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            image_emb = batch["image_embedding"].to(device)
            neural_data = batch["neural_activity"].to(device)

            B, N = neural_data.shape
            neuron_idx = torch.randint(0, N, (B, neuron_batch_size), device=device)
            target = torch.gather(neural_data, dim=1, index=neuron_idx)

            preds = model(image_emb, neuron_idx)
            loss = F.binary_cross_entropy_with_logits(preds, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluation
    model.eval()
    total_correct = 0
    total_examples = 0
    with torch.no_grad():
        for batch in test_loader:
            image_emb = batch["image_embedding"].to(device)
            neural_data = batch["neural_activity"].to(device)

            B, N = neural_data.shape
            neuron_idx = torch.randint(0, N, (B, neuron_batch_size), device=device)
            target = torch.gather(neural_data, dim=1, index=neuron_idx)

            preds = model(image_emb, neuron_idx)
            preds_binary = (torch.sigmoid(preds) > 0.5).float()

            correct = (preds_binary == target).sum().item()
            total_correct += correct
            total_examples += B * neuron_batch_size

    fold_accuracy = total_correct / total_examples
    print(f"✅ Fold {fold + 1} Accuracy: {fold_accuracy:.4f}")
    fold_results.append(fold_accuracy)

    model_path = f"models/fold_{fold + 1}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"💾 Saved model to {model_path}")

# Summary
mean_acc = np.mean(fold_results)
std_acc = np.std(fold_results)
print(f"\n🏁 10-fold CV complete. Mean Accuracy: {mean_acc:.4f}, Std: {std_acc:.4f}")



🔁 Fold 1/10
