Skip to content

Commit

Permalink
Merge pull request #40 from felixriese/fix-quantization-error
Browse files Browse the repository at this point in the history
FIX quantization error calculation
  • Loading branch information
felixriese committed Jul 8, 2023
2 parents 005288d + 91b3ccd commit 7d090af
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 24 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Expand Up @@ -3,6 +3,7 @@

# for formatting
black==22.10.0
isort==5.12.0

# for the documentation
sphinx>=4.5.0
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Expand Up @@ -21,3 +21,10 @@ exclude = '''
| .pytest_cache
)/
'''

[tool.isort]
profile = "black"

[tool.pydocstyle]
convention = "numpy"
match = 'susi/*'
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -4,7 +4,7 @@ numpy>=1.18.5
scikit-learn>=0.21.1
scipy>=1.3.1
tqdm>=4.45.0
matplotlib>=3.3.0
matplotlib==3.7.2

# for the examples
notebook>=6.0.0
Expand Down
25 changes: 21 additions & 4 deletions susi/SOMClustering.py
Expand Up @@ -525,6 +525,7 @@ def _get_node_distance_matrix(
np.divide(
np.dot(som_array[node[0], node[1]], datapoint),
np.multiply(
# TODO check if an axis needs to be set here
np.linalg.norm(som_array),
np.linalg.norm(datapoint),
),
Expand Down Expand Up @@ -927,6 +928,24 @@ def _get_node_neighbors(
)
return list(itertools.product(row_range, column_range))

def _get_weights_per_datapoint(self, datapoints: Sequence) -> list:
"""Get SOM weights per datapoint.
Parameters
----------
datapoints : array-like matrix, optional (default=True)
Samples of shape = [n_samples, n_features].
Returns
-------
float
Mean quantization error over all datapoints.
"""
return [
self.unsuper_som_[bmu[0], bmu[1]]
for bmu in self.get_bmus(datapoints)
]

def get_quantization_error(self, X: Optional[Sequence] = None) -> float:
"""Get quantization error for `X` (or the training data).
Expand All @@ -953,12 +972,10 @@ def get_quantization_error(self, X: Optional[Sequence] = None) -> float:
if X is None:
X = self.X_

weights_per_datapoint = [
self.unsuper_som_[bmu[0], bmu[1]] for bmu in self.get_bmus(X)
]
weights_per_datapoint = self._get_weights_per_datapoint(X)

quantization_errors = np.linalg.norm(
np.subtract(weights_per_datapoint, X)
np.subtract(weights_per_datapoint, X), axis=1
)

return np.mean(quantization_errors)
6 changes: 3 additions & 3 deletions susi/SOMPlots.py
Expand Up @@ -2,8 +2,8 @@

from typing import List, Tuple

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


Expand Down Expand Up @@ -47,7 +47,7 @@ def plot_estimation_map(
for label in cbar.ax.xaxis.get_ticklabels()[::2]:
label.set_visible(False)

plt.grid(b=False)
plt.grid(visible=False)

return ax

Expand Down Expand Up @@ -130,7 +130,7 @@ def plot_som_histogram(
# to be compatible with plt.imshow:
ax.invert_yaxis()

plt.grid(b=False)
plt.grid(visible=False)

return ax

Expand Down
3 changes: 2 additions & 1 deletion test-requirements.txt
Expand Up @@ -2,5 +2,6 @@ codecov>=2.1.10
coverage>=5.3
flake8==5.0.4
nbval>=0.9.5
pytest>=6.0.1
pytest>=7.2.0
pytest-cov>=2.10.1
pytest_mock==3.11.1
64 changes: 49 additions & 15 deletions tests/test_SOMClustering.py
Expand Up @@ -4,18 +4,19 @@
python -m pytest tests/test_SOMClustering.py
"""
import pytest
import os
import sys

import numpy as np
import pytest
from sklearn.datasets import make_biclusters

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
)
import susi
import susi # noqa

X, _, _ = make_biclusters((100, 10), 3)
test_data, _, _ = make_biclusters((100, 10), 3)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -665,27 +666,60 @@ def test_get_datapoints_from_node(
)
def test_get_u_matrix(n_rows, n_columns, mode):
som = susi.SOMClustering(n_rows=n_rows, n_columns=n_columns)
som.fit(X)
som.fit(test_data)
u_matrix = som.get_u_matrix(mode=mode)
assert isinstance(u_matrix, np.ndarray)
assert u_matrix.shape == (n_rows * 2 - 1, n_columns * 2 - 1, 1)


def test_get_clusters():
som = susi.SOMClustering()
som.fit(X)
clusters = som.get_clusters(X)
assert len(clusters) == len(X)
som.fit(test_data)
clusters = som.get_clusters(test_data)
assert len(clusters) == len(test_data)
assert len(clusters[0]) == 2


def test_get_quantization_error():
# given
som = susi.SOMClustering()
som.fit(X)
class TestGetQuantizationError:
def test_with_default_data(self) -> None:
# given
som = susi.SOMClustering()
som.fit(test_data)

# when
qerror = som.get_quantization_error()

# then
assert qerror < 0.05

def test_with_explicit_data(self, mocker) -> None:
# given
som = susi.SOMClustering()
X = np.array(
[
[8, 9, 7, 8, 6, 3, 0, 4, 1, 6],
[3, 2, 1, 2, 5, 7, 0, 0, 3, 3],
[0, 7, 5, 1, 4, 5, 7, 5, 8, 5],
[9, 1, 1, 7, 5, 9, 8, 9, 3, 3],
[2, 7, 7, 2, 3, 3, 3, 3, 3, 3],
]
)
som.fit(X)
weights_per_datapoint = [
[8, 5, 0, 4, 5, 7, 0, 2, 0, 2],
[4, 0, 7, 4, 8, 0, 4, 2, 2, 8],
[5, 3, 1, 2, 7, 9, 9, 8, 8, 7],
[8, 6, 8, 7, 7, 5, 4, 7, 1, 1],
[6, 2, 3, 5, 7, 5, 9, 5, 9, 0],
]
mocker.patch.object(
som,
"_get_weights_per_datapoint",
return_value=weights_per_datapoint,
)

# when
qerror = som.get_quantization_error()
# when
qerror = som.get_quantization_error(X)

# then
assert qerror < 0.05
# then
assert qerror == 11.45650021348017

0 comments on commit 7d090af

Please sign in to comment.