In [2]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import torchvision.models
from options.train_options import TrainOptions
import sys; sys.argv=['']; del sys

In [None]:
# %tb 
opt = TrainOptions().parse()

In [None]:
opt.init_type

In [3]:
def get_pretrained_vgg():
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

    def make_partial_vgg16():
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def init_vgg16(vgg):
        vgg16_state_dict = torchvision.models.vgg16_bn(pretrained=True).state_dict()
        dict_new = vgg.state_dict().copy()
        new_list = list(vgg.state_dict().keys())
        trained_list = list(vgg16_state_dict.keys())
        
        # for i in range(self.n_vgg_parameters):
        for i, _ in enumerate(vgg16.parameters()):
            dict_new[new_list[i]] = vgg16_state_dict[trained_list[i]]
        
        vgg.load_state_dict(dict_new)
        print('VGG parameters loaded.')

    vgg16 = make_partial_vgg16()
    init_vgg16(vgg16)
    return vgg16


class BaseModule(nn.Module):
    def __init__(self):
        super(BaseModule, self).__init__()
    
    def weights_init_func(self, m, init_type, gain):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif isinstance(m, nn.BatchNorm2d):
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)


class StyleExtractor(BaseModule):
    def __init__(self, vgg=get_pretrained_vgg(), n_kernel_channels=64, 
                 init_type='xavier', n_hidden=1024):
        super(StyleExtractor, self).__init__()

        self.nkc = n_kernel_channels
        self.vgg16 = vgg
        self.representor = nn.Sequential(
            nn.Linear(512 * 7 * 7, n_hidden),
            nn.ReLU(True),
            nn.Dropout(),
            # nn.Linear(4096, 4096),
            # nn.ReLU(True),
            # nn.Dropout()
        )
        self.conv_kernel_gen_1 = nn.Linear(n_hidden, n_kernel_channels*n_kernel_channels*3*3)
        self.conv_kernel_gen_2 = nn.Linear(n_hidden, n_kernel_channels*n_kernel_channels*3*3)
        self.conv_kernel_gen_3 = nn.Linear(n_hidden, n_kernel_channels*n_kernel_channels*3*3)
        self.conv_kernel_gen_4 = nn.Linear(n_hidden, n_kernel_channels*n_kernel_channels*3*3)
        
        weights_init_func = lambda m : self.weights_init_func(m, init_type, gain=0.02)
        for module in self.children():
            if module is not self.vgg16:
                module.apply(weights_init_func)       
        print('StyleExtractor weights initialized using %s.' % init_type)
        
        
    def forward(self, x):
        conv_kernels = []
        features = self.vgg16(x).detach()
        deep_features = self.representor(features)
        conv_kernels.append(self.conv_kernel_gen_1(deep_features).view(self.nkc, self.nkc, 3, 3))
        conv_kernels.append(self.conv_kernel_gen_2(deep_features).view(self.nkc, self.nkc, 3, 3))
        conv_kernels.append(self.conv_kernel_gen_3(deep_features).view(self.nkc, self.nkc, 3, 3))
        conv_kernels.append(self.conv_kernel_gen_4(deep_features).view(self.nkc, self.nkc, 3, 3))

        return conv_kernels
                      
    def train(self, mode=True):
        r"""
        Override the train method inherited from nn.Module to keep vgg blocks always in train mode.
        """
        self.training = mode
        for module in self.children():
            # if module in self.vgg_block_set:
            if module is self.vgg16:
                module.train(False)
            else:
                module.train(mode)
        return self

VGG parameters loaded.


In [4]:
tnet = StyleExtractor()

StyleExtractor weights initialized using xavier.


In [5]:
list(tnet.vgg16.parameters())[0]

Parameter containing:
tensor([[[[ 8.2833e-02,  2.7968e-02,  7.7096e-02],
          [ 4.9341e-02, -3.3441e-02,  1.9572e-02],
          [ 8.0300e-02,  7.7076e-02,  8.3349e-02]],

         [[-4.4296e-02, -1.7748e-01, -4.8706e-02],
          [-1.1003e-01, -2.7530e-01, -1.3474e-01],
          [-5.9982e-03, -6.1375e-02,  1.6822e-02]],

         [[ 2.7480e-02, -6.6769e-02,  4.3955e-02],
          [-2.6662e-02, -1.4995e-01, -3.3615e-02],
          [ 5.2778e-02,  1.7143e-02,  8.6744e-02]]],


        [[[-1.2628e-02,  3.0218e-02, -2.6930e-02],
          [-1.3764e-02,  1.1993e-01, -6.6263e-03],
          [-2.6019e-02, -8.3535e-03, -3.9197e-02]],

         [[-4.0557e-02,  1.3983e-02, -5.4278e-02],
          [ 1.5412e-02,  1.8198e-01,  1.7598e-02],
          [-1.7032e-02,  1.1284e-02, -2.4226e-02]],

         [[-6.5683e-02,  5.9252e-02, -5.3020e-02],
          [ 3.8278e-02,  2.7292e-01,  5.9491e-02],
          [-4.1218e-02,  3.6159e-02, -3.0478e-02]]],


        [[[ 1.4962e-06, -1.1430e-06,  1.2536

In [11]:
tnet = Generator()

Encoder weights initialized using xavier.
Decoder weights initialized using xavier.
VGG parameters loaded.
StyleExtractor weights initialized using xavier.
StyleWhitener weights initialized using xavier.
Generator build success!


In [17]:
list(tnet.style_extractor.children())[0] is tnet.style_extractor.vgg16

True

In [2]:
def get_pretrained_vgg():
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

    def make_partial_vgg16():
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def init_vgg16(vgg):
        vgg16_state_dict = torchvision.models.vgg16_bn(pretrained=True).state_dict()
        dict_new = vgg.state_dict().copy()
        new_list = list(vgg.state_dict().keys())
        trained_list = list(vgg16_state_dict.keys())
        
        # for i in range(self.n_vgg_parameters):
        for i, _ in enumerate(vgg16.parameters()):
            dict_new[new_list[i]] = vgg16_state_dict[trained_list[i]]
        
        vgg.load_state_dict(dict_new)
        print('VGG parameters loaded.')

    vgg16 = make_partial_vgg16()
    init_vgg16(vgg16)
    return vgg16

In [3]:
vgg16 = get_pretrained_vgg()

VGG parameters loaded.


In [4]:
print(vgg16)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace)
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace)
  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU(inplace)
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): BatchNor

In [6]:
vgg_real = torchvision.models.vgg16_bn(pretrained=True)

In [8]:
list(vgg_real.parameters())[0]

Parameter containing:
tensor([[[[ 8.2833e-02,  2.7968e-02,  7.7096e-02],
          [ 4.9341e-02, -3.3441e-02,  1.9572e-02],
          [ 8.0300e-02,  7.7076e-02,  8.3349e-02]],

         [[-4.4296e-02, -1.7748e-01, -4.8706e-02],
          [-1.1003e-01, -2.7530e-01, -1.3474e-01],
          [-5.9982e-03, -6.1375e-02,  1.6822e-02]],

         [[ 2.7480e-02, -6.6769e-02,  4.3955e-02],
          [-2.6662e-02, -1.4995e-01, -3.3615e-02],
          [ 5.2778e-02,  1.7143e-02,  8.6744e-02]]],


        [[[-1.2628e-02,  3.0218e-02, -2.6930e-02],
          [-1.3764e-02,  1.1993e-01, -6.6263e-03],
          [-2.6019e-02, -8.3535e-03, -3.9197e-02]],

         [[-4.0557e-02,  1.3983e-02, -5.4278e-02],
          [ 1.5412e-02,  1.8198e-01,  1.7598e-02],
          [-1.7032e-02,  1.1284e-02, -2.4226e-02]],

         [[-6.5683e-02,  5.9252e-02, -5.3020e-02],
          [ 3.8278e-02,  2.7292e-01,  5.9491e-02],
          [-4.1218e-02,  3.6159e-02, -3.0478e-02]]],


        [[[ 1.4962e-06, -1.1430e-06,  1.2536

In [9]:
list(vgg16.parameters())[0]

Parameter containing:
tensor([[[[ 8.2833e-02,  2.7968e-02,  7.7096e-02],
          [ 4.9341e-02, -3.3441e-02,  1.9572e-02],
          [ 8.0300e-02,  7.7076e-02,  8.3349e-02]],

         [[-4.4296e-02, -1.7748e-01, -4.8706e-02],
          [-1.1003e-01, -2.7530e-01, -1.3474e-01],
          [-5.9982e-03, -6.1375e-02,  1.6822e-02]],

         [[ 2.7480e-02, -6.6769e-02,  4.3955e-02],
          [-2.6662e-02, -1.4995e-01, -3.3615e-02],
          [ 5.2778e-02,  1.7143e-02,  8.6744e-02]]],


        [[[-1.2628e-02,  3.0218e-02, -2.6930e-02],
          [-1.3764e-02,  1.1993e-01, -6.6263e-03],
          [-2.6019e-02, -8.3535e-03, -3.9197e-02]],

         [[-4.0557e-02,  1.3983e-02, -5.4278e-02],
          [ 1.5412e-02,  1.8198e-01,  1.7598e-02],
          [-1.7032e-02,  1.1284e-02, -2.4226e-02]],

         [[-6.5683e-02,  5.9252e-02, -5.3020e-02],
          [ 3.8278e-02,  2.7292e-01,  5.9491e-02],
          [-4.1218e-02,  3.6159e-02, -3.0478e-02]]],


        [[[ 1.4962e-06, -1.1430e-06,  1.2536

In [6]:
a = torch.Tensor(2,3,4)

In [7]:
a

tensor([[[-226979659555413959137820672.0000,                            0.0000,
                                     0.0000,                            0.0000],
         [                           0.0000,                            0.0000,
                                     0.0000,                            0.0000],
         [                           0.0000,                            0.0000,
                                     0.0000,                            0.0000]],

        [[                           0.0000,                            0.0000,
                                     0.0000,                            0.0000],
         [                           0.0000,                            0.0000,
                                     0.0000,                            0.0000],
         [                           0.0000,                            0.0000,
                                     0.0000,                            0.0000]]])

In [9]:
a.view(-1).shape

torch.Size([24])

In [10]:
a = torch.Tensor(2,6)
b = torch.Tensor(2,7)
torch.cat((a,b), 1).shape

torch.Size([2, 13])

In [12]:
class Discriminator(BaseModule):
    def __init__(self, vgg=get_pretrained_vgg(), 
                 init_type='xavier', n_hidden=1024):
        super(Discriminator, self).__init__()

        self.vgg16 = vgg
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7 * 2, n_hidden),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid()
        )

        
        weights_init_func = lambda m : self.weights_init_func(m, init_type, gain=0.02)
        for module in self.children():
            if module is not self.vgg16:
                module.apply(weights_init_func)       
        print('Discriminator weights initialized using %s.' % init_type)

    def forward(self, img1, img2):
        feature1 = self.vgg16(img1).detach().view(img1.size(0), -1)
        feature2 = self.vgg16(img2).detach().view(img1.size(0), -1)
        feature_cat = torch.cat((feature1, feature2), 1)
        prob = self.classifier(feature_cat)
        return prob

VGG parameters loaded.


In [13]:
a = Discriminator()

Discriminator weights initialized using xavier.
