DeepFill model

Sources:
- [paper](https://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Free-Form_Image_Inpainting_With_Gated_Convolution_ICCV_2019_paper.pdf)
- [github](https://github.com/JiahuiYu/generative_inpainting)

Issues:
- two possible implementations of gating: $O_{y,x}=\phi(Feature_{y,x}) \odot \sigma(Gating_{y,x})$ vs $O_{y,x}=\phi(Feature_{y,x} \odot \sigma(Gating_{y,x}))$
- paper uses ELU activation

In [None]:
import torch 
import torch.nn as nn
import torchvision
from torchsummary import summary

In [3]:
class GatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(GatedConv, self).__init__()
        self.conv_A= nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.conv_B= nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.sigmoid = nn.Sigmoid()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        A = self.conv_A(x)
        B = self.conv_B(x)
        B = self.sigmoid(B)
        return A * B

In [4]:
class GatedDeConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, scale_factor=2):
        super(GatedDeConv, self).__init__()
        self.conv = GatedConv(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.scale_factor = scale_factor
    
    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor)
        x = self.conv(x)
        return x

In [5]:
class ContextualAttention(nn.Module):
    def __init__(self, in_dim):
        super(ContextualAttention, self).__init__()
        self.in_dim = in_dim
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)

        out = self.gamma * out + x
        return out

In [49]:
class CoarseNet(nn.Module):
    def __init__(self, cnum=48):
        super(CoarseNet, self).__init__()

        self.model = nn.Sequential(
            GatedConv(4, cnum, 5, 1, 2),
            nn.ELU(),
            GatedConv(cnum, 2*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(2*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(2*cnum, 4*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            
            GatedConv(4*cnum, 4*cnum, 3, 1, 2, dilation=2),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 4, dilation=4),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 8, dilation=8),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 16, dilation=16),
            nn.ELU(),
            GatedDeConv(4*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(2*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedDeConv(2*cnum, cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(cnum, cnum//2, 3, 1, 1),
            nn.ELU(),
            GatedConv(cnum//2, 3, 3, 1, 1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.model(x)

In [21]:
class RefinementNet(nn.Module):
    def __init__(self, cnum=48):
        super(RefinementNet, self).__init__()

        self.conv_branch = nn.Sequential(
            GatedConv(3, cnum, 5, 1, 2),
            nn.ELU(),
            GatedConv(cnum, 2*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(2*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(2*cnum, 4*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 2, 2),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 4, 4),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 8, 8),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 16, 16),
            nn.ELU()
        )

        self.attention_branch = nn.Sequential(
            GatedConv(3, cnum, 5, 1, 2),
            nn.ELU(),
            GatedConv(cnum, 2*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(2*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(2*cnum, 4*cnum, 3, 2, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ReLU(),
            ContextualAttention(4*cnum),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU()
        )

        self.decoder = nn.Sequential(
            GatedConv(8*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(4*cnum, 4*cnum, 3, 1, 1),
            nn.ELU(),
            GatedDeConv(4*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(2*cnum, 2*cnum, 3, 1, 1),
            nn.ELU(),
            GatedDeConv(2*cnum, cnum, 3, 1, 1),
            nn.ELU(),
            GatedConv(cnum, cnum//2, 3, 1, 1),
            nn.ELU(),
            GatedConv(cnum//2, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x_conv = self.conv_branch(x)
        x_attn = self.attention_branch(x)
        x = torch.cat([x_conv, x_attn], dim=1)
        x = self.decoder(x)
        return x

In [69]:
class DeepFill(nn.Module):
    def __init__(self, cnum=48):
        super(DeepFill, self).__init__()
        self.coarse_net = CoarseNet(cnum)
        self.refinement_net = RefinementNet(cnum)
    
    def forward(self, x):
        mask = x[:, 3, ...].unsqueeze(1)
        img = x[:, :3, ...]
        coarse_output = self.coarse_net(x)
        refined_input = coarse_output * mask + img * (1 - mask)
        refined_output = self.refinement_net(refined_input)
        return coarse_output, refined_output

In [83]:
x = torch.randn(1, 4, 256, 256)

In [77]:
%%timeit
RefinementNet(cnum=48)(x).shape

571 ms ± 7.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [80]:
%%timeit
RefinementNet(cnum=32)(x).shape

305 ms ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [86]:
%%timeit
DeepFill(cnum=48)(x)[1].shape

888 ms ± 28.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [87]:
%%timeit
DeepFill(cnum=32)(x)[1].shape

482 ms ± 28.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [70]:
summary(DeepFill(cnum=32).to('cuda'), (4, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]           3,232
            Conv2d-2         [-1, 32, 256, 256]           3,232
           Sigmoid-3         [-1, 32, 256, 256]               0
         GatedConv-4         [-1, 32, 256, 256]               0
               ELU-5         [-1, 32, 256, 256]               0
            Conv2d-6         [-1, 64, 128, 128]          18,496
            Conv2d-7         [-1, 64, 128, 128]          18,496
           Sigmoid-8         [-1, 64, 128, 128]               0
         GatedConv-9         [-1, 64, 128, 128]               0
              ELU-10         [-1, 64, 128, 128]               0
           Conv2d-11         [-1, 64, 128, 128]          36,928
           Conv2d-12         [-1, 64, 128, 128]          36,928
          Sigmoid-13         [-1, 64, 128, 128]               0
        GatedConv-14         [-1, 64, 1