Model based on "Single Image Portrait Relighting" whose aim is to produce same scenes as given with illumination conditions (color/light direction) swapped between 2 given inputs.

In [1]:
import torch
import torch.nn as nn

# Model building blocks

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, channels_per_group=16):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels, out_channels, stride=2, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x):
        return self.double_conv(x)

In [4]:
class BottleneckConv(nn.Module):
    def __init__(self, in_channels, channels_per_group=16, envmap_H=16, envmap_W=32, depth=4):
        expected_channels = envmap_H * envmap_W
        assert in_channels == expected_channels, f'Bottleneck input has {in_channels} channels, expected {expected_channels}'
        super(BottleneckConv, self).__init__()
        self.convolutions = self._build_bottleneck_convolutions(in_channels, channels_per_group, depth)
    
    def _build_bottleneck_convolutions(self, in_channels, channels_per_group, depth):
        single_conv = [
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        ]
        convolutions = single_conv * (depth - 1)
        # final layer before weighted pooling
        out_channels = 4*in_channels
        convolutions += [
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.Softplus()
        ]
        return nn.ModuleList(convolutions)
    
    def forward(self, x):
        for module in self.convolutions:
            x = module(x)
        # split x into environment map predictions and confidence
        channels = x.size()[1]
        split_point = 3 * (channels // 4)
        envmap_predictions, confidence = x[:, :split_point], x[:, split_point:]
        return envmap_predictions, confidence

In [None]:
class WeightedPooling(nn.Module):