Skip to content

Commit

Permalink
add stein kernel and some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
gabelstein committed Apr 15, 2024
1 parent e7de8dc commit 0608787
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 110 deletions.
188 changes: 87 additions & 101 deletions pyriemann/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,16 @@ def kernel_canonical(X, Y=None, *,
A. Barachant, S. Bonnet, M. Congedo and C. Jutten. Neurocomputing,
Elsevier, 2013, 112, pp.172-178.
"""
try:
return globals()[f'kernel_{metric}'](X, Y, Cref=Cref, reg=reg)
except KeyError:
raise ValueError(
"Kernel metric must be 'euclid', 'logeuclid', or 'riemann' for"
f" canonical kernel. Got {metric}.")
kfunc = check_function(metric, _canonical_kernels)
return kfunc(X, Y, Cref=Cref, reg=reg)


_canonical_kernels = {
'euclid': kernel_euclid,
'logeuclid': kernel_logeuclid,
'riemann': kernel_riemann
}

###############################################################################
'''Distance Kernels.'''

Expand All @@ -216,7 +218,7 @@ def wrapper(X, Y=None, *, metric='riemann', reg=1e-10, **kwargs):
return wrapper


def kernel_gaussian(X, Y=None, *, metric='riemann', reg=0, gamma=1):
def kernel_gaussian(X, Y=None, *, metric='riemann', reg=0, gamma=1, **kwargs):
r"""Gaussian kernel between two sets of SPD matrices.
Calculates the Gaussian kernel matrix :math:`\mathbf{K}` of inner products
Expand Down Expand Up @@ -254,12 +256,14 @@ def kernel_gaussian(X, Y=None, *, metric='riemann', reg=0, gamma=1):
--------
kernel
"""
K = _distance_kernel(_exponential, squared=True)(X, Y, metric=metric,
gamma=-gamma, reg=reg)
K = _distance_kernel(_exponential, squared=True)(X, Y,
metric=metric,
gamma=-gamma,
reg=reg)
return K


def kernel_laplacian(X, Y=None, *, metric='riemann', gamma=1, reg=0):
def kernel_laplacian(X, Y=None, *, metric='riemann', gamma=1, reg=0, **kwargs):
"""
Laplacian kernel between two sets of SPD matrices.
Expand Down Expand Up @@ -297,13 +301,15 @@ def kernel_laplacian(X, Y=None, *, metric='riemann', gamma=1, reg=0):
--------
kernel
"""
K = _distance_kernel(_exponential, squared=False)(X, Y, metric=metric,
gamma=-gamma, reg=reg)
K = _distance_kernel(_exponential, squared=False)(X, Y,
metric=metric,
gamma=-gamma,
reg=reg)
return K


def kernel_rational_quadratic(X, Y=None, *, metric='riemann', alpha=1, l=1,
reg=0):
def kernel_rational_quadratic(X, Y=None, *,
metric='riemann', alpha=1, l=1, reg=0, **kwargs):
"""
Rational quadratic kernel between two sets of SPD matrices.
Expand Down Expand Up @@ -346,15 +352,16 @@ def kernel_rational_quadratic(X, Y=None, *, metric='riemann', alpha=1, l=1,
--------
kernel
"""
K = _distance_kernel(_rational_quadratic, squared=True)(X, Y, metric=metric,
K = _distance_kernel(_rational_quadratic, squared=True)(X, Y,
metric=metric,
alpha=alpha,
reg=reg,
l=l)
return K


def kernel_multiquadratic(X, Y=None, *,
metric='riemann', beta=1, sigma=1, reg=0):
metric='riemann', beta=1, sigma=1, reg=0, **kwargs):
"""
Multiquadratic kernel between two sets of SPD matrices.
Expand Down Expand Up @@ -408,7 +415,11 @@ def kernel_multiquadratic(X, Y=None, *,


def kernel_inverse_multiquadratic(X, Y=None, *,
metric='riemann', beta=1, sigma=1, reg=0):
metric='riemann',
beta=1,
sigma=1,
reg=0,
**kwargs):
"""
Inverse multiquadratic kernel between two sets of SPD matrices.
Expand Down Expand Up @@ -466,7 +477,7 @@ def kernel_inverse_multiquadratic(X, Y=None, *,

def _inner_product_kernel(func):
def wrapper(X, Y=None, *, Cref=None, reg=1e-10, metric='riemann', **kwargs):
feature_map = globals()[f'_{metric}']
feature_map = check_function(metric, _feature_maps)
K = _apply_matrix_kernel(feature_map, X, Y,
Cref=Cref, reg=0, metric=metric)
K = func(K, **kwargs)
Expand All @@ -476,7 +487,12 @@ def wrapper(X, Y=None, *, Cref=None, reg=1e-10, metric='riemann', **kwargs):


def kernel_polynomial(X, Y=None, *,
Cref=None, reg=10e-10, metric='riemann', r=0, s=1):
Cref=None,
reg=10e-10,
metric='riemann',
r=0,
s=1,
**kwargs):
"""Polynomial kernel between two sets of SPD matrices.
Calculates the polynomial kernel matrix :math:`\mathbf{K}` of inner products
Expand Down Expand Up @@ -527,7 +543,7 @@ def kernel_polynomial(X, Y=None, *,


def kernel_exponential(X, Y=None, *,
Cref=None, reg=0, metric='riemann', gamma=1):
Cref=None, reg=0, metric='riemann', gamma=1, **kwargs):
"""Exponential kernel between two sets of SPD matrices.
Calculates the exponential kernel matrix :math:`\mathbf{K}` of inner
Expand Down Expand Up @@ -575,8 +591,9 @@ def kernel_exponential(X, Y=None, *,
return K


def kernel_sigmoid(X, Y=None, *, Cref=None, reg=10e-10, metric='riemann',
gamma=1, r=0):
def kernel_sigmoid(X, Y=None, *,
Cref=None, reg=10e-10, metric='riemann', gamma=1, r=0,
**kwargs):
"""Sigmoid kernel between two sets of SPD matrices.
Calculates the sigmoid kernel matrix :math:`\mathbf{K}` of inner products
Expand Down Expand Up @@ -671,7 +688,7 @@ def kernel_frobenius(X, Y=None, *, reg=1e-10, **kwargs):
def kernel_logfrobenius(X, Y=None, *, reg=1e-10, **kwargs):
r"""Log-Frobenius kernel between two sets of SPD matrices.
Calculates the Log-Euclidean kernel matrix :math:`\mathbf{K}` of inner
Calculates the Log-Frobenius kernel matrix :math:`\mathbf{K}` of inner
products of two sets :math:`\mathbf{X}` and :math:`\mathbf{Y}` of SPD
matrices in :math:`\mathbb{R}^{n \times n}` by calculating pairwise
products [1]_:
Expand All @@ -692,11 +709,11 @@ def kernel_logfrobenius(X, Y=None, *, reg=1e-10, **kwargs):
Returns
-------
K : ndarray, shape (n_matrices_X, n_matrices_Y)
The Log-Euclidean kernel matrix between X and Y.
The Log-Frobenius kernel matrix between X and Y.
Notes
-----
.. versionadded:: 0.3
.. versionadded:: 0.7
See Also
--------
Expand All @@ -715,16 +732,17 @@ def kernel_logfrobenius(X, Y=None, *, reg=1e-10, **kwargs):
return K


def kernel_determinant(X, Y=None, *, reg=1e-10, **kwargs):
r"""Determinant kernel between two sets of SPD matrices.
def kernel_stein(X, Y=None, *, reg=1e-10, beta=1, c=1, **kwargs):
r"""Stein kernel between two sets of SPD matrices.
Calculates the determinant kernel matrix :math:`\mathbf{K}` of inner
products of two sets :math:`\mathbf{X}` and :math:`\mathbf{Y}` of SPD
matrices in :math:`\mathbb{R}^{n \times n}` by calculating pairwise
products:
Calculates the Stein kernel matrix :math:`\mathbf{K}` of inner products of
two sets :math:`\mathbf{X}` and :math:`\mathbf{Y}` of SPD matrices in
:math:`\mathbb{R}^{n \times n}` by calculating pairwise products [1]_:
.. math::
\mathbf{K}_{i,j} = \text{det}(\mathbf{X}_i \mathbf{Y}_j)
\mathbf{K}_{i,j} = 2^{c \beta} \left( \frac{\det(\mathbf{X}_i)^{\beta}
\det(\mathbf{Y}_j)^{\beta}}{\det(\mathbf{X}_i + \mathbf{Y}_j)^{\beta}}
\right)^{1/2}
Parameters
----------
Expand All @@ -735,67 +753,43 @@ def kernel_determinant(X, Y=None, *, reg=1e-10, **kwargs):
reg : float, default=1e-10
Regularization parameter to mitigate numerical errors in kernel
matrix estimation.
beta : float, default=1
Kernel parameter.
c : float, default=1
Kernel parameter.
Returns
-------
K : ndarray, shape (n_matrices_X, n_matrices_Y)
The determinant kernel matrix
The Stein kernel matrix between X and Y.
Notes
-----
.. versionadded:: 0.6
.. versionadded:: 0.7
See Also
--------
kernel
"""

K = _apply_matrix_kernel(_det, X, Y, reg=reg)
return K


def kernel_row_feature(X, Y=None, *,
Cref=None,
kernel_fct=np.exp,
kernel_parameters=None,
**kwargs):
"""Row feature kernel between two sets of SPD matrices.
Calculates the row feature kernel matrix :math:`\mathbf{K}` of inner
products of two sets :math:`\mathbf{X}` and :math:`\mathbf{Y}` of SPD
matrices in :math:`\mathbb{R}^{n \times n}` by calculating pairwise
products [1]_:
.. math::
\mathbf{K}_{i,j} = \sum_{k=1}^n \exp(-\gamma \text{dist}(\mathbf{X}_{i,k},
\mathbf{Y}_{j,k})^2)
Parameters
References
----------
X : ndarray, shape (n_matrices_X, n, n)
.. [1] `Sparse coding and dictionary learning for symmetric positive definite
matrices: A kernel approach
<https://link.springer.com/chapter/10.1007/978-3-642-33709-3_16>`_
M. T. Harandi, C. Sanderson, R. Hartley, and B. C. Lovell, ECCV, 2012,
pp. 216–229
"""

n_matrices_X, n, n = X.shape
C12inv = invsqrtm(Cref)
X = C12inv @ X @ C12inv
if Y is None:
Y = X
if Y is None or np.array_equal(X, Y):
X_ = _det(X)
Y_ = X_
else:
Y = C12inv @ Y @ C12inv

n_matrices_Y, n, n = Y.shape
X_, Y_ = _det(X), _det(Y)

full_res = np.zeros((n_matrices_X, n_matrices_Y))

for i, dat_ in enumerate(X):
res = Y - dat_
res = np.linalg.norm(res, axis=-1) ** 2
res = kernel_fct(res, **kernel_parameters)
res = np.sum(res, axis=-1)
full_res[i] = res

return full_res
frac = np.sqrt((X_[:, None]* Y_) ** beta ) / (X_[:, None] + Y_)**beta
K = 2**(c*beta) * frac
K = _regularize_kernel(K, reg=reg)
return K


###############################################################################
Expand Down Expand Up @@ -829,6 +823,14 @@ def _det(X, Cref=None):
return np.linalg.det(X)


_feature_maps = {
'log': _log,
'logeuclid': _logeuclid,
'riemann': _riemann,
'euclid': _euclid,
'det': _det
}

###############################################################################
'''Kernel functions.'''

Expand All @@ -848,12 +850,6 @@ def _sigmoid(K, gamma=1, r=0):
return np.tanh(gamma * K + r)


# might be wrong
def _periodic(K, gamma=1, l=1):
"""Periodic function."""
return np.exp(-2 * np.sin(np.pi * K / gamma) ** 2/l**2)


def _rational_quadratic(K, alpha=1, l=1):
"""Rational quadratic function."""
return (1 + K / (2 * alpha*l**2)) ** (-alpha)
Expand All @@ -875,16 +871,14 @@ def _inverse_multiquadratic(K, beta=1, sigma=1):
def _check_dimensions(X, Y, Cref):
"""Check for matching dimensions in X, Y and Cref."""
if not isinstance(Y, type(None)):
assert Y.shape[1:] == X.shape[1:], f"Dimension of matrices in Y must "\
f"match dimension of matrices in " \
f"X. Expected {X.shape[1:]}, got " \
f"{Y.shape[1:]}."
msg = f"Dimension of matrices in Y must match dimension of matrices "\
f"in X. Expected {X.shape[1:]}, got {Y.shape[1:]}."
assert Y.shape[1:] == X.shape[1:], msg

if not isinstance(Cref, type(None)):
assert Cref.shape == X.shape[1:], f"Dimension of Cref must match " \
f"dimension of matrices in X. " \
f"Expected {X.shape[1:]}, got " \
f"{Cref.shape}."
msg = f"Dimension of Cref must match dimension of matrices in X. "\
f"Expected {X.shape[1:]}, got {Cref.shape}."
assert Cref.shape == X.shape[1:], msg


def _regularize_kernel(K, reg=1e-10):
Expand Down Expand Up @@ -982,7 +976,7 @@ class Gram(BaseEstimator, TransformerMixin):
metric : str
The metric to use to compute the mean. See
:func:`pyriemann.utils.mean.mean_covariance` for available options.
kernel : str
kernel_fct : callable
The kernel to use to compute the gram matrix. See
:func:`pyriemann.utils.kernel.kernel` for available options.
Expand Down Expand Up @@ -1010,17 +1004,9 @@ def fit(self, X, y=None):
return self

def transform(self, X, y=None):
if not hasattr(self, 'data_'):
self.data_ = X
self.Cref = mean_covariance(self.data_, metric=self.metric)
gram = self.kernel_fct(X, self.data_, Cref=self.Cref)
return gram

def fit_transform(self, X, y=None):
gram = self.fit(X, y).transform(X, y)

return gram


kernel_types = {
'canonical': kernel_canonical,
Expand All @@ -1031,8 +1017,8 @@ def fit_transform(self, X, y=None):
'exponential': kernel_exponential,
'sigmoid': kernel_sigmoid,
'logfrobenius': kernel_logfrobenius,
'row_feature': kernel_row_feature,
'inverse_multiquadratic': kernel_inverse_multiquadratic,
'determinant': kernel_determinant,
'stein': kernel_stein,
'multiquadratic': kernel_multiquadratic
}

Loading

0 comments on commit 0608787

Please sign in to comment.