Skip to content

Commit

Permalink
feat: accept pca/kmeans kwargs in train/test split
Browse files Browse the repository at this point in the history
  • Loading branch information
martibosch committed Mar 28, 2024
1 parent 0c9da36 commit 0729a3d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
44 changes: 33 additions & 11 deletions detectree/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def train_test_split(
self,
*,
method="cluster-II",
num_components=12,
n_components=12,
num_img_clusters=4,
train_prop=0.01,
return_evr=False,
pca_kwargs=None,
kmeans_kwargs=None,
):
"""
Select the image/tiles to be used for traning.
Expand All @@ -167,7 +169,7 @@ def train_test_split(
----------
method : {'cluster-I', 'cluster-II'}, optional (default 'cluster-II')
Method used in the train/test split.
num_components : int, optional (default 12)
n_components : int, default 12
Number of principal components into which the image descriptors should be
represented when applying the *k*-means clustering.
num_img_clusters : int, optional (default 4)
Expand All @@ -178,6 +180,12 @@ def train_test_split(
return_evr : bool, optional (default False)
Whether the explained variance ratio of the principal component
analysis should be returned
pca_kwargs : dict, optional
Keyword arguments to be passed to the `sklearn.decomposition.PCA` class
constructor (except for `n_components`).
kmeans_kwargs : dict, optional
Keyword arguments to be passed to the `sklearn.cluster.KMeans` class
constructor (except for `n_clusters`).
Returns
-------
Expand All @@ -187,10 +195,16 @@ def train_test_split(
Expected variance ratio of the principal component analysis.
"""
X = self.descr_feature_matrix
pca = decomposition.PCA(n_components=num_components).fit(X)
if pca_kwargs is None:
_pca_kwargs = {}
else:
_pca_kwargs = pca_kwargs.copy()
# if `n_components` is provided in `pca_kwargs`, it will be ignored
_ = _pca_kwargs.pop("n_components", None)
pca = decomposition.PCA(n_components=n_components, **_pca_kwargs).fit(X)

X_pca = pca.transform(X)
X_cols = range(num_components)
X_cols = range(n_components)
df = pd.concat(
(
pd.Series(self.img_filepaths, name="img_filepath"),
Expand All @@ -199,10 +213,16 @@ def train_test_split(
axis=1,
)

if kmeans_kwargs is None:
_kmeans_kwargs = {}
else:
_kmeans_kwargs = kmeans_kwargs.copy()
# if `n_clusters` is provided in `kmeans_kwargs`, it will be ignored
_ = _kmeans_kwargs.pop("n_clusters", None)
if method == "cluster-I":
km = cluster.KMeans(n_clusters=int(np.ceil(train_prop * len(df)))).fit(
X_pca
)
km = cluster.KMeans(
n_clusters=int(np.ceil(train_prop * len(df))), **_kmeans_kwargs
).fit(X_pca)
closest, _ = metrics.pairwise_distances_argmin_min(
km.cluster_centers_, df[X_cols]
)
Expand All @@ -216,16 +236,18 @@ def cluster_train_test_split(img_cluster_ser):
# use `ceil` to avoid zeros, which might completely ignore a significant
# image cluster
num_train = int(np.ceil(train_prop * len(X_cluster_df)))
cluster_km = cluster.KMeans(n_clusters=num_train).fit(X_cluster_df)
cluster_km = cluster.KMeans(n_clusters=num_train, **_kmeans_kwargs).fit(
X_cluster_df
)
closest, _ = metrics.pairwise_distances_argmin_min(
cluster_km.cluster_centers_, X_cluster_df
)
train_idx = X_cluster_df.iloc[closest].index
return [True if i in train_idx else False for i in X_cluster_df.index]

df["img_cluster"] = cluster.KMeans(n_clusters=num_img_clusters).fit_predict(
X_pca
)
df["img_cluster"] = cluster.KMeans(
n_clusters=num_img_clusters, **_kmeans_kwargs
).fit_predict(X_pca)
df["train"] = df.groupby("img_cluster")["img_cluster"].transform(
cluster_train_test_split
)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_detectree.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,31 @@ def test_train_test_split(self):
self.assertIsInstance(split_df, pd.DataFrame)
self.assertIsInstance(evr, float)

# test pca n_components and kwargs
# evr is greater or equal with more components (given the same seed)
random_state = 42
self.assertGreaterEqual(
*(
ts.train_test_split(
return_evr=True,
pca_kwargs={
"n_components": n_components,
"random_state": random_state,
},
)[1]
for n_components in (4, 2)
)
)
# test kwargs for kmeans too
# the result should be the same with the same seed
tts_kwargs = dict(
pca_kwargs={"random_state": random_state},
kmeans_kwargs={"random_state": random_state},
)
self.assertTrue(
ts.train_test_split(**tts_kwargs).equals(ts.train_test_split(**tts_kwargs))
)


class TestImageDescriptor(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 0729a3d

Please sign in to comment.