Convolution layers apply a small sliding filter over the input image and perform element-wise multiplication and addition to produce an output feature map which helps in effectively capturing spatial relationships and local patterns in the image, however using convolution layers for inpainting often leads to artifacts such as color discrepancy and blurriness.

Partial Convolution layers on the other hand takes into account the masked regions of the input image. In inpainting tasks, the input image has missing pixels, and partial convolution layers only use the available pixels to compute the convolution operation, so that the output pixel values in the masked region are not affected by the missing values.

# Partial Convolution Layer

The partial convolution layer computes the convolution operation only on the valid pixels and then divides the output by the number of valid pixels. This division factor ensures that the output values for the masked regions are not biased due to the presence of fewer input values. By using partial convolution layers in an inpainting network, the model can learn to fill in the missing regions while preserving the content and style of the original image.

$$x' = \begin{cases} 
      W^{T}(X\bigodot M)\frac{sum(1)}{sum(M)} + b & if sum(M) > 0 \\
      0, & otherwise \\
   \end{cases}$$

The partial convolution operation also updates it’s mask M as if the convolution was able to condition its output on at least one valid input value, then we mark that location to be valid

$$m' = \begin{cases} 
      1 & if sum(M) > 0 \\
      0, & otherwise \\
   \end{cases}$$

In [1]:
import torch
import torch.nn as nn

class PartialConv2d (nn.Module):
	def __init__(self, in_channels, out_channels, bias=False):
		super().__init__()
		self.input_conv = nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=bias)
		self.mask_conv = nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=False)
		self.activation = nn.ReLU()
		nn.init.constant_(self.mask_conv.weight, 1.0)
		for param in self.mask_conv.parameters():
			param.requires_grad = False

	def forward(self, input_x, mask):
		# output = W^T dot (X .* M) + b
		output = self.input_conv(input_x * mask)
		with torch.no_grad():
			# mask = (1.M) + 0 = M
			output_mask = self.mask_conv(mask)
		output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)
		mask_is_zero = (output_mask == 0)
		# mask_sum is the sum of the binary mask at every partial convolution location 
		mask_sum = output_mask.masked_fill_(mask_is_zero, 1.0)
		# output = (W^T dot (X .* M) + b - b) / M_sum + b ; if M_sum > 0
		output = (output - output_bias) / mask_sum + output_bias
		# output = 0 if M_sum == 0
		output = output.masked_fill_(mask_is_zero, 0.0)
		# update the mask
		new_mask = torch.ones_like(output)
		new_mask = new_mask.masked_fill_(mask_is_zero, 0.0)
		output = self.activation(output)
		return output, new_mask