https://github.com/megvii-research/NBNet/blob/73112b185e022d0920f2f45c34c5bcf7c581d983/model.py#L71

In [2]:
!pip install megengine

Collecting megengine
  Downloading MegEngine-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl (850.7 MB)
[K     |███████████████████████████████▍| 834.1 MB 1.4 MB/s eta 0:00:12tcmalloc: large alloc 1147494400 bytes == 0x55e7d2604000 @  0x7f7cb1c12615 0x55e7989e33bc 0x55e798ac418a 0x55e7989e61cd 0x55e798ad8b3d 0x55e798a5a458 0x55e798a5502f 0x55e7989e7aba 0x55e798a5a2c0 0x55e798a5502f 0x55e7989e7aba 0x55e798a56cd4 0x55e798ad9986 0x55e798a56350 0x55e798ad9986 0x55e798a56350 0x55e798ad9986 0x55e798a56350 0x55e7989e7f19 0x55e798a2ba79 0x55e7989e6b32 0x55e798a5a1dd 0x55e798a5502f 0x55e7989e7aba 0x55e798a56cd4 0x55e798a5502f 0x55e7989e7aba 0x55e798a55eae 0x55e7989e79da 0x55e798a56108 0x55e798a5502f
[K     |████████████████████████████████| 850.7 MB 11 kB/s 
Collecting redispy
  Downloading redispy-3.0.0-py2.py3-none-any.whl (64 kB)
[K     |████████████████████████████████| 64 kB 2.6 MB/s 
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Collecting mprop
  Downlo

### NBNet

![image](https://user-images.githubusercontent.com/44194558/152297625-a09af994-fb54-4f39-a130-fbae75e78f63.png)


### UNet Convolutional Block

(a) conv-block

![image](https://user-images.githubusercontent.com/44194558/152296600-a14d8633-1ec4-4020-ab68-c51687ffec91.png)

In [3]:
import megengine as mge
import megengine.module as nn
import megengine.functional as F

In [4]:
class UNetConvBlock(nn.Module):

    def __init__(self, in_size, out_size, downsample, relu_slope):
        super(UNetConvBlock, self).__init__()
        
        # 1. Convolutional Block
        ## 1.1 (3x3 Conv + LeakyReLU) x 2
        self.block = nn.Sequential(
            nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope),
            nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope))
        
        ## 1.2 Skip-conn with 1x1 Conv
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)
        
        # 2. Downsample
        self.downsample = downsample
        if downsample:
            self.downsample = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        # 1. Conv Block with Skip-conn
        out = self.block(x)
        sc = self.shortcut(x)
        out = out + sc

        # 2. downsample
        # ou은 skip_blocks에 전달, out_down은 다음 stage에 전달 
        if self.downsample:
            out_down = self.downsample(out)
            return out_down, out 
        else:
            return out

### Basis Generation

합성곱 연산을 통해 Subspace를 span하는 기저 벡터 생성

![image](https://user-images.githubusercontent.com/44194558/152303075-d0ab2d33-6403-4e9e-a06b-4d81a17c792d.png)

<br/>

![image](https://user-images.githubusercontent.com/44194558/152309965-0d82cec3-7014-4a1f-8f54-0ddcf513d693.png)

SSA 모듈의 conv-block에 해당

In [5]:
# self.subnet = Subspace(in_size, self.num_subspace) - 16차원의 subspace

class Subspace(nn.Module):

    def __init__(self, in_size, out_size):
        super(Subspace, self).__init__()
        self.blocks = []
        self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2))
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        sc = self.shorcut(x)
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)

        return x + sc

### Skip-connections

![image](https://user-images.githubusercontent.com/44194558/152300776-905ecc6b-dabc-4242-90a3-9d5f8dfa12c2.png)

In [6]:
class skip_blocks(nn.Module):

    def __init__(self, in_size, out_size, repeat_num=1):
        super(skip_blocks, self).__init__()
        self.re_num = repeat_num  # 각 stage마다 conv-block의 개수가 다름
        mid_c = 128
        
        # 각 stage마다 지정된 횟수 만큼 conv-block쌓기
        self.blocks = []
        self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2))
        for i in range(self.re_num - 2):
            self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2))
        self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2))
        
        # Skip-conn
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        sc = self.shortcut(x)

        # 지정된 횟수의 conv-block 통과
        for m in self.blocks:
            x = m(x)

        return x + sc

### UpBlock

![image](https://user-images.githubusercontent.com/44194558/152300923-0a113ba6-5582-40fc-b62f-587b20602b02.png)

In [7]:
class UNetUpBlock(nn.Module):

    def __init__(self, in_size, out_size, relu_slope, subnet_repeat_num, subspace_dim=16):
        super(UNetUpBlock, self).__init__()

        # Up-sampling
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)

        # Convolutional Block
        self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)
        self.num_subspace = subspace_dim
        print(self.num_subspace, subnet_repeat_num)

        # SSA Module with Basis Generation
        self.subnet = Subspace(in_size, self.num_subspace)
        self.skip_m = skip_blocks(out_size, out_size, subnet_repeat_num)  # subnet의 입력

    def forward(self, x, bridge):
        # Up-sampling
        up = self.up(x)  # X2

        # skip_blocks의 출력 (bridge : Encoder의 각 stage에서의 분기된 출력 - X1)
        bridge = self.skip_m(bridge)

        # Concat for Basis Generation
        out = F.concat([up, bridge], 1)  # [X1, X2]
        if self.subnet:
            b_, c_, h_, w_ = bridge.shape
            sub = self.subnet(out)  # [X1, X2]를 입력하여 Basis Generation
            
            # Basis vectors V 계산
            V_t = sub.reshape(b_, self.num_subspace, h_*w_)
            V_t = V_t / (1e-6 + F.abs(V_t).sum(axis=2, keepdims=True))
            V = V_t.transpose(0, 2, 1)
            
            # Projection matrix P 계산
            mat = F.matmul(V_t, V)  
            mat_inv = F.matinv(mat)
            project_mat = F.matmul(mat_inv, V_t)

            # Projection
            bridge_ = bridge.reshape(b_, c_, h_*w_)
            project_feature = F.matmul(project_mat, bridge_.transpose(0, 2, 1))

            # Y=PX1 (SSA 모듈의 출력)
            bridge = F.matmul(V, project_feature).transpose(0, 2, 1).reshape(b_, c_, h_, w_)

            # X1의 projection Y와 X2의 concat
            out = F.concat([up, bridge], 1)

        # 마지막 conv block 통과    
        out = self.conv_block(out)

        return out  # Decoder의 다음 stage에 전달

### UNetD

In [8]:
def conv3x3(in_chn, out_chn, bias=True):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
    return layer

def conv_down(in_chn, out_chn, bias=False):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
    return layer

In [9]:
class UNetD(nn.Module):

    def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.2, subspace_dim=16):
        super(UNetD, self).__init__()
        self.depth = depth
        self.down_path = []
        prev_channels = self.get_input_chn(in_chn)

        # Encoder (down-sampling)
        for i in range(depth):
            downsample = True if (i+1) < depth else False
            self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope))  # 스테이지가 증가할 때 마다 크기는 줄고, 채널은 증가
            prev_channels = (2**i) * wf

        # self.ema = EMAU(prev_channels, prev_channels//8)
        # Decoder (Up-sampling)
        self.up_path = []
        subnet_repeat_num = 1
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope, subnet_repeat_num, subspace_dim))
            prev_channels = (2**i)*wf
            subnet_repeat_num += 1
        
        # denoised output을 출력하는 최종 layer
        self.last = conv3x3(prev_channels, in_chn, bias=True)
        #self._initialize()

    def forward(self, x1):
        # Encoder
        blocks = []
        for i, down in enumerate(self.down_path):  # down : UNetConvBlock 참고
            # print(x1.shape)
            if (i+1) < self.depth:
                x1, x1_up = down(x1)  # out_down (다음 stage에 전달), out
                blocks.append(x1_up)
            else:
                x1 = down(x1)
        # print(x1.shape)
        # x1 = self.ema(x1)
        for i, up in enumerate(self.up_path):  # up : UNetUpBlock 참고
            # print(x1.shape, blocks[-i-1].shape)
            x1 = up(x1, blocks[-i-1])  # x, bridge(skip_blocks의 입력으로 전달)
        
        pred = self.last(x1)

        return pred

    def get_input_chn(self, in_chn):
        return in_chn

    def _initialize(self):
        gain = nn.init.calculate_gain('leaky_relu', 0.20)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                print("weight")
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    print("bias")
                    nn.init.zeros_(m.bias)

In [None]:
import numpy as np

NBNet = UNetD(3)

input = mge.tensor(np.random.randn(1, 3, 128, 128).astype(np.float32))
pred = NBNet(input)  # (1, 3, 128, 128)