Model based on "Single Image Portrait Relighting" whose aim is to produce same scenes as given with illumination conditions (color/light direction) swapped between 2 given inputs.

In [104]:
import torch
import torch.nn as nn

# Model building blocks

In [105]:
class ImageEncoding(nn.Module):
    def __init__(self, in_channels=3, enc_channels=32, out_channels=64):
        super(ImageEncoding, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, enc_channels - in_channels, 7, padding=3)
        self.conv2 = nn.Conv2d(enc_channels, out_channels, 3, stride=2, padding=1)
        self.encoded_image = None
    
    def forward(self, image):
        x = self.conv1(image)
        self.encoded_image = torch.cat((x, image), dim=1)  # append image channels to the convolution result
        return self.conv2(self.encoded_image)
    
    def get_encoded_image(self):
        # should not require tensor copying (based on https://github.com/milesial/Pytorch-UNet)
        return self.encoded_image

In [106]:
class DownDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, channels_per_group=16):
        super(DownDoubleConv, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
        self.skip_connections = []
    
    def forward(self, x):
        self.skip_connections.append(x)
        x = self.conv1(x)
        self.skip_connections.append(x)
        return self.conv2(x)
    
    def get_skip_connections(self):
        return self.skip_connections
    
    def reset_skip_connections(self):
        self.skip_connections = []

In [107]:
class DownBottleneckConv(nn.Module):
    def __init__(self, in_channels, channels_per_group=16, envmap_H=16, envmap_W=32, depth=4):
        expected_channels = envmap_H * envmap_W
        assert in_channels == expected_channels, f'DownBottleneck input has {in_channels} channels, expected {expected_channels}'
        super(DownBottleneckConv, self).__init__()
        self.convolutions = self._build_bottleneck_convolutions(in_channels, channels_per_group, depth)
        self.skip_connections = []
    
    def _build_bottleneck_convolutions(self, in_channels, channels_per_group, depth):
        single_conv = [nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )]
        convolutions = single_conv * (depth - 1)
        
        # final layer before weighted pooling
        out_channels = 4*in_channels
        convolutions.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.Softplus()
        ))
        
        return nn.ModuleList(convolutions)
    
    def forward(self, x):
        for module in self.convolutions:
            self.skip_connections.append(x)
            x = module(x)
        
        # last layer output is not used as skip_connections:
        self.skip_connections.pop()
            
        # split x into environment map predictions and confidence
        channels = x.size()[1]
        split_point = 3 * (channels // 4)
        envmap_predictions, confidence = x[:, :split_point], x[:, split_point:]
        
        return envmap_predictions, confidence
       
    def get_skip_connections(self):
        return self.skip_connections
    
    def reset_skip_connections(self):
        self.skip_connections = []

In [108]:
class WeightedPooling(nn.Module):
    def __init__(self):
        super(WeightedPooling, self).__init__()
    
    def forward(self, x, weights):
        # TODO: multiplication with sum can be probably implemented as convolution with groups (according to some posts)
        return (x * weights.repeat((1, 3, 1, 1))).sum(dim=(2, 3), keepdim=True)

In [109]:
class Tiling(nn.Module):
    def __init__(self, size=16, in_channels=1536, out_channels=512, channels_per_group=16):
        super(Tiling, self).__init__()
        self.size = size
        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x):
        tiled = x.repeat((1, 1, self.size, self.size))
        return self.encode(tiled)

Articles on transposed convolutions:
* [Convolution types](https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d)
* [Upsampling with transposed convolution](https://medium.com/activating-robotic-minds/up-sampling-with-transposed-convolution-9ae4f2df52d0)

In [110]:
def channel_concat(x, y):
    return torch.cat((x, y), dim=1)

class UpBottleneckConv(nn.Module):
    def __init__(self, in_channels, channels_per_group=16, envmap_H=16, envmap_W=32, depth=4):
        expected_channels = envmap_H * envmap_W
        assert depth >= 2, f'Depth should be not smaller than 3'
        assert in_channels == expected_channels, f'UpBottleneck input has {in_channels} channels, expected {expected_channels}'
        super(UpBottleneckConv, self).__init__()
        self.depth = depth
        
        half_in_channels = in_channels // 2
        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, half_in_channels, 3, padding=1),
            nn.GroupNorm(half_in_channels // channels_per_group, half_in_channels),
            nn.PReLU()
        )
        # TODO: why are these paddings necessary
        self.initial_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels + half_in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )
        self.convs = nn.ModuleList([nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )] * (depth - 3))
        # TODO: output_padding added to fit the spatial dimensions, but there is no reasoned justification for it
        self.out_conv = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, half_in_channels, 3, stride=2, padding=1, output_padding=1),
            nn.GroupNorm(half_in_channels // channels_per_group, half_in_channels),
            nn.PReLU()
        )
    
    def forward(self, x, skip_connections):      
        # encoding convolution
        x = self.encode(x)
        
        # transposed convolutions with skip connections
        x = self.initial_conv(channel_concat(x, skip_connections.pop()))
        for conv in self.convs:
            x = conv(channel_concat(x, skip_connections.pop()))
        return self.out_conv(channel_concat(x, skip_connections.pop()))

In [111]:
class UpDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, channels_per_group=16):
        super(UpDoubleConv, self).__init__()
        # TODO: why are these paddings necessary
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, in_channels, 3, padding=1),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            nn.PReLU()
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(2*in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
            nn.GroupNorm(out_channels // channels_per_group, out_channels),
            nn.PReLU()
        )
    
    def forward(self, x, skip_connections):
        x = self.conv1(channel_concat(x, skip_connections.pop()))
        return self.conv2(channel_concat(x, skip_connections.pop()))

In [112]:
class Output(nn.Module):
    def __init__(self, in_channels=64, out_channels=3):
        super(Output, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3),  # should it be conv or transposed conv?
            nn.GroupNorm(1, out_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x, encoded_img):
        return self.block(channel_concat(x, encoded_img))

# Aggregated modules for skip connections management

In [113]:
class Down(nn.Module):
    def __init__(self):
        super(Down, self).__init__()
        self.image_encoder = ImageEncoding()
        
        # down
        n_double_conv = 3
        self.down_double_convs = self._build_down_double_convs(n_double_conv)
        self.down_bottleneck = DownBottleneckConv(512)
        self.weighted_pool = WeightedPooling()
        
        self.skip_connections = []
    
    def _build_down_double_convs(self, n):
        return nn.ModuleList([DownDoubleConv(64*(2**i), 64*(2**(i+1))) for i in range(n)])
    
    def forward(self, image):
        # initial image encoding
        x = self.image_encoder(image)
        self.skip_connections.append(self.image_encoder.get_encoded_image())
        
        # double convolution layers
        for down_double_conv in self.down_double_convs:
            x = down_double_conv(x)
            self.skip_connections += down_double_conv.get_skip_connections()
            
        # pre-bottleneck layer
        env_map, weights = self.down_bottleneck(x)
        self.skip_connections += self.down_bottleneck.get_skip_connections()
        
        # predict environment map
        pred_env_map = self.weighted_pool(env_map, weights)
        
        return pred_env_map
    
    def get_skip_connections(self):
        return self.skip_connections
    
    def reset_skip_connections(self):
        self.skip_connections = []
        for down_double_conv in self.down_double_convs:
            down_double_conv.reset_skip_connections()
        self.down_bottleneck.reset_skip_connections()

In [114]:
class Up(nn.Module):
    def __init__(self):
        super(Up, self).__init__()
        
        # up
        n_double_conv = 3
        self.tiling = Tiling()
        self.up_bottleneck = UpBottleneckConv(512)
        self.up_double_convs = self._build_up_double_convs(n_double_conv)
        
        self.output = Output()
    
    def _build_up_double_convs(self, n):
        return nn.ModuleList([UpDoubleConv(256//(2**i), 256//(2**(i+1))) for i in range(n)])
    
    def forward(self, latent, skip_connections):
        # tiling and channel reduction
        tiled = self.tiling(latent)
        
        # post-bottleneck layer
        x = self.up_bottleneck(tiled, skip_connections)
        
        # double convolution layers
        for up_double_conv in self.up_double_convs:
            x = up_double_conv(x, skip_connections)
        
        # final layer constructing image
        relighted = self.output(x, skip_connections.pop())
        
        return relighted

# Network architecture

In [115]:
class IlluminationSwapNet(nn.Module):
    def __init__(self):
        super(IlluminationSwapNet, self).__init__()
        self.down = Down()
        self.up = Up()
    
    def forward(self, image, target):
        # pass image through encoder
        image_env_map = self.down(image)
        image_skip_connections = self.down.get_skip_connections()
        self.down.reset_skip_connections()
        
        # pass target through encoder
        target_env_map = self.down(target)
        target_skip_connections = self.down.get_skip_connections()
        self.down.reset_skip_connections()
        
        # decode image with target env map
        image_relighted = self.up(target_env_map, image_skip_connections)
        # can also add target relighting here
        
        return image_relighted

# Test for size mismatch errors

In [116]:
model = IlluminationSwapNet()

In [119]:
image = torch.randn(4, 3, 256, 256)
target = torch.randn(4, 3, 256, 256)

In [120]:
_ = model(image, target)