# Imports

In [None]:
%pylab inline
import time
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms

from PIL import Image
from collections import OrderedDict
from resnet import resnet50
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

# Utilities

In [None]:
content_dir = Path("./input/content/")
style_dir = Path("./input/style/")
model_dir = Path("./models/")

In [None]:
# gram matrix and loss
def gram_matrix(x):
    b,c,h,w = x.size()
    F = x.view(b, c, h*w)
    G = torch.bmm(F, F.transpose(1,2)) 
    G.div_(h*w)
    return G

def gram_mse(x, gt):
    return F.mse_loss(gram_matrix(x), gt)

In [None]:
# pre(post)processing
img_size = 512

transforms_fw = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(
         mean=[0.48501961, 0.45795686, 0.40760392], 
         std=[0.2290, 0.2240, 0.2250])
])

transforms_bw = transforms.Compose([
    transforms.Lambda(lambda x: x.mul_(torch.tensor([0.2290, 0.2240, 0.2250]).view(3, 1, 1))),
    transforms.Lambda(lambda x: x.add_(torch.tensor([0.48501961, 0.45795686, 0.40760392]).view(3, 1, 1))),
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
    transforms.ToPILImage()
])

In [None]:
# image and style reconstruction functions
def reconstruct(img_path, model, layer, max_iter=5000, show_iter=100):
    img = transforms_fw(Image.open(img_path)).unsqueeze(0).cuda()
    target = model.forward_layers(img, [layer])[0]
    opt_img = torch.randn(img.size()).type_as(img.data).requires_grad_(True) #random init
    optimizer = optim.Adam([opt_img], lr=5e-2)
    print("Reconstruction optimization process started...")
    loss_history = []
    for i in range(max_iter):
        optimizer.zero_grad()
        x = model.forward_layers(opt_img, [layer])[0]
        loss = F.mse_loss(x, target)
        loss.backward()
        if i%show_iter == 0 and i != 0:
            print('Iteration: %d, loss: %f'%(i, loss.item()))
        optimizer.step()
        loss = loss.cpu().item()
        if len(loss_history) > 25 and np.isclose(np.array(loss_history[-25:]), loss).all():
            break
        loss_history += [loss]
    
    
    #display result
    out_img = transforms_bw(opt_img.data[0].cpu().squeeze())
    plt.figure(figsize=(10, 10))
    plt.imshow(out_img)
    return out_img

def style_reconstruct(img_path, model, layer, max_iter=5000, show_iter=100):
    img = transforms_fw(Image.open(img_path)).unsqueeze(0).cuda()
    target = model.forward_layers(img, [layer])[0]
    target = gram_matrix(target)
    opt_img = torch.randn(img.size()).type_as(img.data).requires_grad_(True) #random init
    optimizer = optim.Adam([opt_img], lr=5e-2)
    loss_history = []
    print("Style reconstruction optimization process started...")
    for i in range(max_iter):
        optimizer.zero_grad()
        x = model.forward_layers(opt_img, [layer])[0]
        x = gram_matrix(x)
        loss = 1000 * F.mse_loss(x, target)
        loss.backward()
        if i%show_iter == 0 and i != 0:
            print('Iteration: %d, loss: %f'%(i, loss.item()))
        optimizer.step()
        loss = loss.cpu().item()
        if len(loss_history) > 25 and np.isclose(np.array(loss_history[-25:]), loss).all():
            break
        loss_history += [loss]
    
    #display result
    out_img = transforms_bw(opt_img.data[0].cpu().squeeze())
    plt.figure(figsize=(10, 10))
    plt.imshow(out_img)
    return out_img

In [None]:
def style_transfer(
        content_img_path, 
        style_img_path, 
        model, 
        content_layers=["layer3"], 
        style_layers=["layer1", "layer2", "layer3", "layer4"], 
        content_weights=[1], 
        style_weights=[1e4, 1e3, 1e2, 1e1], 
        max_iter=600, 
        show_iter=100):
    
    style_img = transforms_fw(Image.open(style_img_path)).unsqueeze(0).cuda()
    content_img = transforms_fw(Image.open(content_img_path)).unsqueeze(0).cuda()
    opt_img = content_img.clone().detach().requires_grad_(True)
    
    loss_layers = style_layers + content_layers
    loss_fns = [gram_mse] * len(style_layers) + [F.mse_loss] * len(content_layers)
    
    #compute optimization targets
    style_targets = [gram_matrix(x).detach() for x in model.forward_layers(style_img, style_layers)]
    content_targets = [x.detach() for x in model.forward_layers(content_img, content_layers)]

    targets = style_targets + content_targets    
    weights = style_weights + content_weights
    
    optimizer = optim.LBFGS([opt_img])
    n_iter=[0]
    print("Style transfer initialized...")
    while n_iter[0] <= max_iter:
        def closure():
            optimizer.zero_grad()
            out = model.forward_layers(opt_img, loss_layers)
            layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)]
            loss = sum(layer_losses)
            loss.backward()
            n_iter[0]+=1
            #print loss
            if n_iter[0]%show_iter == (show_iter-1):
                print('Iteration: %d, loss: %f'%(n_iter[0]+1, loss.item()))
            return loss
        optimizer.step(closure)
    out_img = transforms_bw(opt_img.data[0].cpu().squeeze())
    plt.figure(figsize=(10, 10))
    plt.imshow(out_img)

    return out_img

# Models initialization

In [None]:
random_model = resnet50(False).cuda().eval()
for param in random_model.parameters():
    param.requires_grad = False
    
imagenet_model = resnet50(True).cuda().eval()
for param in imagenet_model.parameters():
    param.requires_grad = False
    
robust_model = resnet50()
robust_state_dict = torch.load(model_dir / "resnet50_robust.pth")
del robust_state_dict['fc.weight']
del robust_state_dict['fc.bias']
robust_model.load_state_dict(robust_state_dict, strict=False)
robust_model = robust_model.cuda().eval()
for param in robust_model.parameters():
    param.requires_grad = False

# Content Reconstructions

## Random Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    reconstruct(content_dir / "kosci-kupres.jpg", random_model, l, max_iter=500, show_iter=1000)

## Regular ImageNet Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    reconstruct(content_dir / "kosci-kupres.jpg", imagenet_model, l, max_iter=500, show_iter=1000)

## Robust ImageNet Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    reconstruct(content_dir / "kosci-kupres.jpg", robust_model, l, max_iter=500, show_iter=1000)

# Style Reconstructions

## Random Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    style_reconstruct(style_dir / "scream.jpg", random_model, l, max_iter=5000, show_iter=6000)


## Regular ImageNet Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    style_reconstruct(style_dir / "scream.jpg", imagenet_model, l, max_iter=5000, show_iter=6000)

## Robust ImageNet Initialization

In [None]:
for l in  ["conv1", "layer1", "layer2", "layer3", "layer4"]:
    print(f"Reconstruction based on features from layer: {l}")
    style_reconstruct(style_dir / "scream.jpg", robust_model, l, max_iter=5000, show_iter=6000)

# Style Transfers

In [None]:
content_image_path = content_dir / "kosci-kupres.jpg"
style_images = [style_dir / p for p in ["scene_de_rue.jpg", "picasso_seated_nude_hr.jpg" ,"scream.jpg" ,"vangogh_starry_night.jpg"]]

## Regular ImageNet Initialization

In [None]:
for style_img in style_images:
    print(f"Content image: {content_image_path}, Style image: {style_img}")
    style_transfer(content_image_path, style_img, imagenet_model, max_iter=600, show_iter=700)

## Robust ImageNet Initialization

In [None]:
for style_img in style_images:
    print(f"Content image: {content_image_path}, Style image: {style_img}")
    style_transfer(content_image_path, style_img, robust_model, max_iter=600, show_iter=700)

## More examples

In [None]:
content_images = [content_dir / p for p in ["drazen-petrovic.jpg", "dubrovnik.jpg", "rimac.jpeg", "fer.jpg"]]
style_images = [style_dir / p for p in ["contrast_of_forms.jpg", "scream.jpg", "goeritz.jpg", "mondrian_cropped.jpg"]]

for content_image_path, style_image_path in zip(content_images, style_images):
    print(f"Content image: {content_image_path}, Style image: {style_image_path}")
    style_transfer(content_image_path, style_image_path, robust_model, max_iter=600, show_iter=700)

## Hyper-dependency on hyper-parameters

In [None]:
style_image_path = style_dir / "scream.jpg"
content_image_path = content_dir / "kosci-kupres.jpg"
x = 100
for i in range(6):
    style_weights = [x, x * 1e-1, x * 1e-2, x * 1e-3]
    print(f"Style weights: {style_weights}")
    x *= 10
    style_transfer(content_image_path, style_image_path, robust_model, max_iter=600, show_iter=700, style_weights=style_weights)