Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the conv_transpose2d usage? #5

Open
sdw95927 opened this issue Apr 30, 2021 · 3 comments
Open

Update the conv_transpose2d usage? #5

sdw95927 opened this issue Apr 30, 2021 · 3 comments

Comments

@sdw95927
Copy link

I implemented the excellent scripts and found that conv_transpose2d does not work properly for my own work. So I updated it in functional.conv line27 as follows:

    # relevance_input  = F.conv_transpose2d(relevance_output, weight, None, padding=1)
    if ctx.stride[0] >= 2:
        output_padding = 1
    else:
        output_padding = 0
    relevance_input  = F.conv_transpose2d(relevance_output, weight, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)

and also here:

        def f(X1, X2, W1, W2, ctx): 

            # Z1  = F.conv2d(X1, W1, bias=None, stride=1, padding=1) 
            # Z2  = F.conv2d(X2, W2, bias=None, stride=1, padding=1)
            Z1 = F.conv2d(X1, W1, None, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
            Z2 = F.conv2d(X2, W2, None, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
            Z   = Z1 + Z2

            rel_out = relevance_output / (Z + (Z==0).float()* 1e-6)

            # t1 = F.conv_transpose2d(rel_out, W1, bias=None, padding=1) 
            # t2 = F.conv_transpose2d(rel_out, W2, bias=None, padding=1)
            if ctx.stride[0] >= 2:
                output_padding = 1
            else:
                output_padding = 0
            t1 = F.conv_transpose2d(rel_out, W1, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)
            t2 = F.conv_transpose2d(rel_out, W2, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)

            r1  = t1 * X1
            r2  = t2 * X2

            return r1 + r2

Not sure if this is my own issue, but the above change fixed my problem.

@miladsikaroudi
Copy link

Thank you @sdw95927 for posting this.
I just caught an error 'Conv2DAlpha1Beta0Backward' object has no attribute 'stride'.
I am wondering where should I set the stride for the ctx after making mentioned changes.

@sdw95927
Copy link
Author

Thank you @sdw95927 for posting this. I just caught an error 'Conv2DAlpha1Beta0Backward' object has no attribute 'stride'. I am wondering where should I set the stride for the ctx after making mentioned changes.

Can you take a screen shot of both your script and error?

@miladsikaroudi
Copy link

miladsikaroudi commented Apr 25, 2022

The problem is fixed. Thank you so much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants