Skip to content

Commit

Permalink
[FIX] Whitening with low amplitude data (#64)
Browse files Browse the repository at this point in the history
* minor docfixes

* fix whitening when data is very small

* Update test_cca.py

* pep

Co-authored-by: Nicolas Barascud <nbarascud@snapchat.com>
  • Loading branch information
nbara and nbarascud-sc authored Nov 24, 2022
1 parent 154cfc5 commit 5e95ee5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
23 changes: 13 additions & 10 deletions meegkit/cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
n_trials).
shifts : array, shape=(n_shifts,)
Array of shifts to apply to `y` relative to `x` (can be negative).
sfreq : float
Sampling frequency. If not 1, lags are assumed to be given in seconds.
surrogate : bool
If True, estimate SD of correlation over non-matching pairs.
plot : bool
Expand Down Expand Up @@ -133,16 +135,16 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,

# Calculate leave-one-out CCAs
print('Calculate CCAs...')
AA = list()
BB = list()
AA = []
BB = []
for t in tqdm(np.arange(n_trials)):
# covariance of all trials except t
CC = np.sum(C[..., np.arange(n_trials) != t], axis=-1, keepdims=True)
if CC.ndim == 4:
CC = np.squeeze(CC, 3)

# corresponding CCA
[A, B, R] = nt_cca(None, None, None, CC, xx[0].shape[1])
A, B, _ = nt_cca(None, None, None, CC, xx[0].shape[1])
AA.append(A)
BB.append(B)
del A, B
Expand Down Expand Up @@ -227,17 +229,17 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
independently from each page.
m : int
Number of channels of X.
thresh: float
thresh : float
Discard principal components below this value.
sfreq : float
Sampling frequency. If not 1, lags are assumed to be given in seconds.
Returns
-------
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y))
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)[, n_lags])
Transform matrix mapping `X` to canonical space, where `n_comps` is
equal to `min(n_chans_X, n_chans_Y)`.
B : array, shape=(n_chans_Y, n_comps)
B : array, shape=(n_chans_Y, n_comps[, n_lags])
Transform matrix mapping `Y` to canonical space, where `n_comps` is
equal to `min(n_chans_X, n_chans_Y)`.
R : array, shape=(n_comps, n_lags)
Expand All @@ -246,16 +248,16 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
Notes
-----
Usage 1: CCA of X, Y
>> [A, B, R] = nt_cca(X, Y) # noqa
>> A, B, R = nt_cca(X, Y) # noqa
Usage 2: CCA of X, Y for each value of lags.
>> [A, B, R] = nt_cca(X, Y, lags) # noqa
>> A, B, R = nt_cca(X, Y, lags) # noqa
A positive lag indicates that Y is delayed relative to X.
Usage 3: CCA from covariance matrix
>> C = [X, Y].T * [X, Y] # noqa
>> [A, B, R] = nt_cca([], [], [], C, X.shape[1]) # noqa
>> A, B, R = nt_cca(None, None, None, C=C, m=X.shape[1]) # noqa
Use the third form to handle multiple files or large data (covariance C can
be calculated chunk-by-chunk).
Expand Down Expand Up @@ -381,9 +383,10 @@ def whiten_nt(C, thresh=1e-12, keep=False):
# break symmetry when x and y perfectly correlated (otherwise cols of x*A
# and y*B are not orthogonal)
d = d ** (1 - thresh)
d_norm = d / np.max(d)

dd = np.zeros_like(d)
dd[d > thresh] = (1. / d[d > thresh])
dd[d_norm > thresh] = (1. / d[d_norm > thresh])

D = np.diag(np.sqrt(dd))
W = np.dot(V, D)
Expand Down
Binary file added tests/data/ccadata_meg_2trials.npz
Binary file not shown.
18 changes: 18 additions & 0 deletions tests/test_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ def test_cca2():
# plt.show()


def test_cca_scaling():
"""Test CCA with MEG data."""
data = np.load('./tests/data/ccadata_meg_2trials.npz')
raw = data['arr_0']
env = data['arr_1']

# Test with scaling (unit: fT)
A0, B0, R0 = nt_cca(raw * 1e15, env)

# Test without scaling (unit: T)
A1, B1, R1 = nt_cca(raw, env)

np.testing.assert_almost_equal(R0, R1)


def test_canoncorr():
"""Compare with Matlab's canoncorr."""
x = np.array([[16, 2, 3, 13],
Expand Down Expand Up @@ -130,6 +145,9 @@ def test_cca_lags():
lags = np.arange(-10, 11, 1)
A1, B1, R1 = nt_cca(x, y, lags)

assert A1.ndim == B1.ndim == 3
assert A1.shape[-1] == B1.shape[-1] == lags.size

# import matplotlib.pyplot as plt
# f, ax1 = plt.subplots(1, 1)
# ax1.plot(lags, R1.T)
Expand Down

0 comments on commit 5e95ee5

Please sign in to comment.