-
Notifications
You must be signed in to change notification settings - Fork 31
/
statefunctions.py
29 lines (24 loc) · 1.08 KB
/
statefunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Functions that respect typical key conventions
#
# i.e. f(**s) usually works where s is the state for a covariancecomparisonutil skater
import numpy as np
def oas(n_samples:int, pcov=None, scov=None, **ignore):
""" Compute shrunk covariancecomparisonutil matrix from empirical
:param pcov: Covariance matrix
:return: Shrunk cov matrix
"""
# In sklearn this is bundled with emp cov estimation, so we have to cut and paste a few lines. See
# https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/covariance/_shrunk_covariance.py#L347
assert n_samples > 1, 'Need n_samples>1'
if pcov is None:
assert scov is not None, 'Need pcov or scov to be supplied'
pcov = (n_samples-1)/n_samples*scov
n_dim = np.shape(pcov)[0]
mu = np.trace(pcov) / n_dim
alpha = np.mean(pcov ** 2)
num = alpha + mu ** 2
den = (n_samples + 1.0) * (alpha - (mu ** 2) / n_dim)
shrinkage = 1.0 if den == 0 else min(num / den, 1.0)
b = (1.0 - shrinkage) * pcov
b.flat[:: n_dim + 1] += shrinkage * mu
return b