In [1]:
# %load main.py
import argparse
import time
import os
import os.path

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import torchvision
from torchvision import transforms

import PIL
from PIL import Image
import matplotlib.pyplot as plt
from collections import OrderedDict
from utility.utility import postp, GramMatrix, GramMSELoss, load_images, save_images, make_folders
from utility.loss_fns import get_style_patch_weights, patch_difference, mrf_loss_fn, weight_maker
from utility.vgg_network import VGG

In [2]:
# style weights
sw1=1
sw2=1
sw3=1
sw4=1
sw5=1
# Content weights
cw1=0
cw2=0
cw3=0
cw4=1
cw5=0
#############################################################################
# Get image paths and names
# Style 1
style_dir1  = os.path.dirname('../input/font_contents/serif_removal/NotoSans-Regular.png')
style_name1 = os.path.basename('../input/font_contents/serif_removal/NotoSans-Regular.png')
# Style 2
style_dir2  = os.path.dirname('../input/font_contents/serif_removal/NotoSerif-Regular.png')
style_name2 = os.path.basename('../input/font_contents/serif_removal/NotoSerif-Regular.png')
# Content
content_dir  = os.path.dirname('../input/font_contents/serif_removal/PT_Serif-Caption-Web-Italic.png')
content_name = os.path.basename('../input/font_contents/serif_removal/PT_Serif-Caption-Web-Italic.png')

# Cuda device
if torch.cuda.is_available:
    device = 'cuda:0'
else:
    device = 'cpu'
print("Using device: ", device)

# Parameters
alpha = 1
beta = 1
patch_size = 5
image_size = 256
content_invert = 1
style_invert = 1
result_invert = content_invert

# Get output path
output_path = '../output_style_difference/direct/'
try:
    os.mkdir(output_path)
except:
    pass
output_path = output_path + content_name[:-4] + '_' + style_name1[:-4] + '_' + style_name2[:-4] + '/'

Using device:  cuda:0


In [3]:
# Get network
vgg = VGG()
vgg.load_state_dict(torch.load('../Models/vgg_conv.pth'))
for param in vgg.parameters():
    param.requires_grad = False
vgg.to(device)

# Load images
content_image = load_images(os.path.join(content_dir, content_name), image_size, device, content_invert)
style_image1  = load_images(os.path.join(style_dir1,style_name1), image_size, device, style_invert)
style_image2  = load_images(os.path.join(style_dir2,style_name2), image_size, device, style_invert)

# Random input
# opt_img = Variable(torch.randn(content_image.size()).type_as(content_image.data).to(device), requires_grad=True).to(device)
# Content input
opt_img = Variable(content_image.data.clone(), requires_grad=True)

#### Define layers, loss functions, weights and compute optimization targets

In [11]:
# Define layers, loss functions, weights and compute optimization targets
# Style layers
style_layers = ['r12','r22','r34','r44','r54'] 
style_weights = [sw*1e3/n**2 for sw,n in zip([sw1,sw2,sw3,sw4,sw5],[64,128,256,512,512])]
# style_weights = [1,1,1,1,1]
# Content layers
#content_layers = ['r12','r22','r32','r42','r52']
# content_layers = ['r31','r32','r33','r34','r41']
content_layers = ['r41']
content_weights = [1e4]
# content_weights = [cw1*1e4,cw2*1e4,cw3*1e4,cw4*1e4,cw5*1e4]

loss_layers = style_layers + content_layers
loss_functions = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers)
loss_functions = [loss_fn.to(device) for loss_fn in loss_functions]
weights = style_weights + content_weights

#### Compute optimization targets

In [12]:
# Compute optimization targets
### Gram matrix targets

# Feature maps from style layers of the style images
style1_fms_style = [A.detach() for A in vgg(style_image1, style_layers)]
style2_fms_style = [A.detach() for A in vgg(style_image2, style_layers)]
# Gram matrices of style feature maps
style1_gramm = [GramMatrix()(A) for A in style1_fms_style]
style2_gramm = [GramMatrix()(A) for A in style2_fms_style]
# Difference between gram matrices of style1 and style2
gramm_style = [(style1_gramm[i] - style2_gramm[i])**2 for i in range(len(style_layers))]

# Feature maps from style layers of the content image
content_fms_style = [A.detach() for A in vgg(content_image, style_layers)]
content_gramm = [GramMatrix()(A) for A in content_fms_style]

In [13]:
### Content targets
# Feature maps from content layers of the style images
style1_fms_content = [A.detach() for A in vgg(style_image1, content_layers)]
style2_fms_content = [A.detach() for A in vgg(style_image2, content_layers)]
# Difference between feature maps
style_fms_content = [(style1_fms_content[i] - style2_fms_content[i])**2 for i in range(len(content_layers))]
# Feature maps from content layers of the content image
content_fm_content = [A.detach() for A in vgg(content_image, content_layers)]

In [14]:
# Run style transfer
make_folders(output_path)

max_iter = 1000
show_iter = 50
optimizer = optim.LBFGS([opt_img])
n_iter=[0]
loss_list = []
c_loss = []
s_loss = []

In [15]:
while n_iter[0] <= max_iter:

    def closure():
        optimizer.zero_grad()
        out = vgg(opt_img, loss_layers)
        content_layer_losses = []
        style_layer_losses  = []
        
        opt_fms_style = []
        opt_fms_content = []
        # Divide between style feature maps and content feature maps
        for i, A in enumerate(out):
            if i < len(style_layers):
                opt_fms_style.append(A)
            else:
                opt_fms_content.append(A)

        ## Difference between feature maps on style layers
#        diff_fms_style = [opt_fms_style[i] - content_fms_style[i] for i in range(len(style_layers))]
#        gramm_diff = [GramMatrix()(A) for A in diff_fms_style]
        ## Difference between gram matrix of feature map differences
#        style_layer_losses = [style_weights[i]*(nn.MSELoss()(gramm_diff[i], gramm_style[i])) for i in range(len(style_layers))]
        
        opt_gramm = [GramMatrix()(A) for A in opt_fms_style]
        gramm_diff = [(opt_gramm[i] - content_gramm[i]) for i in range(len(style_layers))]
        style_layer_losses = [style_weights[i]*nn.MSELoss()(gramm_diff[i], gramm_style[i]) for i in range(len(style_layers))]

        ## Difference between feature maps on content layers
        fms_diff = [(opt_fms_content[i] - content_fm_content[i]) for i in range(len(content_layers))]
        content_layer_losses = [content_weights[i]*nn.MSELoss()(fms_diff[i],style_fms_content[i]) for i in range(len(content_layers))]
        

        # losses
        content_loss = sum(content_layer_losses)
        style_loss   = sum(style_layer_losses)
        
        # ld1 = len(str(content_loss.item()))
        # ld2 = len(str(style_loss.item()))
        # if ld1 > ld2:
        #     div = ld1 - ld2
        #     style_loss = style_loss*(10**(div))
        # else:
        #     div = ld2 - ld1
        #     content_loss = content_loss*(10**(div))
        
        
        layer_losses = content_layer_losses + style_layer_losses

        # total loss
        loss = sum(layer_losses)

        # for log
        c_loss.append(content_loss)
        s_loss.append(style_loss)
        loss_list.append(loss)

        # backward calculation
        loss.backward()

        #print loss
        if n_iter[0]%show_iter == 0:
            print('Iteration: {}'.format(n_iter[0]))
            print('Content loss: {}'.format(content_loss.item()))
            print('Style loss  : {}'.format(style_loss.item()))
            print('Total loss  : {}'.format(loss.item()))

            # Save loss graph
            plt.plot(loss_list, label='Total loss')
            plt.plot(c_loss, label='Content loss')
            plt.plot(s_loss, label='Style loss')
            plt.legend()
            plt.savefig(output_path + 'loss_graph.jpg')
            plt.close()
            # Save optimized image
            out_img = postp(opt_img.data[0].cpu().squeeze(), image_size, result_invert)
            out_img.save(output_path + 'outputs/{}.bmp'.format(n_iter[0]))

        n_iter[0] += 1
        return loss
      
    optimizer.step(closure)


Iteration: 0
Content loss: 2.2899782074761216e+16
Style loss  : 1.1602842338983936e+16
Total loss  : 3.450262334000333e+16
Iteration: 50
Content loss: 2.289742843268301e+16
Style loss  : 1.1596442837712896e+16
Total loss  : 3.449387234413773e+16
Iteration: 100
Content loss: 2.2895446305275904e+16
Style loss  : 1.1592011505205248e+16
Total loss  : 3.448745781048115e+16
Iteration: 150
Content loss: 2.2893264461889536e+16
Style loss  : 1.1587656408367104e+16
Total loss  : 3.448092087025664e+16
Iteration: 200
Content loss: 2.289069392396288e+16
Style loss  : 1.1583723292065792e+16
Total loss  : 3.4474418289770496e+16
Iteration: 250
Content loss: 2.288724721270784e+16
Style loss  : 1.1579184585375744e+16
Total loss  : 3.4466429650599936e+16
Iteration: 300
Content loss: 2.288334952988672e+16
Style loss  : 1.1573563546927104e+16
Total loss  : 3.445690985558835e+16
Iteration: 350
Content loss: 2.2878953630859264e+16
Style loss  : 1.15674743570432e+16
Total loss  : 3.4446427987902464e+16
Iterat

In [16]:
save_images(content_image.data[0].cpu().squeeze(), opt_img.data[0].cpu().squeeze(), style_image1.data[0].cpu().squeeze(), style_image2.data[0].cpu().squeeze(), image_size, output_path, n_iter, content_invert, style_invert, result_invert)