In [1]:
#!/usr/bin/env python3
"""
Two-stage cvPCA pipeline with an initial PCA bottleneck:

1) ViT:
   - Softmax + CLR on logits for natural scenes
   - PCA â†’ keep PCs up to VAR_CUTOFF variance

2) Neural:
   (a) Preprocess:
       - Load deconvolved responses
       - Select AREA_NAME
       - Reshape to (neurons Ã— images Ã— trials Ã— time)
       - Compute stimulus-locked means
   (b) PCA on stimulus-locked means (all trials):
       - Xa = R_all.T  (images Ã— neurons)
       - PCA â†’ keep PCs up to VAR_CUTOFF variance
       - Get neural PC scores: Xa_pca (images Ã— n_neural_pca)
   (c) cvPCA in neural PC space:
       - Split trials into even/odd â†’ R_even, R_odd
       - Project to PCA space â†’ Xe_pca, Xo_pca
       - cvPCA (Pachitariu-style) on these PC scores
       - Get neural cvPCA basis inside PCA space + shared variances
       - Get stimulus-locked neural cvPCA scores Zb_cv (images Ã— n_cvpca)

3) Cross-covariance SVD:
   - Between ViT PC scores Zv and neural cvPCA scores Zb_cv
   - SVD of cross-covariance â†’ shared directions U_cv, V_cv

4) Variance explained in original neural PC data:
   - Map cross-covariance components back into the original neural PCA space
   - Compute fraction of variance in Xa_pca explained by those components
"""

import numpy as np
import pickle
from sklearn.decomposition import PCA
from scipy.special import softmax
from skbio.stats.composition import clr

# ---------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------
VIT_PATH    = '/home/maria/Documents/HuggingMouseData/MouseViTEmbeddings/google_vit-base-patch16-224_embeddings_logits.pkl'
NEURAL_PATH = '/home/maria/LuckyMouse/pixel_transformer_neuro/data/processed/hybrid_neural_responses.npy'
AREAS_PATH  = '/home/maria/MITNeuralComputation/visualization/brain_area.npy'

AREA_NAME   = 'VISp'
N_IMAGES    = 118
N_TRIALS    = 50
VAR_CUTOFF  = 0.90
RANDOM_SEED = 42

rng = np.random.default_rng(RANDOM_SEED)

# ===============================================================
# STEP 1: ViT preprocessing (softmax + CLR + PCA)
# ===============================================================
print("ðŸ”¹ Loading ViT embeddings...")
with open(VIT_PATH, 'rb') as f:
    vit_logits = pickle.load(f)['natural_scenes']  # (N_IMAGES Ã— D_vit)

Xv = np.asarray(vit_logits)
Xv = softmax(Xv, axis=1)          # probabilities over classes
Xv = clr(Xv + 1e-12)              # CLR transform

print("ðŸ”¹ PCA on ViT CLR embeddings...")
vit_pca_full = PCA(random_state=RANDOM_SEED).fit(Xv)
vit_cumvar = np.cumsum(vit_pca_full.explained_variance_ratio_)
vit_ncomp = np.searchsorted(vit_cumvar, VAR_CUTOFF) + 1

vit_pca = PCA(n_components=vit_ncomp, random_state=RANDOM_SEED)
Zv = vit_pca.fit_transform(Xv)    # (images Ã— vit_ncomp)

print(f"âœ… ViT PCs covering {VAR_CUTOFF*100:.0f}% variance: {vit_ncomp}")

# ===============================================================
# STEP 2: Neural preprocessing + FIRST PCA (to 90% var)
# ===============================================================
print(f"\nðŸ”¹ Loading neural responses, area = {AREA_NAME}...")
dat_all = np.load(NEURAL_PATH, mmap_mode='r')   # (n_neurons_total Ã— (images*trials*time))
areas = np.load(AREAS_PATH, allow_pickle=True)  # (n_neurons_total,)

area_mask = (areas == AREA_NAME)
dat = dat_all[area_mask]                        # (n_neurons_area Ã— n_total)
n_neurons, n_total = dat.shape

# Infer time bins
n_time = n_total // (N_IMAGES * N_TRIALS)
if N_IMAGES * N_TRIALS * n_time != n_total:
    raise ValueError(f"Inferred n_time={n_time} inconsistent with data shape.")

print(f"   â†’ {n_neurons} neurons, n_total={n_total}, inferred n_time={n_time}")

# Reshape to (neurons Ã— images Ã— trials Ã— time)
dat = dat.reshape(n_neurons, N_IMAGES, N_TRIALS, n_time)

# Stimulus-locked means:
# - R_even, R_odd: for cvPCA
# - R_all: all trials, for the initial PCA
even_idx = np.arange(0, N_TRIALS, 2)
odd_idx  = np.arange(1, N_TRIALS, 2)
print(f"   â†’ Even trials: {even_idx.size}, odd trials: {odd_idx.size}")

R_even = dat[:, :, even_idx, :].mean(axis=(2, 3))   # (neurons Ã— images)
R_odd  = dat[:, :, odd_idx,  :].mean(axis=(2, 3))   # (neurons Ã— images)
R_all  = dat.mean(axis=(2, 3))                      # (neurons Ã— images)

# For PCA: samples = images, features = neurons
Xa = R_all.T    # (images Ã— neurons)
Xe = R_even.T   # (images Ã— neurons)
Xo = R_odd.T    # (images Ã— neurons)

print("\nðŸ”¹ FIRST PCA on stimulus-locked neural responses (all trials)...")
neural_pca_full = PCA(random_state=RANDOM_SEED).fit(Xa)
neural_cumvar = np.cumsum(neural_pca_full.explained_variance_ratio_)
neural_ncomp = np.searchsorted(neural_cumvar, VAR_CUTOFF) + 1

neural_pca = PCA(n_components=neural_ncomp, random_state=RANDOM_SEED)
Xa_pca = neural_pca.fit_transform(Xa)   # (images Ã— neural_ncomp)
Xe_pca = neural_pca.transform(Xe)       # (images Ã— neural_ncomp)
Xo_pca = neural_pca.transform(Xo)       # (images Ã— neural_ncomp)

print(f"âœ… Neural PCs (first stage) covering {VAR_CUTOFF*100:.0f}% variance: {neural_ncomp}")

# PCA basis in neuron space (neurons Ã— neural_ncomp)
# sklearn: components_ is (n_components Ã— n_features)
W_pca_neurons = neural_pca.components_.T

# ===============================================================
# STEP 3: cvPCA in neural PCA space
# ===============================================================
print("\nðŸ”¹ Running cvPCA on neural PCA scores (stimulus-locked)...")

n_stim = N_IMAGES

# Xe_pca and Xo_pca are already centered using the same mean in neuron space;
# we can treat them as zero-mean PC scores.
Xe0 = Xe_pca
Xo0 = Xo_pca

# Covariance of train (even) half in PCA space
Cb_pca = Xe0.T @ Xe0 / (n_stim - 1)      # (neural_ncomp Ã— neural_ncomp)

# Eigen-decomposition in PCA space
eigvals, V_pca = np.linalg.eigh(Cb_pca)
idx = np.argsort(eigvals)[::-1]
eigvals = eigvals[idx]
V_pca = V_pca[:, idx]                    # (neural_ncomp Ã— neural_ncomp)

# Project both halves onto these components (still in PCA space)
S1 = Xe0 @ V_pca                         # (images Ã— neural_ncomp)
S2 = Xo0 @ V_pca                         # (images Ã— neural_ncomp)

# Cross-validated variance per component
lam_cv = np.sum(S1 * S2, axis=0) / (n_stim - 1)   # (neural_ncomp,)

# Clamp negatives (noise)
lam_cv_pos = np.maximum(lam_cv, 0.0)
total_shared = lam_cv_pos.sum()
if total_shared <= 0:
    raise RuntimeError("Total shared variance (cvPCA) non-positive. Check data/splits.")

shared_frac = lam_cv_pos / total_shared
cum_shared_frac = np.cumsum(shared_frac)

# Pick number of cvPCA components by shared variance cutoff
brain_ncomp = np.searchsorted(cum_shared_frac, VAR_CUTOFF) + 1
brain_ncomp = min(brain_ncomp, V_pca.shape[1])

print("===== Neural cvPCA (in PCA space) =====")
for i in range(brain_ncomp):
    print(f"cvPC {i+1:2d}: shared var = {lam_cv_pos[i]:.6f} | "
          f"fraction = {shared_frac[i]*100:5.2f}% | "
          f"cumulative = {cum_shared_frac[i]*100:5.2f}%")

print(f"\nâœ… Using first {brain_ncomp} cvPCA components as neural stimulus-locked basis "
      f"in PCA space (covering {cum_shared_frac[brain_ncomp-1]*100:.2f}% shared variance).")

# Neural cvPCA basis inside PCA space (neural_ncomp Ã— brain_ncomp)
W_cvpca_pca = V_pca[:, :brain_ncomp]

# Now get neural cvPCA scores for ALL trials (Xa_pca) in this basis:
Zb_cv = Xa_pca @ W_cvpca_pca           # (images Ã— brain_ncomp)

# Optionally, you can also get cvPCA basis back in neuron space:
# W_neural_cv = W_pca_neurons @ W_cvpca_pca   # (neurons Ã— brain_ncomp)


# ===============================================================
# STEP 4: Cross-covariance between ViT PCs and neural cvPCA scores
# ===============================================================
print("\nðŸ”¹ Cross-covariance SVD between ViT PCs and neural cvPCA scores...")

# Center across images
Zv0 = Zv - Zv.mean(axis=0, keepdims=True)      # (images Ã— vit_ncomp)
Zb0 = Zb_cv - Zb_cv.mean(axis=0, keepdims=True)  # (images Ã— brain_ncomp)

# Cross-covariance
C_vb = (Zv0.T @ Zb0) / (n_stim - 1)            # (vit_ncomp Ã— brain_ncomp)

# SVD of cross-covariance
U_cv, S_cv, Vt_cv = np.linalg.svd(C_vb, full_matrices=False)
V_cv = Vt_cv.T                                  # (brain_ncomp Ã— r)

shared_cov_frac = (S_cv**2) / np.sum(S_cv**2)

print("===== Cross-covariance components (ViT â†” neural cvPCA) =====")
for i, (s, frac) in enumerate(zip(S_cv, shared_cov_frac), start=1):
    print(f"Component {i:2d}: singular value = {s:.6f} | "
          f"cross-covariance fraction = {frac*100:5.2f}%")


# ===============================================================
# STEP 5: Variance explained in ORIGINAL neural PC data
#         by cross-covariance components
# ===============================================================
print("\nðŸ”¹ Variance explained in ORIGINAL neural PC data by cross-covariance components...")

# Original neural PC scores from FIRST PCA:
# Xa_pca: (images Ã— neural_ncomp)
Xa_pca0 = Xa_pca - Xa_pca.mean(axis=0, keepdims=True)

# Total variance in original neural PCA space
total_var_neural_pca = np.sum(np.var(Xa_pca0, axis=0, ddof=1))

# Map cross-cov components from cvPCA space back into original neural PCA space:
# W_cvpca_pca: (neural_ncomp Ã— brain_ncomp)
# V_cv:        (brain_ncomp Ã— r)
# â†’ B_pca:     (neural_ncomp Ã— r), directions in original neural PCA space
B_pca = W_cvpca_pca @ V_cv                   # (neural_ncomp Ã— r)

# Project original neural PC scores onto these cross-cov directions
Xb_cc_pca = Xa_pca0 @ B_pca                  # (images Ã— r)

# Variance per cross-cov component in original neural PCA space
var_cc = np.var(Xb_cc_pca, axis=0, ddof=1)   # (r,)
frac_cc = var_cc / total_var_neural_pca
cum_frac_cc = np.cumsum(frac_cc)

print("===== Neural PCA variance explained by cross-covariance components =====")
for i, (f, cf) in enumerate(zip(frac_cc, cum_frac_cc), start=1):
    print(f"Cross-cov comp {i:2d}: "
          f"explained var in neural PCA space = {f*100:5.2f}% | "
          f"cumulative = {cf*100:5.2f}%")

print(f"\nâœ… Total variance in ORIGINAL neural PC data accounted for "
      f"(sum over all cross-cov comps): {frac_cc.sum()*100:.2f}%")

# ===============================================================
# STEP 6: Save everything
# ===============================================================
out_name = f"vit_{AREA_NAME}_pca_cvpca_crosscov_results.npz"
np.savez(
    out_name,
    # ViT
    vit_scores=Zv,
    vit_basis=vit_pca.components_,
    vit_explained_var=vit_pca.explained_variance_ratio_,
    # Neural PCA
    neural_pca_scores=Xa_pca,
    neural_pca_basis=W_pca_neurons,
    neural_pca_explained_var=neural_pca.explained_variance_ratio_,
    # Neural cvPCA
    neural_cvpca_basis_pca=W_cvpca_pca,
    neural_cvpca_shared_var=lam_cv_pos[:brain_ncomp],
    neural_cvpca_shared_frac=shared_frac[:brain_ncomp],
    neural_cvpca_scores=Zb_cv,
    # Cross-covariance
    crosscov_matrix=C_vb,
    crosscov_singular_values=S_cv,
    crosscov_vit_basis=U_cv,
    crosscov_neural_cvpca_basis=V_cv,
    # Variance explained in ORIGINAL neural PCA space
    neural_pca_var_per_cc=var_cc,
    neural_pca_var_fraction_per_cc=frac_cc,
)

print(f"ðŸ’¾ Saved results to {out_name}")


ðŸ”¹ Loading ViT embeddings...
ðŸ”¹ PCA on ViT CLR embeddings...
âœ… ViT PCs covering 90% variance: 44

ðŸ”¹ Loading neural responses, area = VISp...
   â†’ 14382 neurons, n_total=5900, inferred n_time=1
   â†’ Even trials: 25, odd trials: 25

ðŸ”¹ FIRST PCA on stimulus-locked neural responses (all trials)...
âœ… Neural PCs (first stage) covering 90% variance: 87

ðŸ”¹ Running cvPCA on neural PCA scores (stimulus-locked)...
===== Neural cvPCA (in PCA space) =====
cvPC  1: shared var = 2.676840 | fraction =  5.94% | cumulative =  5.94%
cvPC  2: shared var = 1.709752 | fraction =  3.79% | cumulative =  9.73%
cvPC  3: shared var = 1.561735 | fraction =  3.46% | cumulative = 13.19%
cvPC  4: shared var = 1.341531 | fraction =  2.98% | cumulative = 16.17%
cvPC  5: shared var = 1.205600 | fraction =  2.67% | cumulative = 18.85%
cvPC  6: shared var = 1.145316 | fraction =  2.54% | cumulative = 21.39%
cvPC  7: shared var = 1.078619 | fraction =  2.39% | cumulative = 23.78%
cvPC  8: shared var =