Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement automatic ridge regression #124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions mne_connectivity/vector_ar/tests/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,42 @@ def bivariate_var_data():
return y


def illconditioned_data(
n=12,
m=100,
add_noise=False,
sigma=1e-4,
random_state=12345,
):

rng = np.random.RandomState(random_state)

if add_noise:
mu = 0.0
noise = rng.normal(mu, sigma, m) # gaussian noise

# create upper triangle A matrix
A = np.triu(rng.uniform(0, 1, (n, n)))
A[-1, -1] = 1e-6 # matrix is ill-conditioned

# compute true eigenvalues
true_eigvals = np.linalg.eigvals(A)

X = np.zeros((n, m))
X[:, 0] = rng.uniform(0, 1, n)
# evolve the system and perturb the data with noise
for k in range(1, m):
X[:, k] = A.dot(X[:, k - 1])

if add_noise:
X[:, k - 1] += noise[k - 1]

# data must be ill-conditioned
assert (np.linalg.cond(X) > 1e6)

return X, true_eigvals, A


def create_noisy_data(
add_noise,
sigma=1e-4,
Expand Down Expand Up @@ -316,3 +352,41 @@ def test_vector_auto_regression():
big_epoch_data = rng.randn(n_times * 2, n_signals, n_times)
parr_conn = vector_auto_regression(big_epoch_data, times=times, n_jobs=-1)
parr_conn.predict(big_epoch_data)


def test_auto_l2reg():
"""Test automatic l2 regularization of ill-conditioned data."""

sample_data, sample_eigs, sample_A = illconditioned_data(
add_noise=True
)

# create 3D array input
sample_data = sample_data[np.newaxis, ...]

# compute the model
model = vector_auto_regression(sample_data, l2_reg='auto')

# test that Ridge regression was used for ill-conditioned data
assert model.xarray.attrs['use_ridge']

# test the recovered model
assert_array_almost_equal(
model.get_data(output='dense').squeeze(), sample_A,
decimal=1
)

# compute model without regularization
noreg_model = vector_auto_regression(sample_data, l2_reg=None)
assert model.xarray.attrs['use_ridge'] is False

# test that the regularized model is better
eigs = np.linalg.eigvals(model.get_data(output='dense').squeeze())
noreg_eigs = np.linalg.eigvals(
noreg_model.get_data(output='dense').squeeze()
)

reg_diff = np.linalg.norm(eigs - sample_eigs)
noreg_diff = np.linalg.norm(noreg_eigs - sample_eigs)

assert reg_diff < noreg_diff
102 changes: 81 additions & 21 deletions mne_connectivity/vector_ar/var.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings

import numpy as np
import scipy
from scipy.linalg import sqrtm
from sklearn.linear_model import RidgeCV
from tqdm import tqdm
from mne import BaseEpochs

from mne.utils import logger, verbose
from mne.utils import logger, verbose, warn

from ..utils import fill_doc
from ..base import Connectivity, EpochConnectivity, EpochTemporalConnectivity
Expand All @@ -13,7 +16,7 @@
@verbose
@fill_doc
def vector_auto_regression(
data, times=None, names=None, lags=1, l2_reg=0.0,
data, times=None, names=None, lags=1, l2_reg='auto',
compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None):
"""Compute vector auto-regresssive (VAR) model.

Expand All @@ -29,8 +32,14 @@ def vector_auto_regression(
%(names)s
lags : int, optional
Autoregressive model order, by default 1.
l2_reg : float, optional
Ridge penalty (l2-regularization) parameter, by default 0.0.
l2_reg : str | array-like, shape=(n_alphas,) | float | None, optional
Ridge penalty (l2-regularization) parameter, by default 'auto'. If
``data`` has condition number less than 1e6, then ``data`` will undergo
automatic regularization using RidgeCV with a pre-defined array of
witherscp marked this conversation as resolved.
Show resolved Hide resolved
alphas: np.logspace(-15,5,11). A user-defined array of alphas (must be
positive floats) can be inputted or a float value to fix the Ridge
penalty (l2-regularization) parameter. If ``l2_reg`` is set to 0 or
None, then no regularization will be performed.
compute_fb_operator : bool
Whether to compute the backwards operator and average with
the forward operator. Addresses bias in the least-square
Expand Down Expand Up @@ -151,9 +160,32 @@ def vector_auto_regression(
# 1. determine shape of the window of data
n_epochs, n_nodes, _ = data.shape

cv_alphas = None
if isinstance(l2_reg, str) and l2_reg == 'auto':
# reset l2_reg for downstream functions
l2_reg = 0
# determine condition of matrix across all epochs
conds = np.linalg.cond(data)
if np.any(conds > 1e6):
# matrix is ill-conditioned, so regularization must be used with
# cross-validation alphas values
cv_alphas = np.logspace(-15, 5, 11)
warn('Input data matrix exceeds condition threshold of 1e6. '
'Automatic regularization will be performed.')
elif isinstance(l2_reg, (list, tuple, set, np.ndarray)):
cv_alphas = l2_reg
l2_reg = 0

# cases where OLS is used
if (l2_reg in [0, None]) and (cv_alphas is None):
use_ridge = False
else:
use_ridge = True

model_params = {
'lags': lags,
'l2_reg': l2_reg,
'use_ridge': use_ridge,
'cv_alphas': cv_alphas
}

if verbose:
Expand All @@ -165,12 +197,20 @@ def vector_auto_regression(
# sample of the multivariate time-series of interest
# ordinary least squares or regularized least squares
# (ridge regression)
X, Y = _construct_var_eqns(data, **model_params)

b, res, rank, s = scipy.linalg.lstsq(X, Y)
X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg)

# get the coefficients
coef = b.transpose()
if cv_alphas is not None:
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message="Ill-conditioned matrix"
)
Comment on lines +204 to +208
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RidgeCV tests out an array of alpha values and some of them do not regularize the matrix enough to avoid an ill-conditioned matrix error. If the user sees many of these messages pop up, they may think that something is going wrong, when in fact the expected behavior of the function is happening. RidgeCV will choose the best alpha value and that will be from an instance when this error was not thrown.

reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y)
coef = reg.coef_
else:
b, res, rank, s = scipy.linalg.lstsq(X, Y)
coef = b.transpose()

# create connectivity
coef = coef.flatten()
Expand All @@ -187,8 +227,9 @@ def vector_auto_regression(
# linear system
A_mats = _system_identification(
data=data, lags=lags,
l2_reg=l2_reg, n_jobs=n_jobs,
compute_fb_operator=compute_fb_operator)
l2_reg=l2_reg, cv_alphas=cv_alphas,
n_jobs=n_jobs, compute_fb_operator=compute_fb_operator
)
# create connectivity
if lags > 1:
conn = EpochTemporalConnectivity(data=A_mats,
Expand Down Expand Up @@ -261,7 +302,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
X[:n, i * lags + k -
1] = np.reshape(data[:, i, lags - k:-k].T, n)

if l2_reg is not None:
if l2_reg:
np.fill_diagonal(X[n:, :], l2_reg)

# Construct vectors yi (response variables for each channel i)
Expand All @@ -272,7 +313,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
return X, Y


def _system_identification(data, lags, l2_reg=0,
def _system_identification(data, lags, l2_reg=0, cv_alphas=None,
n_jobs=-1, compute_fb_operator=False):
"""Solve system identification using least-squares over all epochs.

Expand All @@ -290,6 +331,7 @@ def _system_identification(data, lags, l2_reg=0,
model_params = {
'l2_reg': l2_reg,
'lags': lags,
'cv_alphas': cv_alphas,
'compute_fb_operator': compute_fb_operator
}

Expand Down Expand Up @@ -346,7 +388,7 @@ def _system_identification(data, lags, l2_reg=0,
return A_mats


def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
def _compute_lds_func(data, lags, l2_reg, cv_alphas, compute_fb_operator):
"""Compute linear system using VAR model.

Allows for parallelization over epochs.
Expand All @@ -372,20 +414,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
# get time-shifted versions
X = data[:, :]
A, resid, omega = _estimate_var(X, lags=lags, offset=0,
l2_reg=l2_reg)
l2_reg=l2_reg, cv_alphas=cv_alphas)

if compute_fb_operator:
# compute backward linear operator
# original method
back_A, back_resid, back_omega = _estimate_var(
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg)
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, cv_alphas=cv_alphas
)
A = sqrtm(A.dot(np.linalg.inv(back_A)))
A = A.real # remove numerical noise

return A, resid, omega


def _estimate_var(X, lags, offset=0, l2_reg=0):
def _estimate_var(X, lags, offset=0, l2_reg=0, cv_alphas=None):
"""Estimate a VAR model.

Parameters
Expand All @@ -397,8 +440,10 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
offset : int, optional
Periods to drop from the beginning of the time-series, by default 0.
Used for order selection, so it's an apples-to-apples comparison
l2_reg : int
l2_reg : int, optional
The amount of l2-regularization to use. Default of 0.
cv_alphas : array-like | None, optional
RidgeCV regularization cross-validation alpha values. Defaults to None.

Returns
-------
Expand Down Expand Up @@ -432,10 +477,25 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
y_sample = endog[lags:]
del endog, X
# Lütkepohl p75, about 5x faster than stated formula
if l2_reg != 0:
params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample, rcond=1e-15)[0]

if (l2_reg is not None) and (l2_reg != 0):
# use pre-specified l2 regularization value
params = np.linalg.lstsq(
z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample,
rcond=1e-15
)[0]
elif cv_alphas is not None:
# use ridge regression with built-in cross validation of alpha values
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message="Ill-conditioned matrix"
)
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample)
params = reg.coef_.T
else:
# use OLS regression
params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0]

# (n_samples - lags, n_channels)
Expand Down