In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Predictive-cvPCA (ridge prediction ‚Üí cross-validated SVD on test stimuli)
Author: Maria + Pl√§ku üêæ
"""

import os, pickle
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import KFold
from scipy.special import softmax
from skbio.stats.composition import clr
import matplotlib.pyplot as plt

# -----------------------------
# CONFIG
# -----------------------------
VIT_PATH    = '/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl'
NEURAL_PATH = '/home/maria/LuckyMouse/pixel_transformer_neuro/data/processed/hybrid_neural_responses.npy'
AREAS_PATH  = '/home/maria/MITNeuralComputation/visualization/brain_area.npy'

AREAS       = ["VISp","VISl","VISrl","VISal","VISam","VISpm"]
N_IMAGES, N_TRIALS = 118, 50
VAR_VIT, VAR_BRAIN = 0.90, 0.90
K_OUTER     = 5           # stimulus folds
N_BOOT, N_NULL = 500, 500
ALPHAS = np.logspace(-4,3,20)
RNG_SEED    = 42
OUTDIR      = "results_predictive_cvpca"
os.makedirs(OUTDIR, exist_ok=True)
rng = np.random.default_rng(RNG_SEED)

# -----------------------------
# Helpers
# -----------------------------
def vit_pcs(vit_logits, var=0.9):
    Xv = clr(softmax(np.asarray(vit_logits), axis=1) + 1e-12)
    pfull = PCA().fit(Xv)
    n = np.searchsorted(np.cumsum(pfull.explained_variance_ratio_), var)+1
    p = PCA(n_components=n).fit(Xv)
    Zv = p.transform(Xv)
    Zv -= Zv.mean(0)
    Zv /= Zv.std(0)+1e-8
    return Zv, p, n

def repeat_split(dat):
    idx = np.arange(N_TRIALS)
    tr, te = idx[::2], idx[1::2]
    Xe = dat[:, :, tr, :].mean(axis=(2,3))
    Xo = dat[:, :, te, :].mean(axis=(2,3))
    return Xe, Xo

def pca_train_proj(Xe, img_tr, Xo, var=0.9):
    pfull = PCA().fit(Xe[:, img_tr].T)
    n = np.searchsorted(np.cumsum(pfull.explained_variance_ratio_), var)+1
    p = PCA(n_components=n).fit(Xe[:, img_tr].T)
    Ze = (Xe.T - p.mean_) @ p.components_.T
    Zo = (Xo.T - p.mean_) @ p.components_.T
    Ze -= Ze.mean(0); Zo -= Zo.mean(0)
    return Ze, Zo, p, n

def cvpca(U_pred, U_true):
    C = (U_pred.T @ U_true)/U_pred.shape[0]
    U,S,Vt = np.linalg.svd(C, full_matrices=False)
    return U,S,Vt.T

# -----------------------------
# Load data
# -----------------------------
with open(VIT_PATH,'rb') as f: vit_logits = pickle.load(f)['natural_scenes']
Zv, vit_p, vit_n = vit_pcs(vit_logits, VAR_VIT)
print(f"ViT PCs covering {VAR_VIT*100:.0f}% variance: {vit_n}")

dat = np.load(NEURAL_PATH, mmap_mode='r')
areas = np.load(AREAS_PATH, allow_pickle=True)

# -----------------------------
# Main loop over brain areas
# -----------------------------
for area in AREAS:
    mask = (areas==area)
    if not np.any(mask): 
        print(f"[WARN] no data for {area}")
        continue

    print(f"\n=== Area: {area} ===")
    dA = dat[mask]
    n_neu, n_total = dA.shape
    n_time = n_total // (N_IMAGES*N_TRIALS)
    dA = dA.reshape(n_neu, N_IMAGES, N_TRIALS, n_time)
    Xe, Xo = repeat_split(dA)

    kf = KFold(n_splits=K_OUTER, shuffle=True, random_state=RNG_SEED)
    stim_idx = np.arange(N_IMAGES)

    fold_Spred, fold_R2, fold_pperm = [], [], []

    for fold,(i_tr,i_te) in enumerate(kf.split(stim_idx),1):
        img_tr, img_te = stim_idx[i_tr], stim_idx[i_te]

        # Brain PCA on train stimuli (fit)
        Ze_all, Zo_all, p, nbrain = pca_train_proj(Xe, img_tr, Xo, VAR_BRAIN)

        # Ridge regression from ViT‚Üíbrain on TRAIN stimuli
        ridge = RidgeCV(alphas=ALPHAS, fit_intercept=False)
        ridge.fit(Zv[img_tr], Ze_all[img_tr])
        Zhat_test = ridge.predict(Zv[img_te])
        Ztrue_test = Zo_all[img_te]

        # predictive cvPCA: SVD of predicted vs true
        _, S_pred, _ = cvpca(Zhat_test, Ztrue_test)
        fold_Spred.append(S_pred)

        # classic R^2 baseline
        ss_res = np.sum((Ztrue_test-Zhat_test)**2,axis=0)
        ss_tot = np.sum((Ztrue_test-Ztrue_test.mean(0))**2,axis=0)
        R2 = 1-ss_res/ss_tot
        fold_R2.append(R2.mean())

        # permutation null (ViT‚Üíbrain mapping shuffled)
        null_S = np.zeros((N_NULL, len(S_pred)))
        for b in range(N_NULL):
            perm = rng.permutation(len(img_te))
            _, S_n, _ = cvpca(Zhat_test[perm], Ztrue_test)
            null_S[b,:len(S_n)] = S_n
        null_mean = null_S.mean(0)
        p_perm = np.mean(S_pred[0] <= null_S[:,0])  # crude p for first component
        fold_pperm.append(p_perm)

        # save fold plot
        x=np.arange(1,len(S_pred)+1)
        plt.figure(figsize=(6,4))
        plt.plot(x,S_pred,'o-',label='Predictive œÉ')
        plt.plot(x,null_mean,'--',label='Null mean')
        lo,hi=np.percentile(null_S,[2.5,97.5],axis=0)
        plt.fill_between(x,lo,hi,alpha=0.2,color='gray')
        plt.title(f"{area} fold {fold}  alpha={ridge.alpha_:.3g}  R¬≤={R2.mean():.3f}")
        plt.xlabel("Component"); plt.ylabel("œÉ_pred")
        plt.legend(); plt.tight_layout()
        plt.savefig(os.path.join(OUTDIR,f"{area}_fold{fold}_predcvpca.png"),dpi=150)
        plt.close()

        np.savez(os.path.join(OUTDIR,f"{area}_fold{fold}_predcvpca.npz"),
                 S_pred=S_pred, null_S=null_S, alpha=ridge.alpha_, R2=R2, img_tr=img_tr, img_te=img_te)

    # ---- Across-fold summary ----
    L = min(len(s) for s in fold_Spred)
    S_stack = np.stack([s[:L] for s in fold_Spred],axis=0)
    S_mean = S_stack.mean(0)
    print(f"[Area {area}] mean R¬≤={np.mean(fold_R2):.3f}  first œÉ_pred={S_mean[0]:.3f}")
    print(f"[Area {area}] permutation p (median)‚âà{np.median(fold_pperm):.3g}")

    np.savez(os.path.join(OUTDIR,f"{area}_summary.npz"),
             S_mean=S_mean, R2_mean=np.mean(fold_R2), pperm=np.median(fold_pperm))


ViT PCs covering 90% variance: 44

=== Area: VISp ===
[Area VISp] mean R¬≤=-0.060  first œÉ_pred=0.117
[Area VISp] permutation p (median)‚âà0.132

=== Area: VISl ===
[Area VISl] mean R¬≤=-0.084  first œÉ_pred=0.115
[Area VISl] permutation p (median)‚âà0.084

=== Area: VISrl ===
[Area VISrl] mean R¬≤=-0.084  first œÉ_pred=0.015
[Area VISrl] permutation p (median)‚âà0.61

=== Area: VISal ===
[Area VISal] mean R¬≤=-0.066  first œÉ_pred=0.049
[Area VISal] permutation p (median)‚âà0.116

=== Area: VISam ===
[Area VISam] mean R¬≤=-0.068  first œÉ_pred=0.024
[Area VISam] permutation p (median)‚âà0.078

=== Area: VISpm ===
[Area VISpm] mean R¬≤=-0.055  first œÉ_pred=0.044
[Area VISpm] permutation p (median)‚âà0.056
