In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from einops import rearrange, reduce, repeat

In [4]:
class Conv_maxpool(nn.Module):
    def __init__(self, c1, c2):
        super(Conv_maxpool, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, 
                      padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(
            kernel_size=3, stride=2,
            padding=1, dilation=1,
            ceil_mode=False)
        
    def forward(self, x):
        return self.maxpool(self.conv(x))

In [7]:
a = torch.randn(10,6, 8, 8)
conv_mp = Conv_maxpool(6, 10)
conv_mp(a).shape

torch.Size([10, 10, 2, 2])

In [None]:
class ShuffleNetV2Block(nn.Module):
    def __init__(self, inp, oup, stride):
        super(ShuffleNetV2Block, self).__init__()
        self.stride = stride
        
        branch_features = oup // 2
        assert (self.stride != 1) or (inp == branch_features << 1)
        
        if self.stride == 2:
            self.branch1 = nn.Sequential(
                nn.Conv2d(inp, inp, kernel_size=3, stride=self.stride, padding=1, groups=inp),
                nn.BatchNorm2d(inp),
                nn.Conv2d(inp, branch_features, kernel_size=1, 
                          stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            self.branch1 = nn.Sequential()
            
        self.branch2 = nn.Sequential(
            nn.Conv2d(inp if (self.stride == 2) else branch_features, 
                      branch_features, kernel_size=1, stride=1, 
                      padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(branch_features, branch_features, kernel_size=3,
                      stride=self.stride, padding=1, groups=branch_features),
            nn.BatchNorm2d(branch_features),
            
            nn.Conv2d(branch_features, branch_features, kernel_size=1,
                      padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )
        
        
    def channel_shuffle(self, x, groups):
        N, C, H, W = x.size()
        out = x.view(N, groups, C//groups, H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
        return out
        

    def forward(self, x):
        if self.stride ==1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
            
        out = self.channel_shuffle(out, 2)
            
        return