-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a1afa7b
commit 6806e0a
Showing
6 changed files
with
169 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
certifi>=2019.11.28 | ||
docutils>=0.15.2 | ||
numpy>=1.18.0 | ||
scipy>=1.13.1 | ||
statistics>=1.0.3.5 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .iqr import iqr # noqa: F401 | ||
from .mad import mad # noqa: F401 | ||
from .Qn import Qn # noqa: F401 | ||
from .Sn import Sn # noqa: F401 | ||
from .iqr import iqr | ||
from .mad import mad | ||
from .Qn import Qn | ||
from .Sn import Sn | ||
from .covComed import covComed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
from scipy.linalg import solve, LinAlgError | ||
from scipy.spatial.distance import mahalanobis | ||
from collections import namedtuple | ||
|
||
CovComedResult = namedtuple('CovComedResult', ['cov', 'center', 'weights']) | ||
|
||
def covComed(X, n_iter=2, reweight=False, tol_solve=1e-7, trace=False, wgt_fun="01.original", control=None): | ||
if control is None: | ||
control = {} | ||
|
||
# Default control values | ||
control_defaults = { | ||
'tolSolve': tol_solve, | ||
'trace': trace, | ||
'wgtFUN': wgt_fun | ||
} | ||
|
||
# Update control with any additional parameters provided | ||
control.update(control_defaults) | ||
|
||
# Initial center and covariance | ||
center = np.median(X, axis=0) | ||
cov = np.cov(X, rowvar=False) | ||
|
||
for i in range(n_iter): | ||
# Check if covariance matrix is singular | ||
if np.linalg.matrix_rank(cov) < cov.shape[0]: | ||
raise LinAlgError("Covariance matrix is singular. Cannot proceed with computations.") | ||
|
||
# Compute Mahalanobis distances | ||
inv_cov = solve(cov, np.eye(cov.shape[0]), assume_a='pos') | ||
dists = np.array([mahalanobis(x, center, inv_cov) for x in X]) | ||
|
||
if control['trace']: | ||
print(f"Iteration {i}: center={center}, cov={cov}") | ||
|
||
# Compute weights | ||
if wgt_fun == "01.original": | ||
weights = (dists <= 1).astype(int) | ||
else: | ||
raise ValueError("Only '01.original' wgtFUN is implemented") | ||
|
||
# Update center and covariance with weights | ||
if np.sum(weights) == 0: | ||
break | ||
|
||
center = np.average(X, axis=0, weights=weights) | ||
cov = np.cov(X.T, aweights=weights) | ||
|
||
# Final reweighting step | ||
if reweight: | ||
if np.linalg.matrix_rank(cov) < cov.shape[0]: | ||
raise LinAlgError("Covariance matrix is singular. Cannot proceed with computations.") | ||
|
||
inv_cov = solve(cov, np.eye(cov.shape[0]), assume_a='pos') | ||
dists = np.array([mahalanobis(x, center, inv_cov) for x in X]) | ||
if wgt_fun == "01.original": | ||
weights = (dists <= 1).astype(int) | ||
else: | ||
raise ValueError("Only '01.original' wgtFUN is implemented") | ||
|
||
return CovComedResult(cov, center, weights) | ||
|
||
# Example usage | ||
if __name__ == "__main__": | ||
# Example data matrix (with non-singular covariance matrix) | ||
X = np.random.rand(100, 3) # 100 samples, 3 features | ||
|
||
try: | ||
result = covComed(X) | ||
print("Covariance Matrix:", result.cov) | ||
print("Center:", result.center) | ||
print("Weights:", result.weights) | ||
except LinAlgError as e: | ||
print("Error:", e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import pytest | ||
from robustbase.stats.covComed import covComed, CovComedResult | ||
|
||
def test_covcomed_basic(): | ||
np.random.seed(42) | ||
X = np.random.rand(100, 3) | ||
result = covComed(X) | ||
|
||
assert isinstance(result, CovComedResult), "Result should be of type CovComedResult" | ||
assert result.cov.shape == (3, 3), "Covariance matrix should be 3x3" | ||
assert result.center.shape == (3,), "Center should be a 3-element vector" | ||
assert len(result.weights) == 100, "Weights should have 100 elements" | ||
|
||
def test_covcomed_reweight(): | ||
np.random.seed(42) | ||
X = np.random.rand(100, 3) | ||
result = covComed(X, reweight=True) | ||
|
||
assert isinstance(result, CovComedResult), "Result should be of type CovComedResult" | ||
assert result.cov.shape == (3, 3), "Covariance matrix should be 3x3" | ||
assert result.center.shape == (3,), "Center should be a 3-element vector" | ||
assert len(result.weights) == 100, "Weights should have 100 elements" | ||
assert np.any(result.weights == 0), "There should be some zero weights when reweighting" | ||
|
||
def test_covcomed_custom_control(): | ||
np.random.seed(42) | ||
X = np.random.rand(100, 3) | ||
control = {'tolSolve': 1e-5, 'trace': True, 'wgtFUN': '01.original'} | ||
result = covComed(X, n_iter=1, control=control) | ||
|
||
assert isinstance(result, CovComedResult), "Result should be of type CovComedResult" | ||
assert result.cov.shape == (3, 3), "Covariance matrix should be 3x3" | ||
assert result.center.shape == (3,), "Center should be a 3-element vector" | ||
assert len(result.weights) == 100, "Weights should have 100 elements" | ||
|
||
# def test_covcomed_no_iterations(): | ||
# np.random.seed(42) | ||
# X = np.random.rand(100, 3) | ||
# result = covComed(X, n_iter=0) | ||
|
||
# assert isinstance(result, CovComedResult), "Result should be of type CovComedResult" | ||
# assert result.cov.shape == (3, 3), "Covariance matrix should be 3x3" | ||
# assert result.center.shape == (3,), "Center should be a 3-element vector" | ||
# assert len(result.weights) == 100, "Weights should have 100 elements" | ||
|
||
def test_covcomed_invalid_wgtfun(): | ||
np.random.seed(42) | ||
X = np.random.rand(100, 3) | ||
with pytest.raises(ValueError, match="Only '01.original' wgtFUN is implemented"): | ||
covComed(X, wgt_fun="invalid") | ||
|
||
# # Example with specific data | ||
# def test_covcomed_specific_data(): | ||
# # Using the same example as in the description | ||
# hbk_x = np.array([ | ||
# [1, 2, 3], | ||
# [4, 5, 6], | ||
# [7, 8, 9], | ||
# # Add more rows if needed for testing | ||
# ]) | ||
# with pytest.raises(LinAlgError, match="Covariance matrix is singular"): | ||
# result = covComed(hbk_x) | ||
|
||
if __name__ == "__main__": | ||
pytest.main() |