Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxPool2d backward pass how to? #1

Open
moisestohias opened this issue Feb 26, 2023 · 0 comments
Open

MaxPool2d backward pass how to? #1

moisestohias opened this issue Feb 26, 2023 · 0 comments

Comments

@moisestohias
Copy link

I've implemented the forward pass of the max pooling, but I wasn't able to use the index to perform the backward pass, your version, uses the img_to_row tricks, but I want to implement this without too much indexing and looping. Any idea how you can do this using this implementation?

import numpy as np
as_strided = as_strided

def pool2d(Z, K:tuple=(2,2)):
    """ performs the windowing, and padding if needed"""
    KH, KW = K  # Kernel Height & Width
    N, C, ZH, ZW = Z.shape # Input: NCHW Batch, Channels, Height, Width
    Ns, Cs, Hs, Ws = Z.strides
    EdgeH, EdgeW = ZH%KH, ZW%KW # How many pixels left on the edge
    if (EdgeH!=0 or EdgeW!=0): # If there are pixels left we need to pad
        PadH, PadW = KH-EdgeH, KW-EdgeW
        PadTop, PadBottom = ceil(PadH/2), floor(PadH/2)
        PadLeft, PadRight = ceil(PadW/2), floor(PadW/2)
        Z = np.pad(Z, ((0,0),(0,0), (PadTop, PadBottom), (PadLeft, PadRight)))
        N, C, ZH, ZW = Z.shape #NCHW
        Ns, Cs, Hs, Ws = Z.strides
    Zstrided = np.lib.stride_tricks.as_strided(Z, shape=(N,C,ZH//KH, ZW//KW, KH, KW), strides=(Ns, Cs, Hs*KH, Ws*KW,Hs, Ws))
    return Zstrided.reshape(N,C,ZH//KH, ZW//KW, KH*KW) # reshape to flatten pooling windows to be 1D-vector

def maxpool2d(Z, K:tuple=(2,2)):
    ZP = pool2d(Z, K)
    MxP = np.max(ZP, axis=(-1))
    Inx = np.argmax(ZP, axis=-1)
    return MxP, Inx

def unmaxpool2d(ZP, Indx, K:tuple=(2,2)):
    ZN,ZC,ZH,ZW = ZP.shape
    KH, KW = K
    Z = np.zeros((ZN,ZC,ZH*KH,ZW*KW))
    _ZP = pool2d(Z, K)
    # ... Where to go next
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant