In [1]:
import torch
from torch import nn

In [2]:
class ResnetBlock(nn.Module):
    """Residual block"""
    
    def __init__(self, dim):
        """Initializes a resnet block
        
        Parameters:
            dim (int) : number channels in the convolution layer
            
        Returns:
            Block of two 3x3 refectionpad-conv-instancenorm layers.
        
        This block learns the residual function.
        Thus the input must have the same channels as the arg dim passed here
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1),
            nn.InstanceNorm2d(dim)
        )
        
    def forward(self, x):
        return self.conv_block(x) + x
    
x = torch.rand(2, 3, 256, 256)
f = ResnetBlock(3)
with torch.no_grad():
    print(f(x).shape)

torch.Size([2, 3, 256, 256])


## Convblock magic
Using padding = (kernel - 1) / 2 makes the size scaled down by a factor of stride only

In [3]:
# class ConvBlock(nn.Module):
#     """Contains an convolution-InstanceNorm-Relu layer"""
#     def __init__(self, channels_out):
#         super(ConvBlock, self).__init__()
#         channels_in = channels_out // 2
#         self.block = nn.Sequential(
#             nn.Conv2d(channels_in, channels_out, kernel_size=3, stride=2, padding=1),
#             nn.InstanceNorm2d(channels_out),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.block(x)

In [4]:
# class ConvTranposeBlock(nn.Module):
#     def __init__(self, channels_out):
#         super(ConvTranposeBlock, self).__init__()
#         channels_in = channels_out * 2
#         self.block = nn.Sequential(
#             nn.ConvTranspose2d(channels_in, channels_out, kernel_size=3, stride=2, padding=1, output_padding=1),
#             nn.InstanceNorm2d(channels_out),
#             nn.ReLU(True)            
#         )
#     def forward(self, x):
#         return self.block(x)
    
# x = torch.rand(2, 256, 64, 64)
# f = ConvTranposeBlock(128)
# with torch.no_grad():
#     print(f(x).shape)

In [5]:
def ConvBlock(channels_out):
    channels_in = channels_out // 2
    return (
        nn.Conv2d(channels_in, channels_out, kernel_size=3, stride=2, padding=1),
        nn.InstanceNorm2d(channels_out),
        nn.ReLU(inplace=True)
    )

def ConvTranposeBlock(channels_out):
    channels_in = channels_out * 2
    return (
        nn.ConvTranspose2d(channels_in, channels_out, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(channels_out),
        nn.ReLU(True)            
    )

In [6]:
x = torch.rand(2, 3, 256, 256)
model = nn.Sequential(
    # first block uses reflection padding and instance norm
    nn.ReflectionPad2d(3),
    nn.Conv2d(3, 64, kernel_size=7, stride=1),
    nn.InstanceNorm2d(64),
    nn.ReLU(True),
    
    *ConvBlock(128),
    *ConvBlock(256),
    
    # six residual blocks
    *[ResnetBlock(256) for i in range(6)],
    
    *ConvTranposeBlock(128),
    *ConvTranposeBlock(64),
    
    # last block uses reflection padding but no normalization and tanh
    nn.ReflectionPad2d(3),
    nn.Conv2d(64, 3, kernel_size=7, stride=1),
    nn.Tanh()
)

with torch.no_grad():
    print(model(x).shape)
y = sum(p.numel() for p in model.parameters())
print(f"{y:,}")

torch.Size([2, 3, 256, 256])
7,837,699


In [7]:
print(model)

Sequential(
  (0): ReflectionPad2d((3, 3, 3, 3))
  (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
  (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (3): ReLU(inplace=True)
  (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (9): ReLU(inplace=True)
  (10): ResnetBlock(
    (conv_block): Sequential(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace=True)
      (4): ReflectionPad2d((1, 1, 1, 1))
      (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    

In [8]:
# model = model.cuda()
# batch = torch.rand(8, 3, 128, 128).cuda()
# out = model(batch)

In [9]:
from original_gen import ResnetGenerator

In [10]:
test = ResnetGenerator(3, 3, 64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=6)
with torch.no_grad():
    print(test(x).shape)
y = sum(p.numel() for p in test.parameters())
print(f"{y:,}")

torch.Size([2, 3, 256, 256])
7,837,699


In [11]:
print(test)

ResnetGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
     