In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable

In [None]:
# some functions to deal with image
def imload(image_name,**kwargs):
    # a function to load image and transfer to Pytorch Variable.
    image = Image.open(image_name)
    if 'resize' in kwargs:
        resize = transforms.Scale(kwargs['resize'])
        image = resize(image)
    transform = transforms.Compose([
        transforms.ToTensor(),#Converts (H x W x C) of[0, 255] to (C x H x W) of range [0.0, 1.0]. 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
    image = Variable(transform(image),volatile=True)
    image = image.unsqueeze(0) 
    return image

def imshow(img):
    # convert torch tensor to PIL image and then show image inline.
    img=toImage(img[0].data*0.5+0.5) # denormalize tensor before convert
    plt.imshow(img,aspect=None)
    plt.axis('off')
    plt.gcf().set_size_inches(8, 8)
    plt.show()

def imsave(img,path):
    # convert torch tensor to PIL image and then save to path
    img=toImage(img[0].data*0.5+0.5) # denormalize tensor before convert
    img.save(path)

In [None]:
class FeatureExtracter(nn.Module):
    # a nn.Module class to extract a intermediate activation of a Torch module
    def __init__(self,submodule):
        super().__init__()
        self.submodule = submodule
    def forward(self,image,layers):
        features = []
        for i in range(layers[-1]+1):
            image = self.submodule[i](image)
            if i in layers :
                features.append(image)       
        return features

In [None]:
class GramMatrix(nn.Module):
    # a nn.Module class to build gram matrix as style feature
    def forward(self,style_features):
        gram_features=[]
        for feature in style_features:
            n,f,h,w = feature.size()
            feature = feature.resize(n*f,h*w)
            gram_features.append((feature@feature.t()).div_(2*n*f*h*w))
        return gram_features

In [None]:
class Stylize(nn.Module): 
    # net 
    def forward(self,x):
        s_feats = feature(x,STYLE_LAYER)
        s_feats = gram(s_feats)
        c_feats = feature(x,CONTENT_LAYER)
        return s_feats,c_feats

In [None]:
def totalloss(style_refs,content_refs,style_features,content_features,style_weight,content_weight):
    # compute total loss 
    style_loss = [l2loss(style_features[i],style_refs[i]) for i in range(len(style_features))] 
    #a small trick to balance the influnce of diffirent style layer
    mean_loss = sum(style_loss).data[0]/len(style_features)
    
    style_loss = sum([(mean_loss/l.data[0])*l*STYLE_LAYER_WEIGHTS[i] 
                    for i,l in enumerate(style_loss)])/len(style_features) 
    
    content_loss = sum([l2loss(content_features[i],content_refs[i]) 
                    for i in range(len(content_refs))])/len(content_refs)
    total_loss = style_weight*style_loss+content_weight*content_loss
    return total_loss

the definition of the the correction factor is here , where l_i is the style loss for layer i:
\begin{equation*}
\ l_{mean} = \frac{\sum_{k=1}^n l_i}{n} \\
\ f_{i} = \frac{l_{mean}}{l_i} \\
\ loss = \sum_{k=1}^n l_i  f_i
\end{equation*}

In [None]:
def reference(style_img,content_img):
    # a function to compute style and content refenrences as used for loss
    style_refs = feature(style_img,STYLE_LAYER)
    style_refs = gram(style_refs)
    style_refs = [Variable(i.data) for i in style_refs]
    content_refs = feature(content_img,CONTENT_LAYER)
    content_refs = [Variable(i.data) for i in content_refs]
    return style_refs,content_refs

In [None]:
# init paramters
learning_rate = 1e-1
style_weight = 1#
content_weight = 1e-3
num_iters = 500
report_intvl = 20

# load  pretrained squeezeNet and use the first sequential
model = models.squeezenet1_1(pretrained=True)
submodel = next(model.children())

# load image
style_img = imload("img_data/style/starry_night.jpg",resize = 224)
content_img = imload("img_data/img/hohnsensee2.jpg")

# set net parameter
STYLE_LAYER =[1,2,3,4,6,7,9]# could add more,maximal to 12
STYLE_LAYER_WEIGHTS = [21,21,1,1,1,7,7]# this should be small length as STYLE_LAYER
CONTENT_LAYER = [1,2,3]

# build net component
gram = GramMatrix()
feature = FeatureExtracter(submodel)
l2loss = nn.MSELoss(size_average=False)
stylize = Stylize()
toImage = transforms.ToPILImage()

# init a trainable img
train_img = Variable(torch.randn(content_img.size()),requires_grad = True)

# optimizer
optimizer = optim.Adam([train_img], lr = learning_rate)

# tracers
loss_history = [] 
min_loss = float("inf")
best_img = 0

# forward
style_refs,content_refs = reference(style_img,content_img)

Start = datetime.now()
for i in range(num_iters):
    
    optimizer.zero_grad()

    train_img.data.clamp_(-1,1)  # useful at first several step

    style_features,content_features = stylize(train_img)

    loss = totalloss(style_refs,content_refs,style_features,content_features,style_weight,content_weight)

    loss.backward()

    loss_history.append(loss.data[0])

    # save best result before update train_img
    
    if min_loss > loss_history[-1]:
            min_Loss = loss_history[-1]
            best_img = train_img

    optimizer.step()
    
    # report loss and image  
    if i % report_intvl == 0:
        print("step: %d loss: %f,time per example:%s s" 
              %(i,loss_history[-1],(datetime.now()-Start)/report_intvl))
        Start = datetime.now()
        imshow(train_img)

plt.plot(loss_history)
plt.show()
#print(train_img[0])

In [None]:
# train another time with smaller learning rate
optimizer = optim.Adam([train_img], lr = learning_rate/10)
Start = datetime.now()
for i in range(num_iters):
   
    optimizer.zero_grad()
    
    train_img.data.clamp_(-1,1)  # useful at first several step
    
    style_features,content_features = stylize(train_img)
    
    loss = totalloss(style_refs,content_refs,style_features,content_features,style_weight,content_weight)
    
    loss.backward()
    
    loss_history.append(loss.data[0])
    
    # save best result before update train_img
    if min_loss > loss_history[-1]:
        min_Loss = loss_history[-1]
        best_img = train_img

    optimizer.step()
    
    # report loss and image  
    if i % report_intvl == 0:
        print("step: %d loss: %f,time per example:%s s" 
              %(i,loss_history[-1],(datetime.now()-Start)/report_intvl))
        Start = datetime.now()
        imshow(train_img)

plt.plot(loss_history)
plt.show()

In [None]:
imshow(Best_Img)
imsave(Best_Img,"img_data/output/s_ms03.jpg")

In [None]:
content_img = imload("img_data/img/shanghai2.jpg",resize=600)
toImage(content_img[0].data*0.5+0.5).size