In [58]:
### Python import ###
import torch
import torch.nn as nn
import torch.nn.functional as F

### Refiner import ###
from refiners.fluxion import layers as fl
from refiners.training_utils.trainer import seed_everything
### Local import ###
from models.architecture_utils import (
    LayerNorm2d, # function to rewrite using refiners
    Local_Base, # function to rewrite using refiners
    AdaptiveAvgPool2d,
    Dropout
)

seed = 42
seed_everything(seed)


In [59]:
class SimpleGate_base(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        # the .chunk() method splits a tensor into a specified number of chunks along a given dimension
        return x1 * x2

class SimplifiedChannelAttention(fl.Module):
    def __init__(self, c, DW_Expand = 2) -> None:
        super().__init__(
            AdaptiveAvgPool2d(1),
            fl.Conv2d(in_channels=(c*DW_Expand)//2, out_channels=(c*DW_Expand)//2, kernel_size=1, padding=0, stride=1, groups=1, bias=True),
        )

class MultiplyLayers(fl.Module):
    def forward(self, x, layer):
        new_x = x * layer(x)
        return new_x

In [60]:

class NAFBlock_debase(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        
        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )
        # SimpleGate
        self.sg = SimpleGate_base()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp
        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)
        x = self.dropout1(x)
        y = inp + x * self.beta
        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)
        x = self.dropout2(x)
        return y + x * self.gamma

In [61]:
class SimpleGate(fl.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        # the .chunk() method splits a tensor into a specified number of chunks along a given dimension
        return x1 * x2

class CustomConditionedDropout(fl.Module):
    def __init__(self, drop_out_rate) -> None:
        super().__init__()
        self.drop_out_rate = drop_out_rate
    def forward(self, x):
        if self.drop_out_rate > 0.:
            x = Dropout(x)
        else :
            x = fl.Identity(x)
        return x

class NAFBlock(fl.Chain):
    def __init__(self, c, DW_Expand = 2, FFN_Expand = 2, drop_out_rate = 0.) -> None:
        super().__init__(
            # TODO : x = inp
            # x = self.norm1(x)
            # LayerNorm2d(c),

            fl.Conv2d(in_channels=c, out_channels=c*DW_Expand, kernel_size=1, padding=0, stride=1, groups=1, use_bias=True),
            fl.Conv2d(in_channels=c*DW_Expand, out_channels=c*DW_Expand, kernel_size=3, padding=1, stride=1, groups=c*DW_Expand, use_bias=True),
            SimpleGate(),

            # x = x * self.sca(x) (sca : simplified channel attention)
            # try with fl.Matmul() ? cf chain.py in refiners repo
            # MultiplyLayers(SimplifiedChannelAttention(c, DW_Expand)), # ??

            fl.Conv2d(in_channels=(c*DW_Expand)//2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, use_bias=True),
            CustomConditionedDropout(drop_out_rate),

            # TODO :  y = inp + x * self.beta
            #         x = self.conv4(self.norm2(y))
            # LayerNorm2d(inp + x * self.beta),

            SimpleGate(),
            fl.Conv2d(in_channels=FFN_Expand*c, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, use_bias=True),
            CustomConditionedDropout(drop_out_rate)

            # TODO : return y + x * self.gamma
        )

In [62]:
# test the NAFBlock
naf_base = NAFBlock_debase(3)
display(naf_base)
# print(naf_base.forward(torch.rand(1, 3, 32, 32)))

naf = NAFBlock(3, drop_out_rate=0.)
display(naf)

NAFBlock_debase(
  (conv1): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6)
  (conv3): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (sca): Sequential(
    (0): AdaptiveAvgPool2d(output_size=1)
    (1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  )
  (sg): SimpleGate_base()
  (conv4): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1))
  (conv5): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (norm1): LayerNorm2d()
  (norm2): LayerNorm2d()
  (dropout1): Identity()
  (dropout2): Identity()
)

(CHAIN) NAFBlock()
    ├── Conv2d(in_channels=3, out_channels=6, kernel_size=(1, 1)) #1
    ├── Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), padding=(1, 1), groups=6) #2
    ├── SimpleGate() #1
    ├── Conv2d(in_channels=3, out_channels=3, kernel_size=(1, 1)) #3
    ├── CustomConditionedDropout(drop_out_rate=0.0) #1
    ├── SimpleGate() #2
    ├── Conv2d(in_channels=6, out_channels=3, kernel_size=(1, 1)) #4
    └── CustomConditionedDropout(drop_out_rate=0.0) #2