In [120]:
import torch
from torch import nn
import matplotlib.pyplot as plt

def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

In [123]:
[0,1,2,3,4][1:4]

[1, 2, 3]

In [68]:
class Seq2VecModel(nn.Module):
    """docstring for Seq2VecModel"""
    def __init__(self, features_dim, hidden_dim = 512):
        super(Seq2VecModel, self).__init__()
        self.features_dim = features_dim
        self.hidden_dim = hidden_dim
        self.feature2vector = nn.Sequential(nn.Conv2d(self.features_dim,128,3),
                               nn.BatchNorm2d(128),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(128,64,3),
                               nn.BatchNorm2d(64),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(64,32,1),
                               nn.BatchNorm2d(32),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(32,16,1),
                               nn.BatchNorm2d(16),
                               nn.ReLU(),
                               nn.MaxPool2d(2))
        self.lstm = nn.LSTM(400, 512, batch_first=True)
        
    def forward(self, features):
        # [batch size, dt, 256, 96, 96]

        features_shape = features.shape[2:]
        batch_size, dt = features.shape[:2]
        vectors = self.feature2vector(features.view(-1, *features_shape))
        vectors = vectors.view(batch_size, dt, -1)
        (h0, c0) = torch.randn(1, batch_size, self.hidden_dim), torch.randn(1, batch_size, self.hidden_dim)
        output, (hn, cn) = self.lstm(vectors, (h0, c0))
        return output[:,-1,...]
        
        

In [69]:
features = torch.randn(2,3,256, 96, 96)

m = Seq2VecModel(256)
m(features).shape

torch.Size([2, 512])

In [76]:
feature2vector = nn.Sequential(nn.Conv2d(256,128,3),
                               nn.BatchNorm2d(128),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(128,64,3),
                               nn.BatchNorm2d(64),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(64,32,1),
                               nn.BatchNorm2d(32),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               nn.Conv2d(32,16,1),
                               nn.BatchNorm2d(16),
                               nn.ReLU(),
                               nn.MaxPool2d(2))

features_shape = features.shape[2:]
feature2vector(features.view(-1, *features_shape)).view(6,-1).shape

torch.Size([6, 1936])

In [None]:
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)       

In [118]:
class AdaIN(nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
    def forward(self, features, params):
        # [batch_size, C, D1, D2, D3]
        size = features.size()
        batch_size, C = features.shape[:2]
        features_mean = features.view(batch_size, C, -1).mean(-1).view(batch_size, C,1,1,1)
        features_std = features.view(batch_size, C, -1).std(-1).view(batch_size, C,1,1,1)
        norm_features = (features - features_mean) / features_std
        return norm_features * params[0] + params[1]

In [119]:
a = AdaIN()
a(features, [1,1]).shape

torch.Size([2, 256, 64, 64, 64])