In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import sys

from collections import OrderedDict
from torch import nn
from torchvision.models.inception import BasicConv2d, InceptionA, InceptionB, InceptionC

In [None]:
sys.path.append('../src')

In [None]:
from models.inception import Inception3_Encoder
from models.style_augment import TransformNet

In [None]:
# model = Inception3_Encoder()
# cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_stylepredictor.pth')
# model.load_state_dict(cp['state_dict_stylepredictor'])

In [None]:
device = torch.device('cuda:1')

# x = torch.rand(6, 3, 299, 299).to(device)
# model = model.to(device)

In [None]:
net = TransformNet()

In [None]:
net = net.to(device)

In [None]:
N = 8
S = 512

x = torch.randn(N, 3, S, S).to(device)
y = torch.randn(N, 100).to(device)

In [None]:
net(x, y).shape

In [None]:
cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_transformer.pth')

In [None]:
def port_transformer_weights(source):
    dest = dict()
    
    # encoder
    for i in range(3):
        for name in ['weight', 'bias']:
            name_source = 'layers.%d.conv.%s' % (i, name)
            name_dest = 'encoder.%d.conv.%s' % (i, name)
            dest[name_dest] = source[name_source]
            
    # bottleneck
    source_layers = ['conv', 'fc_beta', 'fc_gamma']
    dest_layers = ['layers.conv', 'beta', 'gamma']
    
    for i in range(5):
        for j in range(2):
            for source_layer, dest_layer in zip(source_layers, dest_layers):
                for name in ['weight', 'bias']:
                    name_source = 'layers.{}.{}.{}'.format(i + 3, source_layer + str(j +1 ), name)
                    name_dest = 'layers.{}.conv{}.{}.{}'.format(i, j + 1, dest_layer, name)
                    dest[name_dest] = source[name_source]
                    
    # decoder
    for i in range(3):
        for source_layer, dest_layer in zip(source_layers, dest_layers):
                for name in ['weight', 'bias']:
                    name_source = 'layers.{}.{}.{}'.format(i + 8, source_layer, name)
                    name_dest = 'layers.{}.{}.{}'.format(i + 5, dest_layer, name)
                    dest[name_dest] = source[name_source]
                    
    return dest

In [None]:
weights = port_transformer_weights(cp['state_dict_ghiasi'])

In [None]:
net.load_state_dict(weights)

In [None]:
cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_embeddings.pth')

In [None]:
for k, v in cp.items():
    print("{}: {}".format(k, tuple(v.shape)))

In [None]:
o = torch.ones(5) + torch.randn(5) / 100
z = torch.zeros(5) + torch.randn(5) / 100

In [None]:
alpha = 0.2

alpha * o + (1 - alpha) * z

In [None]:
torch.lerp(o, z, 1 - alpha)

In [None]:
o.clone().lerp_(z, 1 - alpha)

In [None]:
class StyleAugmentNet(nn.Module):
    def __init__(self, img_channels=3, style_dim=100):
        super(StyleAugmentNet, self).__init__()
        self.style_dim = style_dim
        self.style_encoder = Inception3_Encoder(out_features=style_dim, transform_input=True)
        self.transform = TransformNet(img_channels)
        
        self.register_buffer('style_mean', torch.zeros(style_dim))
        self.register_buffer('style_cov', torch.ones(style_dim, style_dim))
        self.register_buffer('style_std', torch.empty(style_dim, style_dim))
        self.compute_style_std()
        
    def compute_style_std(self):
        u, s, v = torch.svd(self.style_cov)
        s = torch.sqrt(s)
        self.style_std = (u @ s.diag()).T
    
    def sample_style(self, batch_size, device=None):
        s = torch.randn(batch_size, self.style_dim, device=device)
        s = torch.mm(s, self.style_std).add_(self.style_mean)
        return s
        
    def forward(self, x, style=None, alpha=0.5):    
        if style is None:
            style = self.sample_style(x.size(0), device=x.device)
            
        if alpha < 1:
            x1 = F.interpolate(x, size=299, mode='bicubic', align_corners=False)
            orig_style = self.style_encoder(x1)
            style.lerp_(orig_style, 1 - alpha)
            del x1
            
        x = self.transform(x, style)
        return x
    
    def load_state_dict(self, state_dict, strict=True):
        super(StyleAugmentNet, self).load_state_dict(state_dict, strict=strict)
        self.compute_style_std()

In [None]:
model = StyleAugmentNet().to(device)
# model.requires_grad_(False)

In [None]:
cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_stylepredictor.pth')
weights = cp['state_dict_stylepredictor']
model.style_encoder.load_state_dict(weights)

cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_transformer.pth')
weights = port_transformer_weights(cp['state_dict_ghiasi'])
model.transform.load_state_dict(weights)

cp = torch.load('../../style-augmentation/styleaug/checkpoints/checkpoint_embeddings.pth')
weights = dict(style_mean=cp['pbn_embedding_mean'], 
               style_cov=cp['pbn_embedding_covariance'])
model.load_state_dict(weights, strict=False)

In [None]:
# state = model.state_dict()
# torch.save(state, 'weights.pth')

In [None]:
# a = model(x)
# a.shape