In [1]:
import torch
import torch.nn as nn

import numpy as np

In [2]:
DEBUG = True

In [3]:
class PartialConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride)
        
        self.sum_conv = nn.Conv2d(in_channels, 1, kernel_size, stride=stride, bias=False)
        self.sum_conv.weight.data.fill_(1)
        self.sum_conv.weight.requires_grad_(False)  # TODO: check that not learning
        
    
    def forward(self, x, mask):
        """
        Forward pass of Partial Convolution (arxiv.org/abs/1804.07723)
        
        Parameters
        ----------
        x : FloatTensor, input feature tensor of shape (b, c, h, w)
        mask : FloatTensor, binary mask tensor of shape (b, c, h, w)
        """
        #mask_unsqueezed = mask.unsqueeze(1)  # making mask of shape (b, 1, h, w)
        assert x.shape == mask.shape, 'x and mask shapes must be equal'
        
        x_masked = x * mask
        x_after_conv = self.conv(x_masked)

        mask_norm = self.sum_conv(mask)
        x_after_conv_normed = torch.where(mask_norm != 0, x_after_conv / mask_norm, torch.zeros_like(x_after_conv))

        updated_mask_single = (self.sum_conv(mask) > 0).type(torch.FloatTensor)
        updated_mask = torch.cat([updated_mask_single] * self.out_channels, dim=1)

        if DEBUG:
            print('x')
            print(x)
            print()

            print('mask')
            print(mask)
            print()

            print('x_after_conv')
            print(x_after_conv)
            print()

            print('mask_norm')
            print(mask_norm)
            print()

            print('x_after_conv_normed')
            print(x_after_conv_normed)
            print()

            print('updated_mask')
            print(updated_mask)
            print()
            
        return x_after_conv_normed, updated_mask

Testing:

In [4]:
b, c, h, w = 2, 2, 4, 4

x = torch.randint(0, 5, (b, c, h, w))

mask_single = (torch.rand((b, h, w)) > 0.8).unsqueeze(1).type(torch.FloatTensor)
print(mask_single.shape)
mask = torch.cat([mask_single] * c, dim=1)
print(mask.shape)

pconv = PartialConv2d(c, 1, 3, stride=1)

x_new, mask_new = pconv(x, mask)

torch.Size([2, 1, 4, 4])
torch.Size([2, 2, 4, 4])
x
tensor([[[[ 4.,  2.,  4.,  0.],
          [ 4.,  0.,  4.,  4.],
          [ 0.,  2.,  2.,  1.],
          [ 1.,  2.,  1.,  3.]],

         [[ 3.,  1.,  0.,  4.],
          [ 3.,  4.,  0.,  0.],
          [ 4.,  1.,  0.,  4.],
          [ 4.,  0.,  2.,  4.]]],


        [[[ 4.,  1.,  1.,  1.],
          [ 0.,  1.,  4.,  2.],
          [ 4.,  0.,  1.,  3.],
          [ 0.,  0.,  1.,  4.]],

         [[ 0.,  0.,  3.,  0.],
          [ 3.,  3.,  2.,  2.],
          [ 2.,  1.,  2.,  3.],
          [ 1.,  1.,  0.,  4.]]]])

mask
tensor([[[[ 0.,  0.,  0.,  0.],
          [ 0.,  1.,  0.,  0.],
          [ 0.,  1.,  0.,  1.],
          [ 0.,  0.,  0.,  1.]],

         [[ 0.,  0.,  0.,  0.],
          [ 0.,  1.,  0.,  0.],
          [ 0.,  1.,  0.,  1.],
          [ 0.,  0.,  0.,  1.]]],


        [[[ 0.,  0.,  1.,  0.],
          [ 1.,  0.,  1.,  1.],
          [ 0.,  1.,  0.,  1.],
          [ 0.,  0.,  1.,  0.]],

         [[ 0.,  0.,  1.,  

In [None]:
mask.un

In [None]:
0.3288 / 8

In [None]:
torch.where()

In [None]:
b, c, h, w = 2, 3, 4, 4
x = torch.rand((b, c, h, w))
mask = torch.randint(0, 2, (b, h, w))

In [None]:
b, c, h, w = 2, 4, 5, 5
mask = (torch.rand((b, h, w)) > 0.5).type(torch.FloatTensor)
mask_unsqueezed = mask.unsqueeze(1)

norm_conv = nn.Conv2d(1, 1, 3, stride=1, bias=False)
norm_conv.weight.data.fill_(1)

mask_norm = norm_conv(mask_unsqueezed)

In [None]:
print(mask_norm.shape)
print(mask_norm)
print(torch.cat([mask_norm] * 5, dim=1))

In [None]:
mask.shape

In [None]:
mask_unsqueezed

In [None]:
mask_norm

In [None]:
mask.squeeze(1).shape

In [None]:
mask_unsqueezed[0]

In [None]:
mask_norm[0]

In [None]:
mask.dtype

In [None]:
norm_conv.bias.data.zero_()