In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MDTA(nn.Module):
    '''***IMPORTANT*** - The channels must be zero when divided by num_heads'''
    def __init__(self, channels, num_heads): 
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1))
        
        # *3 for q, k, v & chunk(3, dim=1)
        # 1x1 Conv to aggregate pixel-wise cross-channel context 
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)  
        # 3x3 DWConv to encode channel-wise spatial context
        self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False)
        # 1x1 Point-wise Conv
        self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        '''(N, C, H, W) -> (N, C, H, W)
        Output of MDTA feature should be added to input feature x'''
        b, c, h, w = x.shape
        q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1)  # (N, C, H, W)
        
        # divide the # of channels into heads & learn separate attention map
        q = q.reshape(b, self.num_heads, -1, h * w)  # (N, num_heads, C/num_heads, HW)
        k = k.reshape(b, self.num_heads, -1, h * w)
        v = v.reshape(b, self.num_heads, -1, h * w)
        q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
        
        # CxC Attention map instead of HWxHW (when num_heads=1)
        attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)  # (N, num_heads, C/num_heads, C/num_heads)
        out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w))  # attn*v: (N, num_heads, C/num_heads, HW)

        return out


class GDFN(nn.Module):
    def __init__(self, channels, expansion_factor):
        super(GDFN, self).__init__()

        hidden_channels = int(channels * expansion_factor)  # channel expansion 
        # 1x1 conv to extend feature channel
        self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)

        # 3x3 DConv (groups=input_channels) -> each input channel is convolved with its own set of filters
        self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1,
                              groups=hidden_channels * 2, bias=False)
        
        # 1x1 conv to reduce channels back to original input dimension
        self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        '''HxWxC -> HxWxC
        Output of GDFN feature should be added to input feature x'''
        x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1)
        # Gating: the element-wise product of 2 parallel paths of linear transformation layers 
        x = self.project_out(F.gelu(x1) * x2)

        return x


class TransformerBlock(nn.Module):
    '''***IMPORTANT*** - The channels must be zero when divided by num_heads'''
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()
        assert channels % num_heads == 0
        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        '''(N, C, H, W) -> (N, C, H, W)'''
        b, c, h, w = x.shape        
        # Add MDTA output feature
        x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                          .contiguous().reshape(b, c, h, w))
        # ADD GDFN output feature
        x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                         .contiguous().reshape(b, c, h, w))
        return x


class DownSample(nn.Module):
    '''Channel x 2, Resolution x 1/2 by PixelUnshuffle'''
    def __init__(self, channels):
        super(DownSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))  # (N, C, H, W) -> (N, 4C, H/2, W/2)

    def forward(self, x):
        '''(N, C, H, W) -> (N, 2C, H/2, W/2)'''
        return self.body(x)


class UpSample(nn.Module):
    '''Channel x 1/2, Resolution x 2 by PixelShuffle'''
    def __init__(self, channels):
        super(UpSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelShuffle(2))  # (N, C, H, W) -> (N, C/4, 2H, 2W)

    def forward(self, x):
        '''(N, C, H, W) -> (N, C/2, 2H, 2W)'''
        return self.body(x)


class Restormer(nn.Module):
    def __init__(self, num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4,
                 expansion_factor=2.66):
        super(Restormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, 48, kernel_size=3, padding=1, bias=False)
        
        self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(
            num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in
                                       zip(num_blocks, num_heads, channels)])
        # TransformerBlock(48, 1, r) x 4
        # TransformerBlock(96, 2, r) x 6
        # TransformerBlock(192, 4, r) x 6
        # TransformerBlock(384, 8, r) x 8
        
        # the number of down sample or up sample == the number of encoder - 1
        self.downs = nn.ModuleList([DownSample(num_ch) for num_ch in channels[:-1]])
        # DownSample(48)
        # DownSample(96)
        # DownSample(192)
        # DownSample(384)

        self.ups = nn.ModuleList([UpSample(num_ch) for num_ch in list(reversed(channels))[:-1]])
        # UpSample(384)
        # UpSample(192)
        # UpSample(96)

        # the number of reduce block == the number of decoder - 1
        self.reduces = nn.ModuleList([nn.Conv2d(channels[i], channels[i - 1], kernel_size=1, bias=False)
                                      for i in reversed(range(2, len(channels)))])
        # Conv2d(384, 192)
        # Conv2d(192, 96)
        
        # the number of decoder == the number of encoder - 1
        self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor)
                                                       for _ in range(num_blocks[2])])])
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor)
                                             for _ in range(num_blocks[1])]))
        # the channel of last one is not change
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                             for _ in range(num_blocks[0])]))
        # TransformerBlock(192, 4, r) x 6
        # TransformerBlock(96, 2, r) x 6
        # TransformerBlock(96, 1, r) x 4

        self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                          for _ in range(num_refinement)])
        # TransformerBlock(96, 2, r) x 4

        self.output = nn.Conv2d(96, 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):  # (N, 3, H, W)

        # low-level feature
        fo = self.embed_conv(x)   # (N, 48, H, W)

        out_enc1 = self.encoders[0](fo)  # (N, 48, H, W)
        out_enc2 = self.encoders[1](self.downs[0](out_enc1))  # (N, 96, H/2, W/2) -> (N, 96, H/2, W/2)
        out_enc3 = self.encoders[2](self.downs[1](out_enc2))  # (N, 192, H/4, W/4) -> (N, 192, H/4, W/4)
        out_enc4 = self.encoders[3](self.downs[2](out_enc3))  # (N, 384, H/8, W/8) -> (N, 384, H/8, W/8)
        
        # aggregate(concatenate) low-level feature of encoder with the high-level feature of decoder
        out_dec3 = self.decoders[0](self.reduces[0](torch.cat([self.ups[0](out_enc4), out_enc3], dim=1)))  # (N, 192, H/4, W/4) -> (N, 192, H/4, W/4)
        out_dec2 = self.decoders[1](self.reduces[1](torch.cat([self.ups[1](out_dec3), out_enc2], dim=1)))  # (N, 96, H/2, W/2) -> (N, 96, H/2, W/2)

        # deep feature
        fd = self.decoders[2](torch.cat([self.ups[2](out_dec2), out_enc1], dim=1))  # (N, 96, H, W)
        
        # refinement at high resolution
        fr = self.refinement(fd)  # (N, 96, H, W)

        # restored_img = residual_img + degraded_img
        out = self.output(fr) + x  #  (N, 3, H, W)

        return out

In [6]:
x = torch.randn([2, 3, 64, 64]).cuda()
model = Restormer().cuda()
out = model(x)
print(out.shape)

torch.Size([2, 3, 64, 64])
