Model based on "Single Image Portrait Relighting" whose aim is to produce same scenes as given with illumination conditions (color/light direction) swapped between 2 given inputs.

In [2]:
import torch
import torch.nn as nn

# Model building blocks

In [3]:
class ImageEncoding(nn.Module):
    def __init__(self, in_channels=3, enc_channels=32):
        super(ImageEncoding, self).__init__()
        self.convolve = nn.Conv2d(in_channels, enc_channels - in_channels, 7, padding=1)
    
    def forward(self, img):
        conv_out = convolve(img)
        return torch.cat((conv_out, img), dim=1)  # append image channels to the convolution result

In [4]:
class DownDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, channels_per_group=16):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels, out_channels, stride=2, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x):
        return self.double_conv(x)

In [5]:
class DownBottleneckConv(nn.Module):
    def __init__(self, in_channels, channels_per_group=16, envmap_H=16, envmap_W=32, depth=4):
        expected_channels = envmap_H * envmap_W
        assert in_channels == expected_channels, f'DownBottleneck input has {in_channels} channels, expected {expected_channels}'
        super(DownBottleneckConv, self).__init__()
        self.convolutions = self._build_bottleneck_convolutions(in_channels, channels_per_group, depth)
    
    def _build_bottleneck_convolutions(self, in_channels, channels_per_group, depth):
        single_conv = [
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        ]
        convolutions = single_conv * (depth - 1)
        
        # final layer before weighted pooling
        out_channels = 4*in_channels
        convolutions += [
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.Softplus()
        ]
        
        return nn.ModuleList(convolutions)
    
    def forward(self, x):
        for module in self.convolutions:
            x = module(x)
            
        # split x into environment map predictions and confidence
        channels = x.size()[1]
        split_point = 3 * (channels // 4)
        envmap_predictions, confidence = x[:, :split_point], x[:, split_point:]
        
        return envmap_predictions, confidence

In [6]:
class WeightedPooling(nn.Module):
    def __init__(self):
        super(WeightedPooling, self).__init__()
    
    def forward(self, x, weights):
        # TODO: multiplication with sum can be probably implemented as convolution with groups (according to some posts)
        return (x * weights.repeat((1, 3, 1, 1))).sum(dim=(2, 3))

In [7]:
class Tiling(nn.Module):
    def __init__(self, size=16, in_channels=1536, out_channels=512, channels_per_group=16):
        super(Tiling, self).__init__()
        self.size = size
        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x):
        tiled = x.repeat((1, 1, self.size, self.size))
        return encode(tiled)

Articles on transposed convolutions:
* [Convolution types](https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d)
* [Upsampling with transposed convolution](https://medium.com/activating-robotic-minds/up-sampling-with-transposed-convolution-9ae4f2df52d0)

In [8]:
def channel_concat(x, y):
    return torch.cat((x, y), dim=1)

class UpBottleneckConv(nn.Module):
    def __init__(self, in_channels, channels_per_group=16, envmap_H=16, envmap_W=32, depth=4):
        expected_channels = envmap_H * envmap_W
        assert depth >= 2, f'Depth should be not smaller than 3'
        assert in_channels == expected_channels, f'UpBottleneck input has {in_channels} channels, expected {expected_channels}'
        super(UpBottleneckConv, self).__init__()
        self.depth = depth
        
        half_in_channels = in_channels // 2
        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, half_in_channels, 3)
            nn.GroupNorm(half_in_channels // channels_per_group, half_in_channels),
            nn.PReLU()
        )
        self.initial_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels + half_in_channels, in_channels, 3),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )
        self.convs = nn.ModuleList([
            nn.ConvTranspose2d(2*in_channels, in_channels, 3),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        ] * (depth - 3))
        self.out_conv = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, half_in_channels, 3, stride=2),
            nn.GroupNorm(half_in_channels // channels_per_group, half_in_channels),
            nn.PReLU()
        )
    
    def forward(self, x, skip_connections):
        assert isinstance(skip_connections, list), 'skip_connections should be a list of corresponding encoder activations'
        expected_skip_connections = self.depth - 1
        assert len(skip_connections) == expected_skip_connections, f'skip_connections should contain {expected_skip_connections} activations'
        
        # encoding convolution
        x = self.encode(x)
        
        # transposed convolutions with skip connections
        x = self.initial_conv(channel_concat(x, skip_connections.pop()))
        for conv in self.convs:
            x = conv(channel_concat(x, skip_connections.pop()))
        return self.out_conv(channel_concat(x, skip_connections.pop()))

In [9]:
class UpDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, channels_per_group=16):
        super(DoubleConv, self).__init__()
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, in_channels, 3),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, out_channels, stride=2),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x, skip_connections):
        assert isinstance(skip_connections, list)
        assert len(skip_connections) == 2
        x = self.conv1(channel_concat(x, skip_connections.pop()))
        return self.conv2(channel_concat(x, skip_connections.pop()))

In [10]:
class Output(nn.Module):
    def __init__(self, in_channels=64, out_channels=3):
        super(Ouput, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3),  # should it be conv or transposed conv?
            nn.GroupNorm(1, out_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x, encoded_img):
        return self.block(channel_concat(x, encoded_img))

# Network architecture

In [None]:
class IlluminationSwapNet(nn.Module):
    def __init__(self):
        super(IlluminationSwapNet, self).__init__()
        self.image_encoder = ImageEncoding()
        
        # down
        self.layer128 = DownDoubleConv(64, 128)
        self.layer64 = DownDoubleConv(128, 256)
        self.layer32 = DownDoubleConv(256, 512)
        self.down_bottleneck = DownBottleneckConv(512)
        self.weighted_pool = WeightedPooling()
        
        # up
        self.tiling = Tiling()
    
    def forward(self, x):
        pass        