In [None]:
import torch
from torch import nn

from models.resnet import ResnetGenerator, ResnetBlock, InverseTanh

from tqdm import trange
from torch.optim import LBFGS

from itertools import product
import numpy as np
from skimage import data
import skimage
import matplotlib.pyplot as plot
import PIL.Image

In [None]:
class CustomModel(nn.Module):
    
    def __init__(self):
        super(CustomModel, self).__init__()
        self.model = nn.Sequential(InverseTanh(),
                                   nn.Conv2d(3, 40, 3, 2, 1),
                                   ResnetBlock(40, 'zero', nn.BatchNorm2d, False, False),
                                   nn.ConvTranspose2d(40, 3, 3, 2, 1),
                                   nn.Tanh())
        
    def forward(self, x):
        return self.model(x)

In [None]:
model = ResnetGenerator(3, 3)
# model = CustomModel()
model.eval();

In [None]:
def compute_kernel(n_input, n_output, kernel, stride, n_features):
    weights = torch.randn([n_output, n_input, kernel[0], kernel[0]]) * 0.01
    center = kernel[0] // 2
    
    assert stride[0] == stride[1]
    assert kernel[0] == kernel[1]
    s = stride[0]
    for feature in range(n_features):
        for i, j in np.ndindex(stride):
            weights[feature*s*s + i*s + j, feature, center + i, center + j] = 1

    return weights

def compute_kernel_transpose(n_input, n_output, kernel, stride, n_features):
    weights = torch.randn([n_input, n_output, kernel[0], kernel[0]]) * 0.01
    center = kernel[0] // 2
    
    assert stride[0] == stride[1]
    assert kernel[0] == kernel[1]
    assert stride[0] < kernel[0]
    
    s = stride[0]
    for feature in range(n_features//s//s):
        for i, j in np.ndindex(stride):
            weights[feature*s*s + i*s + j, feature,  center + i, center + j] = 1

    return weights

In [None]:
def zero_resblock(layer):
    layers = list(layer.conv_block.children())

    nn.init.normal(layers[0].weight, 0.00, 0.01)
    if layers[0].bias is not None: nn.init.constant(layers[0].bias, 0)
    nn.init.constant(layers[1].weight, 1)
    if layers[1].bias is not None: nn.init.constant(layers[1].bias, 0)
    nn.init.normal(layers[3].weight, 0.00, 0.01)
    if layers[3].bias is not None: nn.init.constant(layers[3].bias, 0)
    nn.init.constant(layers[4].weight, 1)
    nn.init.constant(layers[4].running_var, 1)
    nn.init.constant(layers[4].running_mean, 0)
    if layers[4].bias is not None: nn.init.constant(layers[4].bias, 0)

In [None]:
n_features_info = 3

for layer in model.model.children():
    if layer.__class__ == nn.BatchNorm2d:
        nn.init.zeros_(layer.bias)
        nn.init.ones_(layer.weight)
        nn.init.zeros_(layer.running_mean)
        nn.init.ones_(layer.running_var)
    elif layer.__class__ == ResnetBlock:
        zero_resblock(layer)
    elif layer.__class__ == nn.Conv2d:
        kernel = compute_kernel(layer.in_channels, layer.out_channels, 
                                layer.kernel_size, layer.stride, n_features_info)
        layer.weight.data = kernel
        if layer.bias is not None: nn.init.zeros_(layer.bias)
        n_features_info *= layer.stride[0] * layer.stride[1]
    elif layer.__class__ == nn.ConvTranspose2d:
        kernel = compute_kernel_transpose(layer.in_channels, layer.out_channels, 
                                layer.kernel_size, layer.stride, n_features_info)
        layer.weight.data = kernel
        if layer.bias is not None: nn.init.zeros_(layer.bias)
        n_features_info = n_features_info // (layer.stride[0] * layer.stride[1])
        
    print("{:<15} > {}".format(layer.__class__.__name__, n_features_info))

In [None]:
image = skimage.img_as_float(data.astronaut())

image_t = torch.Tensor(image.transpose(2, 0, 1))[None]

In [None]:
image_t1 = model.forward(image_t)
image_1 = image_t1.detach().numpy()[0].transpose(1, 2, 0)

In [None]:
# PIL.Image.fromarray(skimage.img_as_ubyte(image_1))

In [None]:
# PIL.Image.fromarray(skimage.img_as_ubyte(image))