# MIMO

做的改动主要集中在layer、EBlock和DBlock，故仅仅展示这些代码。

**数据集：GoPro**

<div align=center>
<img src=./img/Architecture.jpg />
</div>

In [None]:
class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        super(BasicConv, self).__init__()
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 -1
            layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        else:
            layers.append(
                nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        if norm:
            layers.append(nn.BatchNorm2d(out_channel))
        if relu:
            layers.append(nn.ReLU(inplace=True))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResBlock, self).__init__()
        self.main = nn.Sequential(
            BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x):
        return self.main(x) + x


class EBlock(nn.Module):
    def __init__(self, out_channel, num_res=8):
        super(EBlock, self).__init__()

        layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class DBlock(nn.Module):
    def __init__(self, channel, num_res=8):
        super(DBlock, self).__init__()

        layers = [ResBlock(channel, channel) for _ in range(num_res)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

## result
<div align=center>
<img src=./img/model_0.png />
</div>

**params:6.807M**

**flops:67.094G**

**数据集：GoPro**


# model_1
这个模型仅仅是将layer里的，标准卷积转为深度可分离卷积。其他均不改变。

实验目的：测试深度可分离卷积替换卷积后，模型效果及计算量。

**数据集：GoPro**

In [None]:
class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        super(BasicConv, self).__init__()
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 -1
            layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        else:
            layers.append(
                          nn.Conv2d(in_channel, in_channel, kernel_size, padding=padding, groups=in_channel ,stride=stride, bias=bias))
            layers.append(nn.Conv2d(in_channel,out_channel,kernel_size=1))
        if norm:
            layers.append(nn.BatchNorm2d(out_channel))
        if relu:
            layers.append(nn.ReLU(inplace=True))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

## result
<div align=center>
<img src=./img/model_1.png width="1200px"/>
</div>

**params:1.048M**

**flops:13.786G**

**数据集：GoPro**

结果：模型性能下降，计算量下降。

# model_2
模型结构：在ResBlock中加入FFT。

实验目的：探索FFT对模型提升效果。

**数据集：GoPro**

In [None]:
class ResBlock(nn.Module):
    #'Backward' means no normalize (just fft). 
    # 'Ortho' means the frequency matrix will normalize by 1/sqrt(n).
    def __init__(self, n_feat,norm='backward'): # 'ortho'
        super(ResBlock, self).__init__()
        kernel_size=3
        self.main = nn.Sequential(
            BasicConv(n_feat, n_feat, kernel_size=kernel_size, stride=1, relu=True),
            BasicConv(n_feat, n_feat, kernel_size=kernel_size, stride=1, relu=False)
        )
        self.main_fft = nn.Sequential(
            BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True),
            BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False)
        )
        self.dim = n_feat
        self.norm = norm
    def forward(self, x):
        _, _, H, W = x.shape
        dim = 1
        y = torch.fft.rfft2(x, norm='backward')
        y_imag = y.imag
        y_real = y.real
        y_f = torch.cat([y_real, y_imag], dim=dim)
        y = self.main_fft(y_f)
        y_real, y_imag = torch.chunk(y, 2, dim=dim)
        y = torch.complex(y_real, y_imag)
        y = torch.fft.irfft2(y, s=(H, W), norm='backward')
        return self.main(x) + x + y

## result
<div align=center>
<img src=./img/model_2.png width="1200px"/>
</div>

**params:3.843M**

**flops:27.263G**

**数据集：GoPro**

结果：模型性能相对model_1,psnr值上升0.4，flops上升一倍。

# model_3
模型结构：相对model_2，增加EBlock、DBlock、feat_extract卷积核大小

实验目的：受论文《Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs》启发，验证深度可分离卷积搭配大的卷积核效果。

**数据集：GoPro**

In [None]:
class ResBlock(nn.Module):
    #'Backward' means no normalize (just fft). 
    # 'Ortho' means the frequency matrix will normalize by 1/sqrt(n).
    def __init__(self, n_feat,norm='backward'): # 'ortho'
        super(ResBlock, self).__init__()
        kernel_size=7
        self.main = nn.Sequential(
            BasicConv(n_feat, n_feat, kernel_size=kernel_size, stride=1, relu=True),
            BasicConv(n_feat, n_feat, kernel_size=kernel_size, stride=1, relu=False)
        )
        self.main_fft = nn.Sequential(
            BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True),
            BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False)
        )
        self.dim = n_feat
        self.norm = norm
    def forward(self, x):
        _, _, H, W = x.shape
        dim = 1
        y = torch.fft.rfft2(x, norm='backward')
        y_imag = y.imag
        y_real = y.real
        y_f = torch.cat([y_real, y_imag], dim=dim)
        y = self.main_fft(y_f)
        y_real, y_imag = torch.chunk(y, 2, dim=dim)
        y = torch.complex(y_real, y_imag)
        y = torch.fft.irfft2(y, s=(H, W), norm='backward')
        return self.main(x) + x + y

kernel_size=7
base_channel=32
feat_extract = nn.ModuleList([
        BasicConv(3, base_channel, kernel_size=kernel_size, relu=True, stride=1),
        BasicConv(base_channel, base_channel * 2, kernel_size=kernel_size, relu=True, stride=2),
        BasicConv(base_channel * 2, base_channel * 4, kernel_size=kernel_size, relu=True, stride=2),
        BasicConv(base_channel * 4, base_channel * 2, kernel_size=kernel_size+1, relu=True, stride=2, transpose=True),
        BasicConv(base_channel * 2, base_channel, kernel_size=kernel_size+1, relu=True, stride=2, transpose=True),
        BasicConv(base_channel, 3, kernel_size=kernel_size, relu=False, stride=1)
        ])

## result
<div align=center>
<img src=./img/model_3.png width="1200px"/>
</div>

**params:4.627M**

**flops:44.968G**


结果：模型性能相对model_2,psnr值上升0.2，flops上升17G。

# model_10
模型结构：ResBlock，Wavelet_transform，DB_kernel=7，EB_kernel=7,feat_extract_kernel=3
<div align=center>
<img src=./img/model_10_DB.png />
</div>

**数据集：REDS**

In [None]:
#小波变换
class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False  

    def forward(self, x):
        return self.dwt_init(x)

    def dwt_init(self,x):
        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2
        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]
        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4
        return torch.cat((x_LL, x_HL, x_LH, x_HH), 0)


class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return self.iwt_init(x)

    def iwt_init(self,x):
        r = 2
        in_batch, in_channel, in_height, in_width = x.size()
        #print([in_batch, in_channel, in_height, in_width])
        out_batch, out_channel, out_height, out_width = int(in_batch / r ** 2), int(
            in_channel), r * in_height, r * in_width
        x1 = x[0:out_batch, :, :, :] / 2
        x2 = x[out_batch:out_batch * 2, :, :, :] / 2
        x3 = x[out_batch * 2:out_batch * 3, :, :, :] / 2
        x4 = x[out_batch * 3:out_batch * 4, :, :, :] / 2

        h = torch.zeros([out_batch, out_channel, out_height,out_width]).float().cuda()

        h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
        return h

class Wavelet_transform(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Wavelet_transform,self).__init__()
        self.conv1 = BasicConv(in_channel, 64, kernel_size=5)
        self.conv2 = BasicConv(64, 32, kernel_size=3)
        self.conv3 = BasicConv(32, out_channel, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.DWT = DWT()
        self.IDWT = IWT()

    def forward(self, x):
        x = self.DWT(x)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        x = self.IDWT(x)
        return x

In [None]:
class EBlock(nn.Module):
    def __init__(self, out_channel, num_res=8):
        super(EBlock, self).__init__()

        layers = [ResBlock(out_channel,out_channel,kernel_size=7) for _ in range(num_res)]
        self.wavelet=Wavelet_transform(out_channel,out_channel)

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        #return self.layers(x)
        return self.layers(x)+self.wavelet(x)


class DBlock(nn.Module):
    def __init__(self, channel, num_res=8):
        super(DBlock, self).__init__()

        layers = [ResBlock(channel,channel,kernel_size=7) for _ in range(num_res)]
        self.wavelet=Wavelet_transform(channel,channel)
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        #return self.layers(x)
        return self.layers(x)+self.wavelet(x)

kernel_size=3
base_channel=32
feat_extract = nn.ModuleList([
        BasicConv(3, base_channel, kernel_size=kernel_size, relu=True, stride=1),
        BasicConv(base_channel, base_channel * 2, kernel_size=kernel_size, relu=True, stride=2),
        BasicConv(base_channel * 2, base_channel * 4, kernel_size=kernel_size, relu=True, stride=2),
        BasicConv(base_channel * 4, base_channel * 2, kernel_size=kernel_size+1, relu=True, stride=2, transpose=True),
        BasicConv(base_channel * 2, base_channel, kernel_size=kernel_size+1, relu=True, stride=2, transpose=True),
        BasicConv(base_channel, 3, kernel_size=kernel_size, relu=False, stride=1)
        ])

## result

<div align=center>
<img src=./img/model_10_REDS.png width="1200px"/>
</div>

**params:1.408M**

**flops:19.920G**

<div>
<img src=./img/model_10/90/patch_0_0.png width="600px" />
<img src=./img/model_10/120/patch_0_0.png width="600px" />
</div>

缺陷：图像出现伪影

<div>
<img src=./img/model_10/90/patch_8_5.png width="600px" />
<img src=./img/model_10/120/patch_8_5.png width="600px" />
</div>

## model_10第二次训练
实验目的：更换数据集，查看realBlur数据集的效果

数据集：realBlur

<div align=center>
<img src=./img/model_10_realBlur.png width="1200px"/>
</div>

<img src=./img/model_10_realBlur/350/patch_0_0.png width="600px" />

# model_8
数据集：REDS_JPEG

在MIMO中加入dropout，原因：由于VIVO数据集和REDS数据集之间存在gap，导致部分图像出现伪影。


## result

<div align=center>
<img src=./img/model_8.png width="1200px"/>
</div>

<div>
<img src=./img/model_8/140/patch_0_0.png width="600px" />
<img src=./img/model_8/140/patch_8_5.png width="600px" />
</div>



# 总结

目前效果最好的是使model_10基于REDS，第90epoch生成的图像

resotrom：MDTA替换为GDFN 先测FLOPS
mimo aff 1X1 3X3替换为GDFN 测FLOPS 没有小波变换 深度可分卷积
数据集是REDS_JPEG