In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import numpy as np
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score
from skbio.stats.composition import clr
# ─── Load embeddings ────────────────────────────────────────────────────────
with open("/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/google_vit-base-patch16-224_embeddings_softmax.pkl", "rb") as f:
    embeddings_raw = pickle.load(f)

embeddings = embeddings_raw["natural_scenes"]  # shape: (118, 1000)

# ─── Whitened PCA ───────────────────────────────────────────────────────────
pca = PCA(n_components=50, whiten=True)
X_all = pca.fit_transform(clr(embeddings))  # shape: (118, 50)

# ─── Load neural data ───────────────────────────────────────────────────────
neural_data = np.load("/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/hybrid_neural_responses_reduced.npy")
n_trials = 50
n_neurons, n_stimuli = neural_data.shape
assert n_stimuli == 118, "Mismatch in stimulus count"

# ─── MLP Model Definition ───────────────────────────────────────────────────
class SpikeProbMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        logits = self.net(x)
        probs = torch.sigmoid(logits)
        return probs.squeeze(-1)

# ─── Binomial Negative Log-Likelihood ───────────────────────────────────────
def binomial_nll(probs, counts, n_trials):
    probs = torch.clamp(probs, 1e-5, 1 - 1e-5)
    k = counts
    n = torch.full_like(k, fill_value=n_trials)  # shape match!

    log_binom_coeff = torch.lgamma(n + 1) - torch.lgamma(k + 1) - torch.lgamma(n - k + 1)
    return -torch.mean(log_binom_coeff + k * torch.log(probs) + (n - k) * torch.log(1 - probs))


# ─── Loop over neurons ──────────────────────────────────────────────────────
for i in range(n_neurons):
    # Prepare data
    counts = np.clip(np.round(neural_data[i]), 0, n_trials)
    y_train = counts[:-18]
    y_test = counts[-18:]
    X_train = X_all[:-18]
    X_test = X_all[-18:]

    # Convert to torch
    X_train_torch = torch.tensor(X_train, dtype=torch.float32)
    y_train_torch = torch.tensor(y_train, dtype=torch.float32)
    X_test_torch = torch.tensor(X_test, dtype=torch.float32)

    # Initialize model
    model = SpikeProbMLP(input_dim=100)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Train
    for epoch in range(300):
        model.train()
        probs = model(X_train_torch)
        loss = binomial_nll(probs, y_train_torch, n_trials=n_trials)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate
    model.eval()
    with torch.no_grad():
        probs_test = model(X_test_torch)
        counts_pred = (n_trials * probs_test).numpy()
        r2 = r2_score(y_test, counts_pred)

    print(f"Neuron {i}: R² (counts) = {r2:.4f}")
    print("True counts:     ", np.round(y_test).astype(int))
    print("Predicted counts:", counts_pred.round(2))
    print("-" * 60)


Neuron 0: R² (counts) = -101.1920
True counts:      [1 5 2 1 2 2 3 2 0 3 3 2 5 5 0 2 1 0]
Predicted counts: [18.99 17.96 13.56 20.69 20.3  20.42 18.46 14.47 16.18 14.42 18.22 16.19
  0.76 20.64 20.83 23.78 17.49 14.06]
------------------------------------------------------------
Neuron 1: R² (counts) = -15.4011
True counts:      [ 6  4 13 10  4  7 10 10  9 11 17  8 10  5  8  7  7  7]
Predicted counts: [18.49 19.91 13.88 22.02 20.99 21.21 16.15 20.89 18.58 22.22 13.27 17.68
  7.38 23.59 23.27 21.55 23.57 24.21]
------------------------------------------------------------
Neuron 2: R² (counts) = -139.8453
True counts:      [1 1 2 2 3 0 2 4 1 2 3 3 4 1 2 1 1 1]
Predicted counts: [16.43 14.83 16.18 19.   20.71 16.74 11.44 12.53 13.67 15.28 13.64 18.54
  4.78 13.1  14.69 13.55  8.13 14.55]
------------------------------------------------------------
Neuron 3: R² (counts) = -59.0098
True counts:      [0 3 4 1 0 1 1 3 4 0 2 0 0 1 3 7 2 0]
Predicted counts: [11.81 13.61 11.99 11.16 13.66 13.62

KeyboardInterrupt: 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import numpy as np
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score

# ─── Load embeddings ────────────────────────────────────────────────────────
with open("/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/google_vit-base-patch16-224_embeddings_softmax.pkl", "rb") as f:
    embeddings_raw = pickle.load(f)

embeddings = embeddings_raw["natural_scenes"]  # shape: (118, 1000)

# ─── Whitened PCA ───────────────────────────────────────────────────────────
pca = PCA(n_components=50, whiten=True)
X_all = pca.fit_transform(embeddings)  # shape: (118, 50)

# ─── Load neural data ───────────────────────────────────────────────────────
neural_data = np.load("/home/maria/LuckyMouse2/pixel_transformer_neuro/data/processed/hybrid_neural_responses_reduced.npy")
n_trials = 50
n_all_neurons, n_stimuli = neural_data.shape
assert n_stimuli == 118, "Mismatch in stimulus count"

# ─── Neuron indices to process ──────────────────────────────────────────────
selected_indices = [3548, 4467, 4709, 4823, 10970, 14365, 14511, 14610,
                    25070, 26150, 29911, 32163, 33756, 33897, 35317, 36756]

# Validate indices
valid_indices = [i for i in selected_indices if i < n_all_neurons]

# ─── MLP Model Definition ───────────────────────────────────────────────────
class SpikeProbMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        logits = self.net(x)
        probs = torch.sigmoid(logits)
        return probs.squeeze(-1)

# ─── Binomial Negative Log-Likelihood ───────────────────────────────────────
def binomial_nll(probs, counts, n_trials):
    probs = torch.clamp(probs, 1e-5, 1 - 1e-5)
    k = counts
    n = torch.full_like(k, fill_value=n_trials)
    log_binom_coeff = torch.lgamma(n + 1) - torch.lgamma(k + 1) - torch.lgamma(n - k + 1)
    return -torch.mean(log_binom_coeff + k * torch.log(probs) + (n - k) * torch.log(1 - probs))

# ─── Loop over selected neurons ─────────────────────────────────────────────
for i in valid_indices:
    counts = np.clip(np.round(neural_data[i]), 0, n_trials)
    y_train = counts[:-18]
    y_test = counts[-18:]
    X_train = X_all[:-18]
    X_test = X_all[-18:]

    X_train_torch = torch.tensor(X_train, dtype=torch.float32)
    y_train_torch = torch.tensor(y_train, dtype=torch.float32)
    X_test_torch = torch.tensor(X_test, dtype=torch.float32)

    model = SpikeProbMLP(input_dim=50)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(300):
        model.train()
        probs = model(X_train_torch)
        loss = binomial_nll(probs, y_train_torch, n_trials=n_trials)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluation
    model.eval()
    with torch.no_grad():
        probs_test = model(X_test_torch)
        counts_pred = (n_trials * probs_test).numpy()
        r2 = r2_score(y_test, counts_pred)

    print(f"Neuron {i}: R² (counts) = {r2:.4f}")
    print("True counts:     ", np.round(y_test).astype(int))
    print("Predicted counts:", counts_pred.round(2))
    print("-" * 60)


Neuron 3548: R² (counts) = -4.5782
True counts:      [13  7  0  1  0  4  3  9  1  1  5  1  0  2  2  1 14  3]
Predicted counts: [ 1.94  4.25  2.84  9.48  1.57 11.09  0.11  2.81 19.08  5.25  3.6   3.72
  2.07 23.01 23.29  2.38  0.29  7.42]
------------------------------------------------------------
Neuron 4467: R² (counts) = -10.8536
True counts:      [3 1 4 2 4 1 1 3 3 3 2 2 1 4 7 1 0 7]
Predicted counts: [ 0.56  1.98  2.46  5.25  0.43  7.37  0.14  1.3  19.84  4.1   1.58  1.52
  2.17 18.27 20.81  0.74  0.05  2.53]
------------------------------------------------------------
Neuron 4709: R² (counts) = 0.1703
True counts:      [ 4  1  3  3  2 38  2  3  1  4  1  1  2  7  2  2  0  1]
Predicted counts: [ 0.97  2.86  1.31  7.27  0.84 24.35  0.15  1.54 21.36  5.15  2.23  2.24
  9.13 16.12 17.86  1.27  0.8   2.59]
------------------------------------------------------------
Neuron 4823: R² (counts) = -3.8791
True counts:      [ 0  3  0  2  1  2 18  1  2  1  1  2  4  2  1  2  1  5]
Predicted co