In [1]:
import torch

# 假设 x 是一个输入张量
x = torch.randn(1, 3, 16, 16)  # 3D 输入数据，形状为 mini-batch x channels x height x width

# 使用双线性插值法将 x 上采样到指定的尺寸
upsampled_x = torch.nn.functional.interpolate(x, size=(32, 32), mode='bilinear', align_corners=False)
upsampled_x.shape


torch.Size([1, 3, 32, 32])

In [3]:
from torch import nn


class Upsample(nn.Module):
    def __init__(self, input_feat,out_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat * 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat*2),
            # nn.Hardswish(),
            nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


In [7]:
y = torch.randn(1, 16, 8, 8)


In [8]:
upsample = Upsample(16, 32)
result = upsample(y)
result.shape

torch.Size([1, 32, 16, 16])

In [9]:
class Downsample(nn.Module):
    def __init__(self, input_feat,out_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat // 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat // 2),
            # nn.Hardswish(),
            nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)


In [10]:
downsample = Downsample(16, 8)
result_2 = downsample(y)
result_2.shape

torch.Size([1, 8, 4, 4])

In [11]:
depths=[2, 2, 18, 2]
dpr = [
    x.item() for x in torch.linspace(0, 0.5, sum(depths))
] 
dpr

[0.0,
 0.021739130839705467,
 0.043478261679410934,
 0.06521739065647125,
 0.08695652335882187,
 0.10869565606117249,
 0.1304347813129425,
 0.15217392146587372,
 0.17391304671764374,
 0.19565217196941376,
 0.21739131212234497,
 0.239130437374115,
 0.260869562625885,
 0.28260868787765503,
 0.30434781312942505,
 0.32608693838119507,
 0.3478260636329651,
 0.3695652186870575,
 0.3913043439388275,
 0.41304346919059753,
 0.43478262424468994,
 0.45652174949645996,
 0.47826087474823,
 0.5]

In [12]:
sum(depths[:0])

0

In [13]:
sum(depths[:1])

2

In [14]:
dpr[0:2]

[0.0, 0.021739130839705467]

In [15]:
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [83]:
x = torch.randn(1, 1024, 64, 64)
conv1 = BasicConv2d(1024, 256, 1, stride=1, padding=0, dilation=1)
x = conv1(x)
x.shape

torch.Size([1, 256, 64, 64])

In [68]:
up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

y = up(x)
y.shape


torch.Size([1, 256, 64, 64])

In [60]:
y

tensor([[[[ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          ...,
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172]],

         [[-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          ...,
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098]],

         [[ 0.0138,  0.0138,  0.0138,  ...,  0.0138,  0.0138,  0.0138],
          [ 0.0138,  0.0138,  

In [79]:
x = torch.randn(1, 64, 64, 256)
x_2 = torch.randn(1, 64, 64, 256)
x_5 = torch.randn(1, 64, 64, 256)
x_6 = torch.randn(1, 64, 64, 256)
result = torch.cat([x, x_2, x_5, x_6], 3)
result.shape

torch.Size([1, 64, 64, 1024])

In [86]:
from torchsummary import summary
import torch.nn.functional as F

class Enhancer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Enhancer, self).__init__()

        self.relu = nn.LeakyReLU(0.2, inplace=True)

        self.tanh = nn.Tanh()

        self.refine1 = nn.Conv2d(in_channels, 20, kernel_size=3, stride=1, padding=1)
        self.refine2 = nn.Conv2d(20, 20, kernel_size=3, stride=1, padding=1)

        self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0)
        self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0)
        self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0)
        self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0)

        self.refine3 = nn.Conv2d(20 + 4, out_channels, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.functional.interpolate

        self.batch1 = nn.InstanceNorm2d(100, affine=True)

    def forward(self, x):
        dehaze = self.relu((self.refine1(x)))
        dehaze = self.relu((self.refine2(dehaze)))
        shape_out = dehaze.data.size()

        shape_out = shape_out[2:4]

        x101 = F.avg_pool2d(dehaze, 32)

        x102 = F.avg_pool2d(dehaze, 16)

        x103 = F.avg_pool2d(dehaze, 8)

        x104 = F.avg_pool2d(dehaze, 4)

        x1010 = self.upsample(self.relu(self.conv1010(x101)), size=shape_out)
        x1020 = self.upsample(self.relu(self.conv1020(x102)), size=shape_out)
        x1030 = self.upsample(self.relu(self.conv1030(x103)), size=shape_out)
        x1040 = self.upsample(self.relu(self.conv1040(x104)), size=shape_out)

        dehaze = torch.cat((x1010, x1020, x1030, x1040, dehaze), 1)
        dehaze = self.tanh(self.refine3(dehaze))

        return dehaze
    

In [88]:
enhancer = Enhancer(3, 3)
summary(enhancer, input_size=(3, 256, 256), batch_size=1, device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 20, 256, 256]             560
         LeakyReLU-2          [1, 20, 256, 256]               0
            Conv2d-3          [1, 20, 256, 256]           3,620
         LeakyReLU-4          [1, 20, 256, 256]               0
            Conv2d-5               [1, 1, 8, 8]              21
         LeakyReLU-6               [1, 1, 8, 8]               0
            Conv2d-7             [1, 1, 16, 16]              21
         LeakyReLU-8             [1, 1, 16, 16]               0
            Conv2d-9             [1, 1, 32, 32]              21
        LeakyReLU-10             [1, 1, 32, 32]               0
           Conv2d-11             [1, 1, 64, 64]              21
        LeakyReLU-12             [1, 1, 64, 64]               0
           Conv2d-13           [1, 3, 256, 256]             651
             Tanh-14           [1, 3, 2

In [96]:
!set BASICSR_JIT=True
from basicsr.archs.module import ResBlock, WarpBlock


class RIDCPDecoder(nn.Module):
    def __init__(self,
                 in_channel,
                 max_depth,
                 input_res=256,
                 channel_query_dict=None,
                 norm_type='gn',
                 act_type='leakyrelu',
                 only_residual=False,
                 use_warp=True
                 ):
        super().__init__()
        self.only_residual = only_residual
        self.use_warp = use_warp
        self.upsampler = nn.ModuleList()
        self.warp = nn.ModuleList()
        res = input_res // (2 ** max_depth)
        for i in range(max_depth):
            in_channel, out_channel = channel_query_dict[res], channel_query_dict[res * 2]
            self.upsampler.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1),
                    ResBlock(out_channel, out_channel, norm_type, act_type),
                    ResBlock(out_channel, out_channel, norm_type, act_type),
                )
            )
            self.warp.append(WarpBlock(out_channel))
            res = res * 2

    def forward(self, x):
        # in x (batch_size, 256, w, h)
        # in code_decoder_output list length 2 shape (batch_size, 256, w, h)
        # out channel 256 -> 128 -> 64
        for idx, m in enumerate(self.upsampler):
            with torch.backends.cudnn.flags(enabled=False):
                    x = m(x)
        return x


In [115]:
ridcp = RIDCPDecoder(
    3,
    2,
    256,
    {
        8: 256,
        16: 256,
        32: 256,
        64: 256,
        128: 128,
        256: 64,
        512: 32,
    },
    'gn', 'silu', False, use_warp=True
).cuda()

summary(ridcp, input_size=(256, 128, 128), batch_size=1, device="cuda")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1         [1, 256, 256, 256]               0
            Conv2d-2         [1, 128, 256, 256]         295,040
         GroupNorm-3         [1, 128, 256, 256]             256
         NormLayer-4         [1, 128, 256, 256]               0
              SiLU-5         [1, 128, 256, 256]               0
          ActLayer-6         [1, 128, 256, 256]               0
            Conv2d-7         [1, 128, 256, 256]         147,584
         GroupNorm-8         [1, 128, 256, 256]             256
         NormLayer-9         [1, 128, 256, 256]               0
             SiLU-10         [1, 128, 256, 256]               0
         ActLayer-11         [1, 128, 256, 256]               0
           Conv2d-12         [1, 128, 256, 256]         147,584
         ResBlock-13         [1, 128, 256, 256]               0
        GroupNorm-14         [1, 128, 2

In [122]:
class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y


class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y


class DehazeBlock(nn.Module):
    def __init__(self, conv, dim, kernel_size):
        super(DehazeBlock, self).__init__()

        self.conv1 = conv(dim, dim, kernel_size, bias=True)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = conv(dim, dim, kernel_size, bias=True)
        self.calayer = CALayer(dim)
        self.palayer = PALayer(dim)

    def forward(self, x):
        res = self.act1(self.conv1(x))
        res = res + x
        res = self.conv2(res)
        res = self.calayer(res)
        res = self.palayer(res)
        res += x

        return res


In [123]:
dehaze_block = DehazeBlock(nn.Conv2d, 256, 1)
dehaze_block = dehaze_block.cuda()
summary(dehaze_block, input_size=(256, 64, 64), batch_size=1, device="cuda")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [1, 256, 64, 64]          65,792
              ReLU-2           [1, 256, 64, 64]               0
            Conv2d-3           [1, 256, 64, 64]          65,792
 AdaptiveAvgPool2d-4             [1, 256, 1, 1]               0
            Conv2d-5              [1, 32, 1, 1]           8,224
              ReLU-6              [1, 32, 1, 1]               0
            Conv2d-7             [1, 256, 1, 1]           8,448
           Sigmoid-8             [1, 256, 1, 1]               0
           CALayer-9           [1, 256, 64, 64]               0
           Conv2d-10            [1, 32, 64, 64]           8,224
             ReLU-11            [1, 32, 64, 64]               0
           Conv2d-12             [1, 1, 64, 64]              33
          Sigmoid-13             [1, 1, 64, 64]               0
          PALayer-14           [1, 256,

In [116]:
channel_query_dict = {
    8: 256,
    16: 256,
    32: 256,
    64: 256,
    128: 128,
    256: 64,
    512: 32,
}
res = 256 // (2 ** 2)
res

64

In [117]:
in_channel, out_channel = channel_query_dict[res], channel_query_dict[res * 2]

In [118]:
in_channel, out_channel

(256, 128)

In [119]:
res = 128
in_channel, out_channel = channel_query_dict[res], channel_query_dict[res * 2]

In [120]:
in_channel, out_channel

(128, 64)