In [None]:

  
from torch import nn

class cinvolution(nn.Module):

    def __init__(self,inplanes,
                 channels,
                 kernel_size,
                 stride,
                 groups=16):
        super(cinvolution, self).__init__()
        self.inplanes=inplanes
        self.kernel_size = kernel_size
        self.stride = stride
        self.channels = channels
        self.groups = groups
        self.group_channels = self.channels // self.groups
        
        self.p1=nn.AdaptiveAvgPool2d(1)
        self.p2=nn.AdaptiveMaxPool2d(1)
        self.conv2=nn.Conv2d(inplanes ,kernel_size**2 * self.groups,1)

        if stride > 1:
            self.avgpool = nn.AvgPool2d(stride, stride)
        self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)

    def forward(self, x,xw):
        b, c, h, w = x.shape
        weight = xw
        idi=h//2
        weight=weight[:,:,idi:idi+1,idi:idi+1]
        #w0=weight[:,:,idi:idi+1,idi:idi+1]
        #w1=self.p1(weight)
        #w2=self.p2(weight)
        #weight=torch.cat([w0,w1],1)
        weight=self.conv2(weight)
        weight=weight.repeat(1,1,h,w)
        
        b, c, h, w = weight.shape
        weight = weight.view(b, self.groups, self.kernel_size**2, h, w).unsqueeze(2)
        out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size**2, h, w)
        out = (weight * out).sum(dim=3).view(b, self.channels, h, w)
        return out


class CIVoubottleneck1(nn.Module):
    def __init__(self, inplanes, planes, iksize=5, groups=16):
        super(CIVoubottleneck1, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(True)
        self.conv1 = nn.Conv2d(inplanes, planes, 1, 1, 0, bias=False)

        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2_0 = cinvolution(inplanes,planes, iksize, 1, groups=groups)

        
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, inplanes, 1, 1, 0, bias=False)
        self.dropout = nn.Dropout(p=0.5) 
        
    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)

        out=self.conv2_0(out,identity)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        out += self.dropout(identity)

        return out



class DRCIN(nn.Module):
    def __init__(self, num_classes, channels,groups=8,iksize=3,numblocks=5,inplanes=16,midinplanes=8):
        super(DRCIN, self).__init__()
        self.groups = groups
        self.iksize = iksize
        self.numblocks = numblocks
        self.inplanes = inplanes
        self.midinplanes = midinplanes
        self.conv1 = nn.Conv2d(channels, self.inplanes, kernel_size=1, stride=1, padding=0, bias=False)
 
        layers = []
        for _ in range(self.numblocks):
            layers.append(CIVoubottleneck1(self.inplanes, self.midinplanes, iksize=self.iksize, groups=self.groups))
        self.block = nn.Sequential(*layers)

        self.bn = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(True)

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.linear = nn.Linear(self.inplanes, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block(out)
        out = self.relu(self.bn(out))

        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

