-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement automatic ridge regression #124
base: main
Are you sure you want to change the base?
Changes from 1 commit
91aa8b8
7f8a19d
f4ade9b
acad4d5
49c189c
720edbd
5feb309
d85e80e
eccb46a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||
import numpy as np | ||||||
import scipy | ||||||
from scipy.linalg import sqrtm | ||||||
from sklearn.linear_model import RidgeCV | ||||||
from tqdm import tqdm | ||||||
from mne import BaseEpochs | ||||||
|
||||||
|
@@ -13,7 +14,7 @@ | |||||
@verbose | ||||||
@fill_doc | ||||||
def vector_auto_regression( | ||||||
data, times=None, names=None, lags=1, l2_reg=0.0, | ||||||
data, times=None, names=None, lags=1, l2_reg=0.0, auto_reg=False, | ||||||
compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None): | ||||||
"""Compute vector auto-regresssive (VAR) model. | ||||||
|
||||||
|
@@ -31,6 +32,10 @@ def vector_auto_regression( | |||||
Autoregressive model order, by default 1. | ||||||
l2_reg : float, optional | ||||||
Ridge penalty (l2-regularization) parameter, by default 0.0. | ||||||
auto_reg : bool, optional | ||||||
Whether to perform automatic regularization of X matrix using RidgeCV, | ||||||
by default False. If matrix is not full rank, this will be adjusted to | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
True. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
compute_fb_operator : bool | ||||||
Whether to compute the backwards operator and average with | ||||||
the forward operator. Addresses bias in the least-square | ||||||
|
@@ -121,6 +126,8 @@ def vector_auto_regression( | |||||
if model not in ['avg-epochs', 'dynamic']: | ||||||
raise ValueError(f'"model" parameter must be one of ' | ||||||
f'(avg-epochs, dynamic), not {model}.') | ||||||
elif auto_reg and l2_reg: | ||||||
raise ValueError("If l2_reg is set, then auto_reg must be set to False") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a unit-test for this case? Lmk if you need help w/ that. |
||||||
|
||||||
events = None | ||||||
event_id = None | ||||||
|
@@ -151,9 +158,17 @@ def vector_auto_regression( | |||||
# 1. determine shape of the window of data | ||||||
n_epochs, n_nodes, _ = data.shape | ||||||
|
||||||
# determine condition of matrix across all epochs | ||||||
conds = np.linalg.cond(data) | ||||||
|
||||||
if np.any(conds > 1e6): | ||||||
# matrix is rank-deficient, so regularization must be used | ||||||
auto_reg = True | ||||||
|
||||||
model_params = { | ||||||
'lags': lags, | ||||||
'l2_reg': l2_reg, | ||||||
'auto_reg': auto_reg | ||||||
} | ||||||
|
||||||
if verbose: | ||||||
|
@@ -165,12 +180,15 @@ def vector_auto_regression( | |||||
# sample of the multivariate time-series of interest | ||||||
# ordinary least squares or regularized least squares | ||||||
# (ridge regression) | ||||||
X, Y = _construct_var_eqns(data, **model_params) | ||||||
|
||||||
b, res, rank, s = scipy.linalg.lstsq(X, Y) | ||||||
X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg) | ||||||
|
||||||
# get the coefficients | ||||||
coef = b.transpose() | ||||||
if auto_reg: | ||||||
# use ridge regression with built-in cross validation of alpha values | ||||||
reg = RidgeCV(alphas=np.logspace(-15,0,16), cv=5).fit(X, Y) | ||||||
coef = reg.coef_ | ||||||
else: | ||||||
b, res, rank, s = scipy.linalg.lstsq(X, Y) | ||||||
coef = b.transpose() | ||||||
|
||||||
# create connectivity | ||||||
coef = coef.flatten() | ||||||
|
@@ -187,8 +205,9 @@ def vector_auto_regression( | |||||
# linear system | ||||||
A_mats = _system_identification( | ||||||
data=data, lags=lags, | ||||||
l2_reg=l2_reg, n_jobs=n_jobs, | ||||||
compute_fb_operator=compute_fb_operator) | ||||||
l2_reg=l2_reg, auto_reg=auto_reg, | ||||||
n_jobs=n_jobs, compute_fb_operator=compute_fb_operator | ||||||
) | ||||||
# create connectivity | ||||||
if lags > 1: | ||||||
conn = EpochTemporalConnectivity(data=A_mats, | ||||||
|
@@ -261,7 +280,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): | |||||
X[:n, i * lags + k - | ||||||
1] = np.reshape(data[:, i, lags - k:-k].T, n) | ||||||
|
||||||
if l2_reg is not None: | ||||||
if l2_reg: | ||||||
np.fill_diagonal(X[n:, :], l2_reg) | ||||||
|
||||||
# Construct vectors yi (response variables for each channel i) | ||||||
|
@@ -272,7 +291,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): | |||||
return X, Y | ||||||
|
||||||
|
||||||
def _system_identification(data, lags, l2_reg=0, | ||||||
def _system_identification(data, lags, l2_reg=0, auto_reg=False, | ||||||
n_jobs=-1, compute_fb_operator=False): | ||||||
"""Solve system identification using least-squares over all epochs. | ||||||
|
||||||
|
@@ -289,6 +308,7 @@ def _system_identification(data, lags, l2_reg=0, | |||||
|
||||||
model_params = { | ||||||
'l2_reg': l2_reg, | ||||||
'auto_reg': auto_reg, | ||||||
'lags': lags, | ||||||
'compute_fb_operator': compute_fb_operator | ||||||
} | ||||||
|
@@ -346,7 +366,7 @@ def _system_identification(data, lags, l2_reg=0, | |||||
return A_mats | ||||||
|
||||||
|
||||||
def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): | ||||||
def _compute_lds_func(data, lags, l2_reg, auto_reg, compute_fb_operator): | ||||||
"""Compute linear system using VAR model. | ||||||
|
||||||
Allows for parallelization over epochs. | ||||||
|
@@ -372,20 +392,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): | |||||
# get time-shifted versions | ||||||
X = data[:, :] | ||||||
A, resid, omega = _estimate_var(X, lags=lags, offset=0, | ||||||
l2_reg=l2_reg) | ||||||
l2_reg=l2_reg, auto_reg=auto_reg) | ||||||
|
||||||
if compute_fb_operator: | ||||||
# compute backward linear operator | ||||||
# original method | ||||||
back_A, back_resid, back_omega = _estimate_var( | ||||||
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg) | ||||||
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, auto_reg=auto_reg | ||||||
) | ||||||
A = sqrtm(A.dot(np.linalg.inv(back_A))) | ||||||
A = A.real # remove numerical noise | ||||||
|
||||||
return A, resid, omega | ||||||
|
||||||
|
||||||
def _estimate_var(X, lags, offset=0, l2_reg=0): | ||||||
def _estimate_var(X, lags, offset=0, l2_reg=0, auto_reg=False): | ||||||
"""Estimate a VAR model. | ||||||
|
||||||
Parameters | ||||||
|
@@ -399,6 +420,9 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): | |||||
Used for order selection, so it's an apples-to-apples comparison | ||||||
l2_reg : int | ||||||
The amount of l2-regularization to use. Default of 0. | ||||||
auto_reg : bool | ||||||
Whether or not to use automatic regularization with RidgeCV. Defaults | ||||||
to False. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
|
@@ -433,9 +457,18 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): | |||||
del endog, X | ||||||
# Lütkepohl p75, about 5x faster than stated formula | ||||||
if l2_reg != 0: | ||||||
params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags), | ||||||
z.T @ y_sample, rcond=1e-15)[0] | ||||||
# use pre-specified l2 regularization value | ||||||
params = np.linalg.lstsq( | ||||||
z.T @ z + l2_reg * np.eye(n_equations * lags), | ||||||
z.T @ y_sample, | ||||||
rcond=1e-15 | ||||||
)[0] | ||||||
elif auto_reg: | ||||||
# use ridge regression with built-in cross validation of alpha values | ||||||
reg = RidgeCV(alphas=np.logspace(-15,0,16), cv=5).fit(z, y_sample) | ||||||
params = reg.coef_.T | ||||||
else: | ||||||
# use OLS regression | ||||||
params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0] | ||||||
|
||||||
# (n_samples - lags, n_channels) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.