Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

involution卷积替换问题 #110

Open
right135 opened this issue Oct 1, 2022 · 3 comments
Open

involution卷积替换问题 #110

right135 opened this issue Oct 1, 2022 · 3 comments

Comments

@right135
Copy link

right135 commented Oct 1, 2022

up,最近看了一些论文显示involution卷积效果不错,想来替换试试,但是involution官方代码,参数和yolox的不太匹配,调整了好久都一直报错,能麻烦up指点一下 参数该如何修改呢qwq?
import torch.nn as nn
from mmcv.cnn import ConvModule

class involution(nn.Module):

def __init__(self,
             channels,
             kernel_size,
             stride):
    super(involution, self).__init__()
    self.kernel_size = kernel_size
    self.stride = stride
    self.channels = channels
    reduction_ratio = 4
    self.group_channels = 16
    self.groups = self.channels // self.group_channels
    self.conv1 = ConvModule(
        in_channels=channels,
        out_channels=channels // reduction_ratio,
        kernel_size=1,
        conv_cfg=None,
        norm_cfg=dict(type='BN'),
        act_cfg=dict(type='ReLU'))
    self.conv2 = ConvModule(
        in_channels=channels // reduction_ratio,
        out_channels=kernel_size**2 * self.groups,
        kernel_size=1,
        stride=1,
        conv_cfg=None,
        norm_cfg=None,
        act_cfg=None)
    if stride > 1:
        self.avgpool = nn.AvgPool2d(stride, stride)
    self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)

def forward(self, x):
    weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))
    b, c, h, w = weight.shape
    weight = weight.view(b, self.groups, self.kernel_size**2, h, w).unsqueeze(2)
    out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size**2, h, w)
    out = (weight * out).sum(dim=3).view(b, self.channels, h, w)
    return out
@right135
Copy link
Author

right135 commented Oct 1, 2022

image
就只有一个参数不一样,那就是involution只有一个channels,而up的代码有in_channels和out_channels,我尝试的时候是把out_channels去掉了,然后就报错了。
image

@bubbliiiing
Copy link
Owner

通道写错了而已

@right135
Copy link
Author

right135 commented Oct 8, 2022

嘶,是我的通道写错了吗?没太懂qwq,那应该写成什么呀,out_channels?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants