Skip to content

Commit

Permalink
Merge pull request #36 from felixriese/fix-type-hints
Browse files Browse the repository at this point in the history
FIX type hints
  • Loading branch information
felixriese committed Nov 19, 2022
2 parents 4db65c9 + c7d5b76 commit a4564ee
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 19 deletions.
2 changes: 1 addition & 1 deletion susi/SOMClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def predict_proba(
def _modify_weight_matrix_supervised(
self,
dist_weight_matrix: np.ndarray,
true_vector: Optional[np.array] = None,
true_vector: Optional[np.ndarray] = None,
learning_rate: Optional[float] = None,
) -> np.ndarray:
"""Modify weight matrix of the SOM.
Expand Down
7 changes: 4 additions & 3 deletions susi/SOMClustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def _init_unsuper_som(self) -> None:

else:
raise ValueError(
f"Invalid init_mode_unsupervised: {self.init_mode_unsupervised}."
f"Invalid init_mode_unsupervised: "
f"{self.init_mode_unsupervised}."
)

self.unsuper_som_ = som
Expand Down Expand Up @@ -372,7 +373,7 @@ def get_bmu(
return np.argwhere(a == np.min(a))[0]

def get_bmus(
self, X: np.ndarray, som_array: Optional[np.array] = None
self, X: np.ndarray, som_array: Optional[np.ndarray] = None
) -> Optional[List[Tuple[int, int]]]:
"""Get Best Matching Units for big datalist.
Expand Down Expand Up @@ -448,7 +449,7 @@ def _partition_bmus(
return n_jobs, n_datapoints_per_job.tolist(), [0] + starts.tolist()

def _set_bmus(
self, X: np.ndarray, som_array: Optional[np.array] = None
self, X: np.ndarray, som_array: Optional[np.ndarray] = None
) -> None:
"""Set BMUs in the current SOM object.
Expand Down
25 changes: 11 additions & 14 deletions susi/SOMEstimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import numpy as np
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -252,9 +252,7 @@ def _fit_estimator(self):

return self

def predict(
self, X: Sequence, y: Optional[Sequence] = None
) -> List[float]:
def predict(self, X: Sequence, y: Optional[Sequence] = None) -> np.ndarray:
"""Predict output of data X.
Parameters
Expand Down Expand Up @@ -346,7 +344,7 @@ def _calc_proba(self, bmu_pos: Tuple[int, int]) -> np.ndarray:
def _modify_weight_matrix_supervised(
self,
dist_weight_matrix: np.ndarray,
true_vector: Optional[np.array] = None,
true_vector: Optional[np.ndarray] = None,
learning_rate: Optional[float] = None,
) -> np.ndarray:
"""Modify weights of the supervised SOM, either online or batch.
Expand Down Expand Up @@ -406,8 +404,8 @@ def _train_supervised_som(self):
):

# select one input vector & calculate best matching unit (BMU)
dp = self._get_random_datapoint()
bmu_pos = self.bmus_[dp]
dp_index = self._get_random_datapoint_index()
bmu_pos = self.bmus_[dp_index]

# calculate learning rate and neighborhood function
learning_rate = self._calc_learning_rate(
Expand All @@ -423,7 +421,7 @@ def _train_supervised_som(self):
)
self.super_som_ = self._modify_weight_matrix_supervised(
dist_weight_matrix=dist_weight_matrix,
true_vector=self.y_[self.labeled_indices_][dp],
true_vector=self.y_[self.labeled_indices_][dp_index],
learning_rate=learning_rate,
)

Expand Down Expand Up @@ -504,18 +502,17 @@ def get_estimation_map(self) -> np.ndarray:
"""
return self.super_som_

def _get_random_datapoint(self) -> np.ndarray:
"""Find and return random datapoint from labeled dataset.
def _get_random_datapoint_index(self) -> int:
"""Find and return random datapoint index from labeled dataset.
Returns
-------
random_datapoint : np.ndarray
Random datapoint from labeled dataset
int
Random datapoint index from labeled dataset
"""
random_datapoint = None
if self.missing_label_placeholder is not None:
random_datapoint = np.random.choice(
random_datapoint: int = np.random.choice(
len(self.y_[self.labeled_indices_])
)
else:
Expand Down
2 changes: 1 addition & 1 deletion susi/SOMUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def decreasing_rate(

def check_estimation_input(
X: Sequence, y: Sequence, *, is_classification: bool = False
) -> Tuple[np.array, np.array]:
) -> Tuple[np.ndarray, np.ndarray]:
"""Check input arrays.
This function is adapted from sklearn.utils.validation.
Expand Down

0 comments on commit a4564ee

Please sign in to comment.