In [2]:
import pickle
import numpy as np
import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score

# ─── Load Data ───────────────────────────────────────────────────────────────
dataset_path_dict = {
    "embeddings": "/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl",
    "neural": "/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/hybrid_neural_responses_reduced.npy"
}

with open(dataset_path_dict['embeddings'], "rb") as f:
    embeddings_raw = pickle.load(f)
embeddings = embeddings_raw['natural_scenes']  # shape: (118, 1000)

neural_data = np.load(dataset_path_dict["neural"])  # shape: (neurons, 118)

# ─── PCA ─────────────────────────────────────────────────────────────────────
pca = PCA(n_components=50, whiten=True)
X_all = pca.fit_transform(embeddings)  # shape: (118, 50)

# ─── Train/Test Split ────────────────────────────────────────────────────────
X_train = X_all[:-18]
X_test = X_all[-18:]

n_neurons = neural_data.shape[0]
n_trials = 50
r2_test_scores = []

# ─── MLP with ReLU Output ────────────────────────────────────────────────────
class ReLURegressor(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1),
            nn.ReLU()  # Final output is non-negative
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

# ─── Train and Evaluate ──────────────────────────────────────────────────────
for i in range(n_neurons):
    y_all = np.clip(np.round(neural_data[i]), 0, n_trials)
    y_train = y_all[:-18]
    y_test = y_all[-18:]

    model = ReLURegressor(input_dim=50)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    X_train_torch = torch.tensor(X_train, dtype=torch.float32)
    y_train_torch = torch.tensor(y_train, dtype=torch.float32)
    X_test_torch = torch.tensor(X_test, dtype=torch.float32)

    for epoch in range(300):
        model.train()
        y_pred = model(X_train_torch)
        loss = nn.MSELoss()(y_pred, y_train_torch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        y_pred_test = model(X_test_torch).numpy()
        r2 = r2_score(y_test, y_pred_test)
        r2_test_scores.append(r2)

        print(f"Neuron {i}: R² (counts) = {r2:.4f}")
        print("True counts:", y_test)
        print("Predicted counts:", np.round(y_pred_test, 2))

# Optional: summary
print("\nAverage R² across neurons:", np.nanmean(r2_test_scores))


Neuron 0: R² (counts) = -1.5419
True counts: [1. 5. 2. 1. 2. 2. 3. 2. 0. 3. 3. 2. 5. 5. 0. 2. 1. 0.]
Predicted counts: [0.   0.   1.35 0.12 0.   0.   0.   0.   1.5  0.   1.16 0.   0.72 0.21
 0.   0.   0.   0.  ]
Neuron 1: R² (counts) = -6.6608
True counts: [ 6.  4. 13. 10.  4.  7. 10. 10.  9. 11. 17.  8. 10.  5.  8.  7.  7.  7.]
Predicted counts: [0.   0.   0.44 0.   0.   0.   0.28 0.   0.85 1.24 0.82 0.   0.   2.43
 0.36 0.   1.28 0.  ]
Neuron 2: R² (counts) = -2.6201
True counts: [1. 1. 2. 2. 3. 0. 2. 4. 1. 2. 3. 3. 4. 1. 2. 1. 1. 1.]
Predicted counts: [1.22 1.34 0.   1.5  0.   0.   0.   0.07 1.34 0.   0.   0.   0.   0.
 0.   0.07 0.   0.  ]
Neuron 3: R² (counts) = -0.8440
True counts: [0. 3. 4. 1. 0. 1. 1. 3. 4. 0. 2. 0. 0. 1. 3. 7. 2. 0.]
Predicted counts: [0.   0.   0.2  0.   0.   0.   0.   1.7  0.   1.75 0.   1.48 0.   0.
 0.   0.   0.39 1.18]
Neuron 4: R² (counts) = -0.7444
True counts: [4. 1. 0. 0. 1. 1. 5. 0. 0. 4. 8. 1. 5. 0. 1. 1. 1. 3.]
Predicted counts: [0.   0.42 0.   0. 

KeyboardInterrupt: 