In [None]:
# Домашнее задание выполнил Ионкин К.А.
# тема - Neural Style Transfer

In [None]:
import numpy as np
import scipy as sp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transform as transforms
import torchvision.models as models

import copy

import PIL
from PIL import Image
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [None]:
# Перенесем необходимый функцонал с  семинарского ноутбука

In [None]:
# естественно необохимо работат на GPU для ускорения процесса обучения
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# перед подачей изображения в модель, необходимо проихвести предобработку
imsize = 128  

loader = transforms.Compose([
    transforms.Resize(imsize), 
    transforms.CenterCrop(imsize),
    transforms.ToTensor()])

unloader = transforms.ToPILImage()

def image_loader(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

# функця отрисовки изображений
def imshow(tensor, title=None):
    image = tensor.cpu().clone()   
    image = image.squeeze(0)
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    

In [None]:
# Загружаем фотки стилей и фотку контента
style1_img = 
style2_img =
content_img = 

# и сразу же отобразим их


In [None]:
# Создаем классы StyleLoss и ContentLoss

# класс ContentLoss остается без изменений
class ContentLoss(nn.Module):

        def __init__(self, target,):
            super(ContentLoss, self).__init__()
            # we 'detach' the target content from the tree used
            # to dynamically compute the gradient: this is a stated value,
            # not a variable. Otherwise the forward method of the criterion
            # will throw an error.
            self.target = target.detach()#это константа. Убираем ее из дерева вычеслений
            self.loss = F.mse_loss(self.target, self.target )#to initialize with something

        def forward(self, input_image):
            self.loss = F.mse_loss(input_image, self.target)
            
            return input_image

In [None]:
# функция получения матрицы грамма
def gram_matrix(input_image):
        batch_size, f_map_num, h, w  = input_image.size()  # batch size(=1)
        
        features = input_image.view(batch_size * f_map_num, h * w) 

        G = torch.mm(features, features.t())  # compute the gram product

        # we 'normalize' the values of the gram matrix
        # by dividing by the number of element in each feature maps.
        return G.div(batch_size * f_map_num * h * w)

In [None]:
class StyleLoss(nn.Module):
    """
    h - horizontal split
    V - vertical split
    
    """
    
    def __init__(self, target_feature, mode_split="h", style="style1"):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()
        self.loss = F.mse_loss(self.target, self.target)
        self.mode_split = mode_split
        self.style = style

    def forward(self, input_image):
        
        mask_image = StyleLoss.split_dict[self.mode_split](self, input_image)
        
        G = gram_matrix(mask_image)
        self.loss = F.mse_loss(G, self.target)
        
        return input_image
    
    def horizontal_split(self, input_image):
        
        mask = torch.zeros_like(input_image)
        _, _, h, _ = mask.size()
        
        if self.style == "style1":
            mask[:, :, :h//2, :] +=1
            
        if self.style2 == "style2":
            mask[:, :, h//2:, :] +=1
        
        return input_image * mask
    
    def vertical_split(self, input_image):
        
        mask = torch.zeros_like(input_image)
        _, _, _, w = mask.size()
        
        if self.style == "style1":
            mask[:, :, :, :w//2] +=1
            
        if self.style2 == "style2":
            mask[:, :, :, w//2:] +=1
        
        return input_image * mask
    
    split_dict = {
        "h": horizontal_split,  
        "v": vertical_split
    }

In [None]:
content_layers_default = ['conv_4'] # слой для вывода content_loss
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] # слои для вывода соответствующего style_loss

In [None]:
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

In [None]:
class Normalization(nn.Module):
        def __init__(self, mean, std):
            super(Normalization, self).__init__()
            # .view the mean and std to make them [C x 1 x 1] so that they can
            # directly work with image Tensor of shape [B x C x H x W].
            # B is batch size. C is number of channels. H is height and W is width.
            self.mean = mean.view(-1, 1, 1)
            self.std = std.view(-1, 1, 1)

        def forward(self, img):
            # normalize img
            return (img - self.mean) / self.std

In [None]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()

In [None]:
# Теперь в нашей модели подряд идут 2 слоя StyleLoss для каждого из стилей
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                                   style1_img, style2_img, content_img, mode_split="h",
                                   content_layers=content_layers_default,
                                   style_layers=style_layers_default):
        cnn = copy.deepcopy(cnn)

        # normalization module
        normalization = Normalization(normalization_mean, normalization_std).to(device)

        # just in order to have an iterable access to or list of content/syle
        # losses
        content_losses = []
        style1_losses = []
        style2_losses = []

        # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
        # to put in modules that are supposed to be activated sequentially
        model = nn.Sequential(normalization) # мы создаем нашу модель на основе модели vgg19, но сначала нормализуем входный данные, так как в vgg19 поступают норм. данные

        i = 0  # increment every time we see a conv
        for layer in cnn.children():
            if isinstance(layer, nn.Conv2d):
                i += 1
                name = 'conv_{}'.format(i)
            elif isinstance(layer, nn.ReLU):
                name = 'relu_{}'.format(i)
                # The in-place version doesn't play very nicely with the ContentLoss
                # and StyleLoss we insert below. So we replace with out-of-place
                # ones here.
                #Переопределим relu уровень
                layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                name = 'pool_{}'.format(i)
            elif isinstance(layer, nn.BatchNorm2d):
                name = 'bn_{}'.format(i)
            else:
                raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

            model.add_module(name, layer)

            if name in content_layers:
                # add content loss:
                target = model(content_img).detach()
                content_loss = ContentLoss(target)
                model.add_module("content_loss_{}".format(i), content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                # add style1 loss
                target1_feature = model(style1_img).detach()
                style1_loss = StyleLoss(target1_feature, mode_split, style="style1") # TO DO
                model.add_module("style1_loss_{}".format(i), style1_loss)
                style1_losses.append(style1_loss)
                
                # add style2 loss
                target2_feature = model(style2_img).detach()
                style2_loss = StyleLoss(target2_feature, mode_split, style="style2") # TO DO
                model.add_module("style2_loss_{}".format(i), style2_loss)
                style2_losses.append(style2_loss)

        # now we trim off the layers after the last content and style losses
        #выбрасываем все уровни после последенего style loss или content loss
        for i in range(len(model) - 1, -1, -1):
            if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
                break

        model = model[:(i + 1)]

        return model, style1_losses, style2_losses, content_losses

In [None]:
def get_input_optimizer(input_img):
        # this line to show that input is a parameter that requires a gradient
        #добоваляет содержимое тензора катринки в список изменяемых оптимизатором параметров
        optimizer = optim.LBFGS([input_img.requires_grad_()]) 
        return optimizer

In [None]:
def run_style_transfer(cnn, normalization_mean, normalization_std,
                        input_img, style1_img, style2_img, content_img, mode_split="h", num_steps=500,
                        style1_weight=1e6, style2_weight=1e6 content_weight=1):
        """Run the style transfer."""
        
        print('Building the style transfer model..')
        model, style1_losses, style2_losses, content_losses = get_style_model_and_losses(cnn,
            normalization_mean, normalization_std, style1_img, style2_img, content_img, mode_split)
        optimizer = get_input_optimizer(input_img)

        print('Optimizing..')
        run = [0]
        while run[0] <= num_steps:

            def closure():
                # correct the values 
                # это для того, чтобы значения тензора картинки не выходили за пределы [0;1]
                input_img.data.clamp_(0, 1)

                optimizer.zero_grad()

                model(input_img)

                style1_score = 0
                style2_score = 0
                content_score = 0

                for sl in style1_losses:
                    style1_score += sl.loss    
                for sl in style2_losses:
                    style2_score += sl.loss    
                for cl in content_losses:
                    content_score += cl.loss
                
                #взвешивание ощибки
                style1_score *= style1_weight
                style2_score *= style2_weight
                content_score *= content_weight

                loss = style1_score + style2_score + content_score
                loss.backward()

                run[0] += 1
                if run[0] % 50 == 0:
                    print("run {}:".format(run))
                    print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                        style_score.item(), content_score.item()))
                    print()

                return loss

            optimizer.step(closure)

        # a last correction...
        input_img.data.clamp_(0, 1)

        return input_img