In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import sys

from collections import OrderedDict
from pathlib import Path
from torch import nn
from torchvision.models.inception import BasicConv2d, InceptionA, InceptionB, InceptionC

In [2]:
sys.path.append('../src')

In [3]:
from models.inception import Inception3_Encoder
from models.style_augment import StyleAugmentNet

In [4]:
device = torch.device('cuda:1')

In [5]:
def port_transformer_weights(source):
    dest = dict()
    
    # encoder
    for i in range(3):
        for name in ['weight', 'bias']:
            name_source = 'layers.%d.conv.%s' % (i, name)
            name_dest = 'encoder.%d.conv.%s' % (i, name)
            dest[name_dest] = source[name_source]
            
    # bottleneck
    source_layers = ['conv', 'fc_beta', 'fc_gamma']
    dest_layers = ['layers.conv', 'beta', 'gamma']
    
    for i in range(5):
        for j in range(2):
            for source_layer, dest_layer in zip(source_layers, dest_layers):
                for name in ['weight', 'bias']:
                    name_source = 'layers.{}.{}.{}'.format(i + 3, source_layer + str(j +1 ), name)
                    name_dest = 'layers.{}.conv{}.{}.{}'.format(i, j + 1, dest_layer, name)
                    dest[name_dest] = source[name_source]
                    
    # decoder
    for i in range(3):
        for source_layer, dest_layer in zip(source_layers, dest_layers):
                for name in ['weight', 'bias']:
                    name_source = 'layers.{}.{}.{}'.format(i + 8, source_layer, name)
                    name_dest = 'layers.{}.{}.{}'.format(i + 5, dest_layer, name)
                    dest[name_dest] = source[name_source]
                    
    return dest

In [6]:
model = StyleAugmentNet().to(device)
model.requires_grad_(False)

StyleAugmentNet(
  (style_encoder): Inception3_Encoder(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2a_3x3): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2b_3x3): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Conv2d_3b_1x1): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (C

In [7]:
CP_DIR = Path('../../style-augmentation/styleaug/checkpoints')

cp = torch.load(CP_DIR/'checkpoint_stylepredictor.pth')
weights = cp['state_dict_stylepredictor']
model.style_encoder.load_state_dict(weights)

cp = torch.load(CP_DIR/'checkpoint_transformer.pth')
weights = port_transformer_weights(cp['state_dict_ghiasi'])
model.transform.load_state_dict(weights)

cp = torch.load(CP_DIR/'checkpoint_embeddings.pth')
weights = dict(style_mean=cp['pbn_embedding_mean'], 
               style_cov=cp['pbn_embedding_covariance'])
model.load_state_dict(weights, strict=False)

In [8]:
N = 8
S = 512

x = torch.randn(N, 3, S, S).to(device)
y = torch.randn(N, 100).to(device)

In [9]:
# state = model.state_dict()
# torch.save(state, 'weights.pth')

In [10]:
a = model(x)
a.shape

torch.Size([8, 3, 512, 512])