In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image

from torchvision import transforms
from torchvision import models

import matplotlib.pyplot as plt
from torch.optim import Adam



In [None]:
content_image_name = "face1.jpg"
style_image_name = "scr.jpg"
# Names for saved output images:
save_image_as = "created_image.jpg"
color_preserved_image = "color_preserved_image.jpg"

random_start_image = False

iterations = 300
learning_rate = 0.1
# number of iterations between showing losses:
show_loss_rate = 10

Adam_on = False
adam_lr=0.005
adam_wd=0.1

#content-loss weight:
alpha = 1  
#style-loss weight:
beta = 1000 

# layers that we want output from the vgg19 network
# The layer weights are relative and will be normalized to sum to 1
content_layers = [21] # [21] -To remember the standard values we used when we experiment
content_layers_weight = np.array([1]) #np.array([1])
style_layers =  [0,5,10,19,28]    # [0,5,10,19,28]
style_layers_weight = np.array([1,1,1,1,1]) #np.array([1,1,1,1,1])

vgg_image_size = (512,512)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

tensor_transform = transforms.Compose([transforms.Resize(vgg_image_size),transforms.ToTensor()])
                                        
                                          
def load_image(image_path):
    image = Image.open(image_path)
    return tensor_transform(image).to(device)


content_image = load_image(content_image_name)
content_image_size = Image.open(content_image_name).size
content_image_size = (content_image_size[1],content_image_size[0])


style_image = load_image(style_image_name)


In [None]:
vgg19_model = models.vgg19(pretrained=True).features.to(device).eval()
print(vgg19_model)


In [None]:
content_layers_weight = content_layers_weight/np.sum(content_layers_weight)
style_layers_weight = style_layers_weight/np.sum(style_layers_weight)

output_layers = content_layers + style_layers
output_layers = list(set(output_layers))
output_layers.sort()

def extract_features(module, x, y):
    features.append(y)
    return

for i, layer in enumerate(vgg19_model):
    
    if i in output_layers:
        vgg19_model[i].register_forward_hook(extract_features)
    
    # "A Neural Algorithm of Artistic Style" by Gatys et Al. recommends using average pooling layers instead of max pooling.
    if isinstance(layer, nn.MaxPool2d):
        vgg19_model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)

for parameter in vgg19_model.parameters():
    parameter.requires_grad = False
    
print(vgg19_model)

In [None]:
def style_content_features(features,output_layers,content_layers,style_layers):
    content_features = []
    style_features = []
    for i, layer in enumerate(output_layers):
        if layer in content_layers:
            content_features.append(features[i])
        if layer in style_layers:
            style_features.append(features[i])
    return content_features, style_features

In [None]:
features = []

vgg19_model(content_image.unsqueeze(0))

content_target = style_content_features(features,output_layers,content_layers,style_layers)[0]

content_target = [c.detach() for c in content_target]

def gram_matrix(x):
    
    b, c, h, w = x.size()
    V = x.view(b,c,h*w)
    gram_mat = torch.bmm(V, V.transpose(1,2))/(h*w)
    return gram_mat

features = []

vgg19_model(style_image.unsqueeze(0))
style_target = style_content_features(features,output_layers,content_layers,style_layers)[1]
style_target = [s.detach() for s in style_target]
gram_target =[gram_matrix(i).detach() for i in style_target]


In [None]:
if random_start_image:
    image = torch.randn(3,vgg_image_size[0],vgg_image_size[1]).cuda().requires_grad_()
else:
    image = content_image.clone().requires_grad_()

if Adam_on:
    optimizer = Adam([image],lr=adam_lr,weight_decay=adam_wd)
else:
    optimizer = torch.optim.LBFGS([image],lr=learning_rate)

mse_loss = nn.MSELoss(reduction='mean')

In [None]:
def restore_image(image):
    restore_image = transforms.Compose([transforms.Lambda(lambda x: x.clamp(0,1)),transforms.Resize(content_image_size)])
    to_image = transforms.ToPILImage()
    restored_picture = restore_image(image.squeeze().cpu().detach())
    return np.array(to_image(restored_picture))

def save_image(image_name,image):
    plt.imsave(image_name,restore_image(image))


In [None]:
content_loss_memory = []
style_loss_memory = []
    
for i in range(iterations):
    features = []
    def closure(features=features,i=i):
        optimizer.zero_grad()
        vgg19_model(image.unsqueeze(0))
        features = features[-1*len(output_layers):]
        content_features, style_features = style_content_features(features,output_layers,content_layers,style_layers)
        gram_styles = [gram_matrix(i) for i in style_features]
        content_loss = 0
        for i in range(len(content_features)):
            content_loss += alpha * content_layers_weight[i] * mse_loss(content_features[i], content_target[i])/2
        style_loss = 0
        for i in range(len(gram_styles)):
            style_loss += beta * style_layers_weight[i] * mse_loss(gram_styles[i],gram_target[i])/4
        total_loss = content_loss + style_loss
        content_loss_memory.append(content_loss)
        style_loss_memory.append(style_loss)

        total_loss.backward()
    
        return total_loss
    
    optimizer.step(closure)
    
    
    if (i % show_loss_rate == 0):
        print(f"content loss:{content_loss_memory[-1]:.6f}, style loss:{style_loss_memory[-1]:.6f}")    
        save_image("SoFar.jpg",image)

In [None]:
save_image(save_image_as,image)

In [None]:
import cv2

def restore_color(color_source, color_destination):
    
    c_source = np.moveaxis(color_source,0,2)
    c_destination = np.moveaxis(color_destination,0,2)
    
    # extract luminance of destination:
    gray_destination = cv2.cvtColor(c_destination, cv2.COLOR_BGR2GRAY) 
    # convert source from BGR to YIQ-YCbCr:
    yiq_source = cv2.cvtColor(c_source, cv2.COLOR_BGR2YCrCb)           
    # combining destinations luminance and sources IQ-CbCr:
    yiq_source[...,0] = gray_destination                               
    # converting new image from YIQ back to BGR:
    return cv2.cvtColor(yiq_source, cv2.COLOR_YCrCb2BGR)               

color_image = restore_image(image)
color_image = load_image(save_image_as)
color_image_np = color_image.cpu().numpy()
content_image_np = content_image.cpu().numpy()
restored_image = torch.tensor(np.moveaxis(restore_color(content_image_np,color_image_np),2,0))
save_image(color_preserved_image,restored_image)
