diff --git a/mne_connectivity/vector_ar/tests/test_var.py b/mne_connectivity/vector_ar/tests/test_var.py index 836d77a1..9a37f41d 100644 --- a/mne_connectivity/vector_ar/tests/test_var.py +++ b/mne_connectivity/vector_ar/tests/test_var.py @@ -25,6 +25,42 @@ def bivariate_var_data(): return y +def illconditioned_data( + n=12, + m=100, + add_noise=False, + sigma=1e-4, + random_state=12345, +): + + rng = np.random.RandomState(random_state) + + if add_noise: + mu = 0.0 + noise = rng.normal(mu, sigma, m) # gaussian noise + + # create upper triangle A matrix + A = np.triu(rng.uniform(0, 1, (n, n))) + A[-1, -1] = 1e-6 # matrix is ill-conditioned + + # compute true eigenvalues + true_eigvals = np.linalg.eigvals(A) + + X = np.zeros((n, m)) + X[:, 0] = rng.uniform(0, 1, n) + # evolve the system and perturb the data with noise + for k in range(1, m): + X[:, k] = A.dot(X[:, k - 1]) + + if add_noise: + X[:, k - 1] += noise[k - 1] + + # data must be ill-conditioned + assert (np.linalg.cond(X) > 1e6) + + return X, true_eigvals, A + + def create_noisy_data( add_noise, sigma=1e-4, @@ -316,3 +352,41 @@ def test_vector_auto_regression(): big_epoch_data = rng.randn(n_times * 2, n_signals, n_times) parr_conn = vector_auto_regression(big_epoch_data, times=times, n_jobs=-1) parr_conn.predict(big_epoch_data) + + +def test_auto_l2reg(): + """Test automatic l2 regularization of ill-conditioned data.""" + + sample_data, sample_eigs, sample_A = illconditioned_data( + add_noise=True + ) + + # create 3D array input + sample_data = sample_data[np.newaxis, ...] + + # compute the model + model = vector_auto_regression(sample_data, l2_reg='auto') + + # test that Ridge regression was used for ill-conditioned data + assert model.xarray.attrs['use_ridge'] + + # test the recovered model + assert_array_almost_equal( + model.get_data(output='dense').squeeze(), sample_A, + decimal=1 + ) + + # compute model without regularization + noreg_model = vector_auto_regression(sample_data, l2_reg=None) + assert model.xarray.attrs['use_ridge'] is False + + # test that the regularized model is better + eigs = np.linalg.eigvals(model.get_data(output='dense').squeeze()) + noreg_eigs = np.linalg.eigvals( + noreg_model.get_data(output='dense').squeeze() + ) + + reg_diff = np.linalg.norm(eigs - sample_eigs) + noreg_diff = np.linalg.norm(noreg_eigs - sample_eigs) + + assert reg_diff < noreg_diff diff --git a/mne_connectivity/vector_ar/var.py b/mne_connectivity/vector_ar/var.py index d6598311..de520e01 100644 --- a/mne_connectivity/vector_ar/var.py +++ b/mne_connectivity/vector_ar/var.py @@ -1,10 +1,13 @@ +import warnings + import numpy as np import scipy from scipy.linalg import sqrtm +from sklearn.linear_model import RidgeCV from tqdm import tqdm from mne import BaseEpochs -from mne.utils import logger, verbose +from mne.utils import logger, verbose, warn from ..utils import fill_doc from ..base import Connectivity, EpochConnectivity, EpochTemporalConnectivity @@ -13,7 +16,7 @@ @verbose @fill_doc def vector_auto_regression( - data, times=None, names=None, lags=1, l2_reg=0.0, + data, times=None, names=None, lags=1, l2_reg='auto', compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None): """Compute vector auto-regresssive (VAR) model. @@ -29,8 +32,14 @@ def vector_auto_regression( %(names)s lags : int, optional Autoregressive model order, by default 1. - l2_reg : float, optional - Ridge penalty (l2-regularization) parameter, by default 0.0. + l2_reg : str | array-like, shape=(n_alphas,) | float | None, optional + Ridge penalty (l2-regularization) parameter, by default 'auto'. If + ``data`` has condition number less than 1e6, then ``data`` will undergo + automatic regularization using RidgeCV with a pre-defined array of + alphas: np.logspace(-15,5,11). A user-defined array of alphas (must be + positive floats) can be inputted or a float value to fix the Ridge + penalty (l2-regularization) parameter. If ``l2_reg`` is set to 0 or + None, then no regularization will be performed. compute_fb_operator : bool Whether to compute the backwards operator and average with the forward operator. Addresses bias in the least-square @@ -151,9 +160,32 @@ def vector_auto_regression( # 1. determine shape of the window of data n_epochs, n_nodes, _ = data.shape + cv_alphas = None + if isinstance(l2_reg, str) and l2_reg == 'auto': + # reset l2_reg for downstream functions + l2_reg = 0 + # determine condition of matrix across all epochs + conds = np.linalg.cond(data) + if np.any(conds > 1e6): + # matrix is ill-conditioned, so regularization must be used with + # cross-validation alphas values + cv_alphas = np.logspace(-15, 5, 11) + warn('Input data matrix exceeds condition threshold of 1e6. ' + 'Automatic regularization will be performed.') + elif isinstance(l2_reg, (list, tuple, set, np.ndarray)): + cv_alphas = l2_reg + l2_reg = 0 + + # cases where OLS is used + if (l2_reg in [0, None]) and (cv_alphas is None): + use_ridge = False + else: + use_ridge = True + model_params = { 'lags': lags, - 'l2_reg': l2_reg, + 'use_ridge': use_ridge, + 'cv_alphas': cv_alphas } if verbose: @@ -165,12 +197,20 @@ def vector_auto_regression( # sample of the multivariate time-series of interest # ordinary least squares or regularized least squares # (ridge regression) - X, Y = _construct_var_eqns(data, **model_params) - b, res, rank, s = scipy.linalg.lstsq(X, Y) + X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg) - # get the coefficients - coef = b.transpose() + if cv_alphas is not None: + with warnings.catch_warnings(): + warnings.filterwarnings( + action='ignore', + message="Ill-conditioned matrix" + ) + reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y) + coef = reg.coef_ + else: + b, res, rank, s = scipy.linalg.lstsq(X, Y) + coef = b.transpose() # create connectivity coef = coef.flatten() @@ -187,8 +227,9 @@ def vector_auto_regression( # linear system A_mats = _system_identification( data=data, lags=lags, - l2_reg=l2_reg, n_jobs=n_jobs, - compute_fb_operator=compute_fb_operator) + l2_reg=l2_reg, cv_alphas=cv_alphas, + n_jobs=n_jobs, compute_fb_operator=compute_fb_operator + ) # create connectivity if lags > 1: conn = EpochTemporalConnectivity(data=A_mats, @@ -261,7 +302,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): X[:n, i * lags + k - 1] = np.reshape(data[:, i, lags - k:-k].T, n) - if l2_reg is not None: + if l2_reg: np.fill_diagonal(X[n:, :], l2_reg) # Construct vectors yi (response variables for each channel i) @@ -272,7 +313,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): return X, Y -def _system_identification(data, lags, l2_reg=0, +def _system_identification(data, lags, l2_reg=0, cv_alphas=None, n_jobs=-1, compute_fb_operator=False): """Solve system identification using least-squares over all epochs. @@ -290,6 +331,7 @@ def _system_identification(data, lags, l2_reg=0, model_params = { 'l2_reg': l2_reg, 'lags': lags, + 'cv_alphas': cv_alphas, 'compute_fb_operator': compute_fb_operator } @@ -346,7 +388,7 @@ def _system_identification(data, lags, l2_reg=0, return A_mats -def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): +def _compute_lds_func(data, lags, l2_reg, cv_alphas, compute_fb_operator): """Compute linear system using VAR model. Allows for parallelization over epochs. @@ -372,20 +414,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): # get time-shifted versions X = data[:, :] A, resid, omega = _estimate_var(X, lags=lags, offset=0, - l2_reg=l2_reg) + l2_reg=l2_reg, cv_alphas=cv_alphas) if compute_fb_operator: # compute backward linear operator # original method back_A, back_resid, back_omega = _estimate_var( - X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg) + X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, cv_alphas=cv_alphas + ) A = sqrtm(A.dot(np.linalg.inv(back_A))) A = A.real # remove numerical noise return A, resid, omega -def _estimate_var(X, lags, offset=0, l2_reg=0): +def _estimate_var(X, lags, offset=0, l2_reg=0, cv_alphas=None): """Estimate a VAR model. Parameters @@ -397,8 +440,10 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): offset : int, optional Periods to drop from the beginning of the time-series, by default 0. Used for order selection, so it's an apples-to-apples comparison - l2_reg : int + l2_reg : int, optional The amount of l2-regularization to use. Default of 0. + cv_alphas : array-like | None, optional + RidgeCV regularization cross-validation alpha values. Defaults to None. Returns ------- @@ -432,10 +477,25 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): y_sample = endog[lags:] del endog, X # Lütkepohl p75, about 5x faster than stated formula - if l2_reg != 0: - params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags), - z.T @ y_sample, rcond=1e-15)[0] + + if (l2_reg is not None) and (l2_reg != 0): + # use pre-specified l2 regularization value + params = np.linalg.lstsq( + z.T @ z + l2_reg * np.eye(n_equations * lags), + z.T @ y_sample, + rcond=1e-15 + )[0] + elif cv_alphas is not None: + # use ridge regression with built-in cross validation of alpha values + with warnings.catch_warnings(): + warnings.filterwarnings( + action='ignore', + message="Ill-conditioned matrix" + ) + reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample) + params = reg.coef_.T else: + # use OLS regression params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0] # (n_samples - lags, n_channels)