In [1]:
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch
# from torchvision.models import vgg19
import math

class Generator(nn.Module):
    r"""The main architecture of the generator."""

    def __init__(self, coarse_dim, fine_dim, nc, n_predictands):
        r"""This is an esrgan model defined by the author himself."""
        super(Generator, self).__init__()
        # First layer.
        self.coarse_dim = coarse_dim
        self.fine_dim = fine_dim
        self.nc = nc
        self.n_predictands = n_predictands
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.nc, self.coarse_dim, kernel_size=9, stride=1, padding=4),
            nn.PReLU(),
        )

        # Residual blocks.
        residual_blocks = []
        for _ in range(16):
            residual_blocks.append(ResidualBlock(self.coarse_dim))
        self.Trunk = nn.Sequential(*residual_blocks)

        # Second conv layer post residual blocks.
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                self.coarse_dim,
                self.coarse_dim,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(self.coarse_dim),
        )

        # 2 Upsampling layers.
        upsampling = []
        for _ in range(3):
            upsampling.append(UpsampleBlock(self.coarse_dim))

        self.upsampling = nn.Sequential(*upsampling)
        
#         Final output layer.
        self.convi = nn.Conv2d(
            self.coarse_dim, self.n_predictands, kernel_size=9, stride=1, padding=4
        )
#         self.convi = nn.Sequential(
#             nn.Conv2d(self.n_predictands, self.fine_dim, kernel_size=9, stride=1, padding=4),
#             nn.PReLU(),
#         )

        
        # Residual blocks.
        post_up_residual_blocks = []
        for _ in range(3):
            post_up_residual_blocks.append(ResidualBlock(self.n_predictands))
        self.post_up_Trunk = nn.Sequential(*post_up_residual_blocks)
        
#             Second conv layer post residual blocks.
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                self.n_predictands,
                self.n_predictands,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(self.n_predictands),
        )
        
#         Final output layer.
        self.conv4 = nn.Conv2d(
            self.n_predictands, self.n_predictands, kernel_size=9, stride=1, padding=4
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out1 = self.conv1(input)
        out = self.Trunk(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out1 = self.convi(out)
        out = self.post_up_Trunk(out1)
        out2 = self.conv3(out)
        out = torch.add(out1, out2)
        out = self.conv4(out)
        # out = self.sig(out)
        return out

    def sample_coarse(self, input: torch.Tensor) -> torch.Tensor:
        return input


class UpsampleBlock(nn.Module):
    r"""Main upsample block structure"""

    def __init__(self, channels):
        r"""Initializes internal Module state, shared by both nn.Module and ScriptModule.
        Args:
            channels (int): Number of channels in the input image.
        """
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(
            channels,
            channels * 4,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
        self.prelu = nn.PReLU()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.conv(input)
        out = self.pixel_shuffle(out)
        out = self.prelu(out)

        return out


class ResidualBlock(nn.Module):
    r"""Main residual block structure"""

    def __init__(self, channels):
        r"""Initializes internal Module state, shared by both nn.Module and ScriptModule.
        Args:
            channels (int): Number of channels in the input image.
        """
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1, bias=False
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.conv1(input)
        out = self.prelu(out)
        out = self.conv2(out)

        return out + input


In [2]:
x = torch.randn(100, 7, 16, 16).cuda()

In [3]:
G = Generator(16, 128, 7, 2).cuda()

In [4]:
G(x)

tensor([[[[ 3.4311e-02,  1.4785e-01, -1.4974e-01,  ...,  2.2867e-03,
            1.3488e-02, -2.0908e-02],
          [ 5.8613e-02,  2.0475e-01,  5.2772e-02,  ...,  1.3846e-01,
            9.1460e-02,  1.5502e-01],
          [-4.7215e-02, -2.5638e-01, -3.9353e-01,  ..., -2.5608e-01,
           -2.0563e-01,  1.6038e-01],
          ...,
          [ 2.0541e-01,  4.5574e-01,  5.2230e-01,  ...,  1.7188e-03,
           -6.3540e-02,  4.4931e-02],
          [ 3.9826e-03, -1.0721e-01,  7.9571e-02,  ..., -1.7100e-01,
           -1.2555e-01, -5.5412e-02],
          [-1.0697e-01, -1.7847e-01,  3.1286e-02,  ..., -1.8031e-01,
           -1.8446e-01, -1.7592e-01]],

         [[-6.1838e-02, -4.4313e-02,  3.2636e-02,  ...,  1.2875e-01,
            1.1893e-01, -1.9040e-01],
          [ 7.1081e-02,  8.0100e-02,  1.5945e-01,  ..., -7.1080e-02,
            1.3849e-01, -8.5565e-02],
          [ 3.7613e-02, -1.5659e-02,  1.7950e-01,  ...,  2.3053e-01,
            2.0042e-01, -5.6914e-02],
          ...,
     