-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
median pooling #3752
Comments
I came up with an implementation: def median_filter(image, window, pad=True):
assert len(image.shape) == 3
if type(window) is int:
window = (window, window)
elif type(window) is list or type(window) is tuple:
if len(window) != 2:
raise ValueError("sliding window must have only 2 dimensions")
if type(window[0]) is not int or type(window[1]) is not int:
raise ValueError("window sizes are not integer")
else:
raise ValueError("window argument is invalid")
# pad the input:
if pad:
pad_h, pad_w = window[0] // 2, window[1] // 2
x = C.pad(image, pattern=[(0,0), (pad_h, pad_h), (pad_w,pad_w)], mode=C.SYMMETRIC_PAD)
else:
x = image
if pad:
output_shape = image.shape
else:
output_shape = (image.shape[0], image.shape[-2] - window[-2] + 1, image.shape[-1] - window[-1] + 1)
output = None
for r in range(window[0]):
R_END = r + output_shape[1]
for c in range(window[1]):
C_END = c + output_shape[2]
temp_data = C.slice(x, axis=[1,2], begin_index=[r, c], end_index=[R_END, C_END])
temp_data = C.expand_dims(temp_data, -1)
if not output:
output = temp_data
else:
output = C.splice(output, temp_data, axis=3)
# output now has 4 static axes (C. H, W, k)
# sort entries along the last dimension
k = window[0]*window[1]
sorted_output = C.top_k(output, k)
sorted_output = sorted_output[0]
# take the median
mid_index = (k+1) // 2
median = C.slice(sorted_output, axis=-1, begin_index=mid_index, end_index=mid_index+1)
median = C.squeeze(median)
return median At the moment it only supports stride 1 pooling. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'd like to implement median pooling (i.e. median filtering) to smooth out the noise in feature maps. I can find the median with C.ops.top_k(), but how do I extract a local region with a sliding window?
Thanks
The text was updated successfully, but these errors were encountered: