Permalink
Browse files

hand support for color images in conv kmeans features extraction

  • Loading branch information...
1 parent 79d9d7b commit 38d1ac1d99e048d645bb5eeb1e15662529df3588 @ogrisel committed Jan 7, 2011
@@ -94,9 +94,9 @@
X_train = X_train.reshape((X_train.shape[0], 3, 32, 32)).transpose(0, 2, 3, 1)
X_test = X_test.reshape((X_test.shape[0], 3, 32, 32)).transpose(0, 2, 3, 1)
-# convert to graylevel images for now
-X_train = X_train.mean(axis=-1)
-X_test = X_test.mean(axis=-1)
+## convert to graylevel images for now
+#X_train = X_train.mean(axis=-1)
+#X_test = X_test.mean(axis=-1)
#pl.imshow(X_train[0], interpolation='nearest'); pl.show()
# scale dataset
@@ -145,7 +145,7 @@
pl.figure()
for i in range(n_row * n_col):
pl.subplot(n_row, n_col, i + 1)
- pl.imshow(filters[i].reshape((patch_size, patch_size)),
+ pl.imshow(filters[i].reshape((patch_size, patch_size, 3)),
cmap=pl.cm.gray, interpolation="nearest")
pl.xticks(())
pl.yticks(())
@@ -117,9 +117,6 @@ def extract_patches2d(images, image_size, patch_size, offsets=(0, 0)):
The extracted patches are not overlapping to avoid having to copy any
memory.
- TODO: right now images are graylevel only: add channels to the right
- position according to the natural memory layout from PIL or scikits.image
-
Parameters
----------
images: array with shape (n_images, i_h, i_w) or (n_images, i_h * i_w)
@@ -180,12 +177,13 @@ def extract_patches2d(images, image_size, patch_size, offsets=(0, 0)):
array([[ 9, 10],
[13, 14]])
"""
- i_h, i_w = image_size
+ i_h, i_w = image_size[:2]
p_h, p_w = patch_size
images = np.atleast_2d(images)
n_images = images.shape[0]
- images = images.reshape((n_images, i_h, i_w))
+ images = images.reshape((n_images, i_h, i_w, -1))
+ n_colors = images.shape[-1]
# handle offsets and compute remainder to find total number of patches
o_h, o_w = offsets
@@ -196,16 +194,27 @@ def extract_patches2d(images, image_size, patch_size, offsets=(0, 0)):
# extract the image areas that can be sliced into whole patches
max_h = -r_h or None
max_w = -r_w or None
- images = images[:, o_h:max_h, o_w:max_w]
+ images = images[:, o_h:max_h, o_w:max_w, :]
+
+ # put the color dim before the sliceable dims
+ images = images.transpose((0, 3, 1, 2))
- # slice the images into patches
- patches = images.reshape((n_images, n_h, p_h, n_w, p_w))
+ # slice the last two dims of the images into patches
+ patches = images.reshape((n_images, n_colors, n_h, p_h, n_w, p_w))
- # reorganize the patches into the expected shape
- patches = patches.transpose((2, 4, 0, 1, 3)).reshape((p_h, p_w, n_patches))
+ # reorganize the dims to put n_image, n_h, and n_w at the end so that
+ # reshape will combine them all in n_patches
+ patches = patches.transpose((3, 5, 1, 0, 2, 4))
+ patches = patches.reshape((p_h, p_w, n_colors, n_patches))
# one more transpose to put the n_patches as the first dom
- return patches.transpose((2, 0, 1))
+ patches = patches.transpose((3, 0, 1, 2))
+
+ # remove the color dimension if useless
+ if patches.shape[-1] == 1:
+ return patches.reshape((n_patches, p_h, p_w))
+ else:
+ return patches
class ConvolutionalKMeansEncoder(BaseEstimator):
@@ -283,10 +292,8 @@ def _check_images(self, X):
n_samples = X.shape[0]
if self.image_size is None:
- if len(X.shape) == 3:
- self.image_size = X.shape[1:]
- elif len(X.shape) > 3:
- raise ValueError("%r is not a valid images shape" % (X.shape,))
+ if len(X.shape) > 3:
+ self.image_size = X.shape[1:3]
else:
# assume square images
_, n_features = X.shape

0 comments on commit 38d1ac1

Please sign in to comment.