In [1]:
import torch
import torch.nn as nn

In [2]:


class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, n):
        super(AdaptiveInstanceNorm, self).__init__()

        self.w_0 = nn.Parameter(torch.Tensor([1.0]))
        self.w_1 = nn.Parameter(torch.Tensor([0.0]))

        self.ins_norm = nn.InstanceNorm2d(n, momentum=0.999, eps=0.001, affine=True)

    def forward(self, x):
        return self.w_0 * x + self.w_1 * self.ins_norm(x)


class PALayer(nn.Module):
    def __init__(self, channel: int):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y


class CALayer(nn.Module):
    def __init__(self, channel: int):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

In [3]:
class RDB(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate):
        super(RDB, self).__init__()
        
        modules = []
        self.split_channel=in_channels//4
        kernel_size=3
        dilation=1
        self.conv1 = nn.Conv2d(self.split_channel*1, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=2
        self.conv2 = nn.Conv2d(self.split_channel*2, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=4
        self.conv3 = nn.Conv2d(self.split_channel*3, self.split_channel, kernel_size=kernel_size,  padding=dilation, dilation=dilation)
        dilation=8
        self.conv4 = nn.Conv2d(self.split_channel*4, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)

            
        #self.residual_dense_layers = nn.Sequential(*modules)
        _in_channels=in_channels
        self.calayer=CALayer(in_channels)
        self.palayer=PALayer(in_channels)
        self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        splited = torch.split(x, self.split_channel, dim=1)
        x0=F.relu(self.conv1(splited[0]))
        tmp= torch.cat((splited[1], x0), 1)
        x1=F.relu(self.conv2(tmp))
        tmp= torch.cat((splited[2], x0, x1), 1)
        x2=F.relu(self.conv3(tmp))
        tmp= torch.cat((splited[3], x0, x1, x2), 1)
        x3=F.relu(self.conv4(tmp))
        tmp= torch.cat(( x0, x1, x2, x3), 1)
        
        out = self.conv_1x1(tmp)
        out=self.calayer(out)
        out=self.palayer(out)
        #print(out.shape, x.shape)
        out=out+x
        return out

In [4]:
import torch.nn.functional as F

# Define the input tensor
input_tensor = torch.randn(1, 64, 128, 128)  # Example input tensor with shape (batch_size, channels, height, width)

# Create an instance of the RDB class
rdb = RDB(in_channels=64, num_dense_layer=4, growth_rate=32)

# Pass the input tensor through the RDB instance
output_tensor = rdb(input_tensor)

# Print the output tensor shape
print(output_tensor.shape)

torch.Size([1, 64, 128, 128])
