In [12]:
import numpy as np
from scipy.stats import pearsonr

In [69]:
def center_kernel(K: np.ndarray) -> np.ndarray:
    n = K.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    return H @ K @ H

def remove_diagonal(X: np.ndarray) -> np.ndarray:
    """Remove the diagonal of a square matrix."""
    X = X.copy()
    np.fill_diagonal(X, 0)
    return X

def row_center(X: np.ndarray) -> np.ndarray:
    """Row-centering of a matrix."""
    return X - X.mean(axis=1, keepdims=True)

def col_center(X: np.ndarray) -> np.ndarray:
    """Column-centering of a matrix."""
    return X - X.mean(axis=0, keepdims=True)

def corr_rsm(X: np.ndarray) -> np.ndarray:
    """Row-by-row Pearson correlation matrix (RSM)."""
    return np.corrcoef(X)

def rsa_full(S1: np.ndarray, S2: np.ndarray, center: bool = False, off_diagonal=False) -> float:
    """Pearson correlation between the *entire* (vectorised) similarity matrices."""
    if off_diagonal:
        S1 = remove_diagonal(S1)
        S2 = remove_diagonal(S2)
    if center:
        S1, S2 = center_kernel(S1), center_kernel(S2)
    v1, v2 = S1.ravel(), S2.ravel()
    return pearsonr(v1, v2)[0]

def rsa(S1: np.ndarray, S2: np.ndarray, center: bool = False) -> float:
    """Pearson correlation between the *triu* (vectorised) similarity matrices."""
    if center:
        S1, S2 = center_kernel(S1), center_kernel(S2)
    triu_indices = np.triu_indices(S1.shape[0], k=1)
    v1, v2 = S1[triu_indices], S2[triu_indices]
    return pearsonr(v1, v2)[0]


def cka(K: np.ndarray, L: np.ndarray, center_kernels: bool = True, off_diagonal: bool = False) -> float:
    """Centered Kernel Alignment."""
    if off_diagonal:
        K = remove_diagonal(K)
        L = remove_diagonal(L)
    if center_kernels:
        Kc, Lc = center_kernel(K), center_kernel(L)
    else:
        Kc, Lc = K, L
    hsic = np.sum(Kc * Lc)               # Frobenius inner product
    norm = np.linalg.norm(Kc, 'fro') * np.linalg.norm(Lc, 'fro')
    return hsic / norm




# --- Generate two random datasets ---
np.random.seed(2025)
N, d1, d2 = 3, 30, 35
X = np.random.randn(N, d1)
Y = np.random.randn(N, d2)

# --- Compute RSMs ---
Sx, Sy = corr_rsm(X), corr_rsm(Y)


print("RSA pearson-pearson centered (full):", a:=rsa_full(Sx, Sy, center=True))
print("CKA centred:", b:=cka(Sx, Sy, center_kernels=True))
print("Difference (CKA and full rsa centered):", (a - b).round(10))  # 0

# Removing diagonal
print()
print("RSA pearson-pearson centered (full, offdiagonal):", a:=rsa_full(Sx, Sy, center=True, off_diagonal=True))
print("CKA centred (off diagonal):", b:=cka(Sx, Sy, center_kernels=True, off_diagonal=True))
print("Difference (off diagonal):", (a - b).round(10)) # 0

print()
print("RSA pearson-pearson centered (full, offdiagonal):", a:=rsa(Sx, Sy, center=False))

# print("RSA pearson-pearson centered (full, offdiagonal):", a:=rsa_full(Sx, Sy, center=False, off_diagonal=True))
# print("RSA pearson-pearson centered (triu, offdiagonal):", b:=rsa(Sx, Sy, center=False))
# print("Difference (off diagonal):", (a - b).round(10)) # != 0 





RSA pearson-pearson centered (full): 0.9801577004595325
CKA centred: 0.9801577004595325
Difference (CKA and full rsa centered): 0.0

RSA pearson-pearson centered (full, offdiagonal): 0.6497537436205814
CKA centred (off diagonal): 0.6497537436205817
Difference (off diagonal): -0.0

RSA pearson-pearson centered (full, offdiagonal): 0.5736019368734254
