Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

median pooling #3752

Closed
haixpham opened this issue Sep 24, 2019 · 1 comment
Closed

median pooling #3752

haixpham opened this issue Sep 24, 2019 · 1 comment

Comments

@haixpham
Copy link

I'd like to implement median pooling (i.e. median filtering) to smooth out the noise in feature maps. I can find the median with C.ops.top_k(), but how do I extract a local region with a sliding window?

Thanks

@haixpham
Copy link
Author

I came up with an implementation:

def median_filter(image, window, pad=True):
    assert len(image.shape) == 3
    if type(window) is int:
        window = (window, window)
    elif type(window) is list or type(window) is tuple:
        if len(window) != 2:
            raise ValueError("sliding window must have only 2 dimensions")
        if type(window[0]) is not int or type(window[1]) is not int:
            raise ValueError("window sizes are not integer")
    else:
        raise ValueError("window argument is invalid")
    
    # pad the input:
    if pad:
        pad_h, pad_w = window[0] // 2, window[1] // 2
        x = C.pad(image, pattern=[(0,0), (pad_h, pad_h), (pad_w,pad_w)], mode=C.SYMMETRIC_PAD)
    else:
        x = image
    
    if pad:
        output_shape = image.shape
    else:
        output_shape = (image.shape[0], image.shape[-2] - window[-2] + 1, image.shape[-1] - window[-1] + 1)
    
    output = None
    for r in range(window[0]):
        R_END = r + output_shape[1]
        for c in range(window[1]):
            C_END = c + output_shape[2]
            temp_data = C.slice(x, axis=[1,2], begin_index=[r, c], end_index=[R_END, C_END])
            temp_data = C.expand_dims(temp_data, -1)
            if not output:
                output = temp_data
            else:
                output = C.splice(output, temp_data, axis=3)
    # output now has 4 static axes (C. H, W, k)
    # sort entries along the last dimension
    k = window[0]*window[1]
    sorted_output = C.top_k(output, k)
    sorted_output = sorted_output[0]
    # take the median
    mid_index = (k+1) // 2
    median = C.slice(sorted_output, axis=-1, begin_index=mid_index, end_index=mid_index+1)
    median = C.squeeze(median)
    return median

At the moment it only supports stride 1 pooling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant