Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement automatic ridge regression #124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions mne_connectivity/vector_ar/var.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy
from scipy.linalg import sqrtm
from sklearn.linear_model import RidgeCV
from tqdm import tqdm
from mne import BaseEpochs

Expand All @@ -13,7 +14,7 @@
@verbose
@fill_doc
def vector_auto_regression(
data, times=None, names=None, lags=1, l2_reg=0.0,
data, times=None, names=None, lags=1, l2_reg=0.0, auto_reg=False,
compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None):
"""Compute vector auto-regresssive (VAR) model.

Expand All @@ -31,6 +32,10 @@ def vector_auto_regression(
Autoregressive model order, by default 1.
l2_reg : float, optional
Ridge penalty (l2-regularization) parameter, by default 0.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Ridge penalty (l2-regularization) parameter, by default 0.0.
Ridge penalty (l2-regularization) parameter, by default 0.0.
If ``auto_reg`` is `True`, then this must be set to 0.0.

auto_reg : bool, optional
Whether to perform automatic regularization of X matrix using RidgeCV,
by default False. If matrix is not full rank, this will be adjusted to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
by default False. If matrix is not full rank, this will be adjusted to
by default False. If the data matrix has condition number less than 1e6, then this will be adjusted to

True.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
True.
True. If ``l2_reg`` is non-zero, then this must be set to `False`.

compute_fb_operator : bool
Whether to compute the backwards operator and average with
the forward operator. Addresses bias in the least-square
Expand Down Expand Up @@ -121,6 +126,8 @@ def vector_auto_regression(
if model not in ['avg-epochs', 'dynamic']:
raise ValueError(f'"model" parameter must be one of '
f'(avg-epochs, dynamic), not {model}.')
elif auto_reg and l2_reg:
raise ValueError("If l2_reg is set, then auto_reg must be set to False")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit-test for this case? Lmk if you need help w/ that.


events = None
event_id = None
Expand Down Expand Up @@ -151,9 +158,17 @@ def vector_auto_regression(
# 1. determine shape of the window of data
n_epochs, n_nodes, _ = data.shape

# determine condition of matrix across all epochs
conds = np.linalg.cond(data)

if np.any(conds > 1e6):
# matrix is rank-deficient, so regularization must be used
auto_reg = True

model_params = {
'lags': lags,
'l2_reg': l2_reg,
'auto_reg': auto_reg
}

if verbose:
Expand All @@ -165,12 +180,15 @@ def vector_auto_regression(
# sample of the multivariate time-series of interest
# ordinary least squares or regularized least squares
# (ridge regression)
X, Y = _construct_var_eqns(data, **model_params)

b, res, rank, s = scipy.linalg.lstsq(X, Y)
X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg)

# get the coefficients
coef = b.transpose()
if auto_reg:
# use ridge regression with built-in cross validation of alpha values
reg = RidgeCV(alphas=np.logspace(-15,0,16), cv=5).fit(X, Y)
coef = reg.coef_
else:
b, res, rank, s = scipy.linalg.lstsq(X, Y)
coef = b.transpose()

# create connectivity
coef = coef.flatten()
Expand All @@ -187,8 +205,9 @@ def vector_auto_regression(
# linear system
A_mats = _system_identification(
data=data, lags=lags,
l2_reg=l2_reg, n_jobs=n_jobs,
compute_fb_operator=compute_fb_operator)
l2_reg=l2_reg, auto_reg=auto_reg,
n_jobs=n_jobs, compute_fb_operator=compute_fb_operator
)
# create connectivity
if lags > 1:
conn = EpochTemporalConnectivity(data=A_mats,
Expand Down Expand Up @@ -261,7 +280,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
X[:n, i * lags + k -
1] = np.reshape(data[:, i, lags - k:-k].T, n)

if l2_reg is not None:
if l2_reg:
np.fill_diagonal(X[n:, :], l2_reg)

# Construct vectors yi (response variables for each channel i)
Expand All @@ -272,7 +291,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
return X, Y


def _system_identification(data, lags, l2_reg=0,
def _system_identification(data, lags, l2_reg=0, auto_reg=False,
n_jobs=-1, compute_fb_operator=False):
"""Solve system identification using least-squares over all epochs.

Expand All @@ -289,6 +308,7 @@ def _system_identification(data, lags, l2_reg=0,

model_params = {
'l2_reg': l2_reg,
'auto_reg': auto_reg,
'lags': lags,
'compute_fb_operator': compute_fb_operator
}
Expand Down Expand Up @@ -346,7 +366,7 @@ def _system_identification(data, lags, l2_reg=0,
return A_mats


def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
def _compute_lds_func(data, lags, l2_reg, auto_reg, compute_fb_operator):
"""Compute linear system using VAR model.

Allows for parallelization over epochs.
Expand All @@ -372,20 +392,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
# get time-shifted versions
X = data[:, :]
A, resid, omega = _estimate_var(X, lags=lags, offset=0,
l2_reg=l2_reg)
l2_reg=l2_reg, auto_reg=auto_reg)

if compute_fb_operator:
# compute backward linear operator
# original method
back_A, back_resid, back_omega = _estimate_var(
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg)
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, auto_reg=auto_reg
)
A = sqrtm(A.dot(np.linalg.inv(back_A)))
A = A.real # remove numerical noise

return A, resid, omega


def _estimate_var(X, lags, offset=0, l2_reg=0):
def _estimate_var(X, lags, offset=0, l2_reg=0, auto_reg=False):
"""Estimate a VAR model.

Parameters
Expand All @@ -399,6 +420,9 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
Used for order selection, so it's an apples-to-apples comparison
l2_reg : int
The amount of l2-regularization to use. Default of 0.
auto_reg : bool
Whether or not to use automatic regularization with RidgeCV. Defaults
to False.

Returns
-------
Expand Down Expand Up @@ -433,9 +457,18 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
del endog, X
# Lütkepohl p75, about 5x faster than stated formula
if l2_reg != 0:
params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample, rcond=1e-15)[0]
# use pre-specified l2 regularization value
params = np.linalg.lstsq(
z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample,
rcond=1e-15
)[0]
elif auto_reg:
# use ridge regression with built-in cross validation of alpha values
reg = RidgeCV(alphas=np.logspace(-15,0,16), cv=5).fit(z, y_sample)
params = reg.coef_.T
else:
# use OLS regression
params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0]

# (n_samples - lags, n_channels)
Expand Down