In [1]:
import torch
import torch.nn as nn

import numpy as np

In [2]:
DEBUG = False

## Partial Convolution

In [3]:
class PartialConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)  # TODO: see no bias here!
        self.conv_bias = nn.Parameter(torch.zeros(out_channels), requires_grad=True)
        
        self.sum_conv = nn.Conv2d(in_channels, 1, kernel_size, stride=stride, padding=padding, bias=False)
        self.sum_conv.weight.data.fill_(1)
        self.sum_conv.weight.requires_grad_(False)  # TODO: check that not learning
        
    
    def forward(self, x, mask):
        """
        Forward pass of Partial Convolution (arxiv.org/abs/1804.07723)
        
        Parameters
        ----------
        x : FloatTensor, input feature tensor of shape (b, c, h, w)
        mask : FloatTensor, binary mask tensor of shape (b, c, h, w)
        """
        assert x.shape == mask.shape, 'x and mask shapes must be equal'
        
        x_masked = x * mask
        x_after_conv = self.conv(x_masked)

        x_after_conv_normed = x_after_conv  # no norm
        # x_after_conv_normed = torch.where(mask_norm != 0, x_after_conv / mask_norm, torch.zeros_like(x_after_conv))
        x_after_conv_normed += self.conv_bias.view(1, -1, 1, 1)
        
        updated_mask_single = (self.sum_conv(mask) > 0).type(torch.float32)
        updated_mask = torch.cat([updated_mask_single] * self.out_channels, dim=1).to(mask.device)
            
        return x_after_conv_normed, updated_mask

## Network

In [4]:
class InpaintDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=2, padding='same', bn=True):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        
        padding = (kernel_size - 1) // 2 if padding == 'same' else padding
        self.padding = padding
        
        self.bn = bn
        
        self.pconv = PartialConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        if bn:
            self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x, mask):
        x, mask = self.pconv(x, mask)
        if self.bn:
            x = self.bn(x)
        x = self.relu(x)
        
        return x, mask

In [5]:
class InpaintUpBlock(nn.Module):
    def __init__(self, in_channels, in_channels_bridge, out_channels, kernel_size, padding='same', bn=True):
        super().__init__()
        
        self.in_channels = in_channels
        self.in_channels_bridge = in_channels_bridge
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        padding = (kernel_size - 1) // 2 if padding == 'same' else padding
        self.padding = padding
        
        self.bn = bn
        
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')  # TODO: align corners!
        
        self.pconv = PartialConv2d(in_channels + in_channels_bridge, out_channels, kernel_size, padding=padding)
        if bn:
            self.bn = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
    def forward(self, x, mask, x_bridge, mask_bridge):
        x, mask = self.upsample(x), self.upsample(mask)
        torch.cat([mask, mask_bridge], dim=1)
        x, mask = torch.cat([x, x_bridge], dim=1), torch.cat([mask, mask_bridge], dim=1)
        
        x, mask = self.pconv(x, mask)
        
        if self.bn:
            x = self.bn(x)
        x = self.leaky_relu(x)
        
        return x, mask

In [6]:
class InpaintNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.depth = 7
        
        # down
        self.down_blocks = nn.ModuleList([
            InpaintDownBlock(in_channels, 64, 7, stride=2, padding='same', bn=False),
            InpaintDownBlock(64, 128, 5, stride=2, padding='same', bn=False),
            InpaintDownBlock(128, 256, 5, stride=2, padding='same', bn=False),
            InpaintDownBlock(256, 512, 3, stride=2, padding='same', bn=False),
            InpaintDownBlock(512, 512, 3, stride=2, padding='same', bn=False),
            InpaintDownBlock(512, 512, 3, stride=2, padding='same', bn=False),
            InpaintDownBlock(512, 512, 3, stride=2, padding='same', bn=False),
        ])
        
        # up
        self.up_blocks = nn.ModuleList([
            InpaintUpBlock(512, 512, 512, 3, padding='same', bn=False),
            InpaintUpBlock(512, 512, 512, 3, padding='same', bn=False),
            InpaintUpBlock(512, 512, 512, 3, padding='same', bn=False),
            InpaintUpBlock(512, 256, 256, 3, padding='same', bn=False),
            InpaintUpBlock(256, 128, 128, 3, padding='same', bn=False),
            InpaintUpBlock(128, 64, 64, 3, padding='same', bn=False),
            InpaintUpBlock(64, 3, 3, 3, padding='same', bn=False)
        ])
        
    def forward(self, x, mask):
        x_bridges, mask_bridges = [], []
        for i in range(self.depth):
            x_bridges.append(x)
            mask_bridges.append(mask)
            x, mask = self.down_blocks[i](x, mask)

        for i in range(self.depth):
            x, mask = self.up_blocks[i](x, mask, x_bridges[-i - 1], mask_bridges[-i - 1])
        
        return x, mask

In [7]:
model = InpaintNet()

In [8]:
state_dict_path = '../model_no_sigmoid_lr_18.pth'

state_dict = torch.load(state_dict_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)

## Example

In [9]:
image = torch.rand((1, 3, 256, 256))
mask = torch.randint(0, 2, (1, 3, 256, 256))

image_result, _ = model(image, mask)
print(image_result.shape)

torch.Size([1, 3, 256, 256])
