<a href="https://colab.research.google.com/github/lbenbaccar/Neural-Style-Transfer-using-CNN/blob/main/Neural_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Neural Algorithm of Artistic Style PyTorch Implementation
## **Libraries**

In [None]:
import os
from os import path
import copy
from sys import version_info
from collections import OrderedDict
import argparse
from PIL import Image
Image.MAX_IMAGE_PIXELS = 1000000000 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.model_zoo import load_url

# **Model downloads**

## Download the VGG-19 model and fix the layer names

In [None]:
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, "vgg19-d01eb7cb.pth")

## Download the VGG-16 model and fix the ayer names

In [None]:
sd = load_url("https://web.eecs.umich.edu/~justincj/models/vgg16-00b39a1b.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.items()])
torch.save(sd, "vgg16-00b39a1b.pth")

## Download the NIN model

In [None]:
if version_info[0] < 3:
  import urllib
  urllib.URLopener().retrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", "nin_imagenet.pth")
else: 
  import urllib.request
  urllib.request.urlretrieve("https://raw.githubusercontent.com/ProGamerGov/pytorch-nin/master/nin_imagenet.pth", "nin_imagenet.pth")

# **Options**


In [None]:
parser = argparse.ArgumentParser()

## Basic options

*   `-style_image`: path of the style image.
*   `-content_image`: path of the content image.
*   `-style_blend_weights`: weight for blending the style of multiple style images. *Format :* comma-separated list. *Example :* `-style_blend_weights` 2,6,7. (*by default :* all style images equally weighted)
*   `-image_size`: maximum side length in pixels of the generated image. (*by default :* 512)
*   `-gpu`: zero-indexed ID of the GPU to use.




In [None]:
parser.add_argument("-style_image", help="Style target image", default='style6.jpg')
parser.add_argument("-content_image", help="Content target image", default='ensae.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)

_StoreAction(option_strings=['-gpu'], dest='gpu', nargs=None, const=None, default=0, type=None, choices=None, help='Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c', metavar=None)

## Optimization options
For the optimization algorithm, L-BFGS tends to give better results, but uses more memory. ADAM reduce memory usage. Moreover, when using ADAM we need to play with other parameters to get good results, especially the style weight, content weight, and learning rate.

*   `-content_weight`: argument to know how much to weight the content reconstruction term. (*by default:* 5e0)
*   `-style_weight`: argument to know how much to weight the style reconstruction term. (*by default:* 1e2)
*   `-tv_weight`: weight of total-variation (TV) regularization to help smoothing the image. 0 to disable TV regularization. (*by default:* 1e-3)
*   `-num_iterations`: number of iterations. (*by default:* 1000)
*   `-init`: method for generating the generated image, one of `random` or `image`. (*by default:* `random` which uses a noise initialization; `image` initializes with the content image)
*   `-init_image`: replaces the initialization image with a user specified image.
*   `-optimizer`: optimization algorithm to use, `lbfgs` or `adam`. (*by default:* `lbfgs`)
*   `-learning_rate`: learning rate to use with the ADAM optimizer (*by default:* 1e1)
*   `-lbfgs_num_correction`: number of correction pairs stored. (*by default:* 100)
*   `-normalize_weights`: if present, weights will be L1 normalized.
*   `-normalize_gradients`: if present, style and content gradients from each layer will be L1 normalized.

In [None]:
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e1)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-normalize_gradients", action='store_true')

_StoreTrueAction(option_strings=['-normalize_gradients'], dest='normalize_gradients', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

## Output options

*   `-output_image`: name of the output image (*by default:* `out.png`)
*   `-print_iter`: print progress every `print_iter` iterations. 0 to disable printing.
*   `-save_iter`: save the image every `save_iter` iterations. Set to 0 to disable saving intermediate results.

In [None]:
parser.add_argument("-output_image", default='out5.png')
parser.add_argument("-print_iter", type=int, default=100)
parser.add_argument("-save_iter", type=int, default=0)

_StoreAction(option_strings=['-save_iter'], dest='save_iter', nargs=None, const=None, default=0, type=<class 'int'>, choices=None, help=None, metavar=None)

## Layer options
 
*   `-content_layers`: comma-separated list of layer names to use for content reconstruction (*by default:* `relu4_2`)
*   `-style_layers`: comma-separated list of layer names to use for style reconstruction (*by default:* `relu1_1`, `relu2_1`,`relu3_1`,`relu4_1`,`relu5_1`)

In [None]:
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')

_StoreAction(option_strings=['-style_layers'], dest='style_layers', nargs=None, const=None, default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', type=None, choices=None, help='layers for style', metavar=None)

## Other options
 
*   `-style_scale`: scale at which to extract features from the style image. (*by default:* 1.0)
*   `-original_colors`: if set to 1, then the output image will keep the colors of the content image.
*   `-model_file`: path to the `.pth` file for the VGG Caffe model. (*by default:* the original VGG-19 model)
*   `-pooling`: the type of pooling layers to use, one of `max` or `avg`. (*by default:* `max`). Indeed, VGG-19 models uses max pooling layers, but the paper mentions that replacing these layers with average pooling layers can improve the results.
*   `-seed`: an integer value that we can specify for repeatable results. (*by default:* random for each run)
*   `-multidevice_strategy`: a comma-separated list of layer indices at which to split the network when using multiple devices.
*   `-backend`: `nn`, `cudnn`, `openmp`, or `mkl`. (*by default:* `nn`)
*   `-cudnn_autotune`: when using the cuDNN backend, we pass this flag to use the built-in cuDNN autotuner to select the best convolution algorithms for the architecture. This will make the first iteration a bit slower and can take a bit more memory, but may significantly speed up the cuDNN backend.

In [None]:
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-model_file", type=str, default='vgg19-d01eb7cb.pth')
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-multidevice_strategy", default='4,7,29')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='cudnn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-disable_check", action='store_true')

_StoreTrueAction(option_strings=['-disable_check'], dest='disable_check', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [None]:
params, unknown = parser.parse_known_args()

# **Architecture classes**

## Architectures parameters

In [None]:
channel_list = {
'VGG-16p': [24, 22, 'P', 41, 51, 'P', 108, 89, 111, 'P', 184, 276, 228, 'P', 512, 512, 512, 'P'],
'VGG-16': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 'P', 512, 512, 512, 'P', 512, 512, 512, 'P'],
'VGG-19': [64, 64, 'P', 128, 128, 'P', 256, 256, 256, 256, 'P', 512, 512, 512, 512, 'P', 512, 512, 512, 512, 'P'],
}

vgg16_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu4_1', 'relu4_2', 'relu4_3', 'relu5_1', 'relu5_2', 'relu5_3'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}

vgg19_dict = {
'C': ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4', 'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'],
'R': ['relu1_1', 'relu1_2', 'relu2_1', 'relu2_2', 'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4'],
'P': ['pool1', 'pool2', 'pool3', 'pool4', 'pool5'],
}

nin_dict = {
'C': ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6', 'conv4-1024', 'cccp7-1024', 'cccp8-1024'],
'R': ['relu0', 'relu1', 'relu2', 'relu3', 'relu5', 'relu6', 'relu7', 'relu8', 'relu9', 'relu10', 'relu11', 'relu12'],
'P': ['pool1', 'pool2', 'pool3', 'pool4'],
'D': ['drop'],
}

## VGG-16 and VGG-19 architecture class

In [None]:
class VGG(nn.Module):

    def __init__(self, features, num_classes = 1000):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

## VGG-16 architecture class using the channel pruning model

In [None]:
class VGG_Pruned(nn.Module):
  
    def __init__(self, features, num_classes=1000):
        super(VGG_Pruned, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
        )

## VGG-16 architecture class using the fcn32s-heavy-pascal model

In [None]:
class VGG_Fully_Convolutional_Network_32S(nn.Module):
  
    def __init__(self, features, num_classes=1000):
        super(VGG_Fully_Convolutional_Network_32S, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Conv2d(512,4096,(7, 7)),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Conv2d(4096,4096,(1, 1)),
            nn.ReLU(True),
            nn.Dropout(0.5),
        )

## VGG-16 architecture class using the SOD fintune model

In [None]:
class VGG_Salient_Object_Detection(nn.Module):
  
    def __init__(self, features, num_classes=100):
        super(VGG_Salient_Object_Detection, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 100),
        )

## NIN architecture class

In [None]:
class Network_in_Network(nn.Module):

    def __init__(self, pooling):
        super(Network_in_Network, self).__init__()
        if pooling == 'max':
            pool2d = nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)
        elif pooling == 'avg':
            pool2d = nn.AvgPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True)

        self.features = nn.Sequential(
            nn.Conv2d(3,96,(11, 11),(4, 4)),
            nn.ReLU(inplace=True),
            nn.Conv2d(96,96,(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(96,96,(1, 1)),
            nn.ReLU(inplace=True),
            pool2d,
            nn.Conv2d(96,256,(5, 5),(1, 1),(2, 2)),
            nn.ReLU(inplace=True),
            nn.Conv2d(256,256,(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(256,256,(1, 1)),
            nn.ReLU(inplace=True),
            pool2d,
            nn.Conv2d(256,384,(3, 3),(1, 1),(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,(1, 1)),
            nn.ReLU(inplace=True),
            pool2d,
            nn.Dropout(0.5),
            nn.Conv2d(384,1024,(3, 3),(1, 1),(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024,1024,(1, 1)),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024,1000,(1, 1)),
            nn.ReLU(inplace=True),
            nn.AvgPool2d((6, 6),(1, 1),(0, 0),ceil_mode=True),
            nn.Softmax(),
        )

In [None]:
class ModelParallel(nn.Module):
    def __init__(self, net, device_ids, device_splits):
        super(ModelParallel, self).__init__()
        self.device_list = self.name_devices(device_ids.split(','))
        self.chunks = self.chunks_to_devices(self.split_net(net, device_splits.split(',')))

    def name_devices(self, input_list):
        device_list = []
        for i, device in enumerate(input_list):
            if str(device).lower() != 'c':
                device_list.append("cuda:" + str(device))
            else:
                device_list.append("cpu")
        return device_list

    def split_net(self, net, device_splits):
        chunks, cur_chunk = [], nn.Sequential()
        for i, l in enumerate(net):
            cur_chunk.add_module(str(i), net[i])
            if str(i) in device_splits and device_splits != '':
                del device_splits[0]
                chunks.append(cur_chunk)
                cur_chunk = nn.Sequential()
        chunks.append(cur_chunk)
        return chunks

    def chunks_to_devices(self, chunks):
        for i, chunk in enumerate(chunks):
            chunk.to(self.device_list[i])
        return chunks

    def c(self, input, i):
        if input.type() == 'torch.FloatTensor' and 'cuda' in self.device_list[i]:
            input = input.type('torch.cuda.FloatTensor')
        elif input.type() == 'torch.cuda.FloatTensor' and 'cpu' in self.device_list[i]:
            input = input.type('torch.FloatTensor')
        return input

    def forward(self, input):
        for i, chunk in enumerate(self.chunks):
            if i < len(self.chunks) -1:
                input = self.c(chunk(self.c(input, i).to(self.device_list[i])), i+1).to(self.device_list[i+1])
            else:
                input = chunk(input)
        return input

In [None]:
def buildSequential(channel_list, pooling):

    layers = []
    in_channels = 3

    if pooling == 'max':
        pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

    elif pooling == 'avg':
        pool2d = nn.AvgPool2d(kernel_size=2, stride=2)

    else:
        raise ValueError("Unrecognized pooling parameter")

    for c in channel_list:
        if c == 'P':
            layers = layers + [pool2d]
        else:
            conv2d = nn.Conv2d(in_channels, c, kernel_size=3, padding=1)
            layers = layers + [conv2d, nn.ReLU(inplace=True)]
            in_channels = c

    return nn.Sequential(*layers)

## Function to select the model to use 

In [None]:
def modelSelector(model_file, pooling):

    vgg_list = ["vgg", "fcn32s", "pruning", "sod"]

    if any(name in model_file for name in vgg_list):
        if "pruning" in model_file:
            print("VGG-16 Architecture Detected")
            print("Using The Channel Pruning Model")
            cnn, layerList = VGG_Pruned(buildSequential(channel_list['VGG-16p'], pooling)), vgg16_dict

        elif "fcn32s" in model_file:
            print("VGG-16 Architecture Detected")
            print("Using the fcn32s-heavy-pascal Model")
            cnn, layerList = VGG_Fully_Convolutional_Network_32S(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict

        elif "sod" in model_file:
            print("VGG-16 Architecture Detected")
            print("Using The SOD Fintune Model")
            cnn, layerList = VGG_Salient_Object_Detection(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict

        elif "19" in model_file:
            print("VGG-19 Architecture Detected")
            cnn, layerList = VGG(buildSequential(channel_list['VGG-19'], pooling)), vgg19_dict

        elif "16" in model_file:
            print("VGG-16 Architecture Detected")
            cnn, layerList = VGG(buildSequential(channel_list['VGG-16'], pooling)), vgg16_dict

        else:
            raise ValueError("VGG architecture not recognized.")

    elif "nin" in model_file:
        print("NIN Architecture Detected")
        cnn, layerList = Network_in_Network(pooling), nin_dict

    else:
        raise ValueError("Model architecture not recognized.")

    return cnn, layerList

## Function to print the caffe model

In [None]:
def print_caffe(cnn, layerList):
    c = 0
    for l in list(cnn):
         if "Conv2d" in str(l):
             in_c, out_c, ks  = str(l.in_channels), str(l.out_channels), str(l.kernel_size)
             print(layerList['C'][c] +": " +  (out_c + " " + in_c + " " + ks).replace(")",'').replace("(",'').replace(",",'') )
             c+=1
         if c == len(layerList['C']):
             break

## Function to load the model and configure pooling layer type

In [None]:
def load_caffe_model(model_file, pooling, use_gpu, disable_check):
    cnn, layerList = modelSelector(str(model_file).lower(), pooling)

    cnn.load_state_dict(torch.load(model_file), strict=(not disable_check))
    print("Successfully loaded " + str(model_file))

    # Convert the model to cuda
    if "c" not in str(use_gpu).lower() or "c" not in str(use_gpu[0]).lower():
        cnn = cnn.cuda()
    cnn = cnn.features

    print_caffe(cnn, layerList)

    return cnn, layerList

# **Other classes**
## Class to scale gradients in the backward pass

In [None]:
class ScaleGradients(torch.autograd.Function):
    @staticmethod
    def forward(self, input_tensor, strength):
        self.strength = strength
        return input_tensor

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        grad_input = grad_input / (torch.norm(grad_input, keepdim=True) + 1e-8)
        return grad_input * self.strength * self.strength, None

## Class to compute content loss

In [None]:
class ContentLoss(nn.Module):

    def __init__(self, strength, normalize):
        super(ContentLoss, self).__init__()
        self.strength = strength
        self.crit = nn.MSELoss()
        self.mode = 'None'
        self.normalize = normalize

    def forward(self, input):
        if self.mode == 'loss':
            loss = self.crit(input, self.target)
            if self.normalize:
                loss = ScaleGradients.apply(loss, self.strength)
            self.loss = loss * self.strength
        elif self.mode == 'capture':
            self.target = input.detach()
        return input

## Class to compute the Gram matrix

In [None]:
class GramMatrix(nn.Module):

    def forward(self, input):
        B, C, H, W = input.size()
        x_flat = input.view(C, H * W)
        return torch.mm(x_flat, x_flat.t())

## Class to compute style loss

In [None]:
class StyleLoss(nn.Module):

    def __init__(self, strength, normalize):
        super(StyleLoss, self).__init__()
        self.target = torch.Tensor()
        self.strength = strength
        self.gram = GramMatrix()
        self.crit = nn.MSELoss()
        self.mode = 'None'
        self.blend_weight = None
        self.normalize = normalize

    def forward(self, input):
        self.G = self.gram(input)
        self.G = self.G.div(input.nelement())
        if self.mode == 'capture':
            if self.blend_weight == None:
                self.target = self.G.detach()
            elif self.target.nelement() == 0:
                self.target = self.G.detach().mul(self.blend_weight)
            else:
                self.target = self.target.add(self.blend_weight, self.G.detach())
        elif self.mode == 'loss':
            loss = self.crit(self.G, self.target)
            if self.normalize:
                loss = ScaleGradients.apply(loss, self.strength)
            self.loss = self.strength * loss
        return input

## Class to compute total-variation loss

In [None]:
class TVLoss(nn.Module):

    def __init__(self, strength):
        super(TVLoss, self).__init__()
        self.strength = strength

    def forward(self, input):
        self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
        self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
        self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
        return input

# **Setup**

In [None]:
def setup_gpu():
    def setup_cuda():
        if 'cudnn' in params.backend:
            torch.backends.cudnn.enabled = True
            if params.cudnn_autotune:
                torch.backends.cudnn.benchmark = True
        else:
            torch.backends.cudnn.enabled = False

    def setup_cpu():
        if 'mkl' in params.backend and 'mkldnn' not in params.backend:
            torch.backends.mkl.enabled = True
        elif 'mkldnn' in params.backend:
            raise ValueError("MKL-DNN is not supported yet.")
        elif 'openmp' in params.backend:
            torch.backends.openmp.enabled = True

    multidevice = False
    if "," in str(params.gpu):
        devices = params.gpu.split(',')
        multidevice = True

        if 'c' in str(devices[0]).lower():
            backward_device = "cpu"
            setup_cuda(), setup_cpu()
        else:
            backward_device = "cuda:" + devices[0]
            setup_cuda()
        dtype = torch.FloatTensor

    elif "c" not in str(params.gpu).lower():
        setup_cuda()
        dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
    else:
        setup_cpu()
        dtype, backward_device = torch.FloatTensor, "cpu"
    return dtype, multidevice, backward_device

In [None]:
def setup_multi_device(net):
    assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
      "The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."

    new_net = ModelParallel(net, params.gpu, params.multidevice_strategy)
    return new_net

# **Other functions**


## Preprocess an image before passing it to a model.
We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, and subtract the mean pixel.

In [None]:
def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
    Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
    tensor = Normalize(rgb2bgr(Loader(image) * 255)).unsqueeze(0)
    return tensor

## Cancel the previous preprocessing

In [None]:
def deprocess(output_tensor):
    Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
    bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
    output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 255
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image

## Combine the Y channel of the generated image and the UV/CbCr channels of the content image to perform color-independent style transfer.

In [None]:
def original_colors(content, generated):
    content_channels = list(content.convert('YCbCr').split())
    generated_channels = list(generated.convert('YCbCr').split())
    content_channels[0] = generated_channels[0]
    return Image.merge('YCbCr', content_channels).convert('RGB')

## Configure the optimizer

In [None]:
def setup_optimizer(img):
    if params.optimizer == 'lbfgs':
        print("Running optimization with L-BFGS")
        optim_state = {
            'max_iter': params.num_iterations,
            'tolerance_change': -1,
            'tolerance_grad': -1,
        }
        if params.lbfgs_num_correction != 100:
            optim_state['history_size'] = params.lbfgs_num_correction
        optimizer = optim.LBFGS([img], **optim_state)
        loopVal = 1
    elif params.optimizer == 'adam':
        print("Running optimization with ADAM")
        optimizer = optim.Adam([img], lr = params.learning_rate)
        loopVal = params.num_iterations - 1
    return optimizer, loopVal

## Print torch

In [None]:
def print_torch(net, multidevice):
    if multidevice:
        return
    simplelist = ""
    for i, layer in enumerate(net, 1):
        simplelist = simplelist + "(" + str(i) + ") -> "
    print("nn.Sequential ( \n  [input -> " + simplelist + "output]")

    def strip(x):
        return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
    def n():
        return "  (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]

    for i, l in enumerate(net, 1):
         if "2d" in str(l):
             ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
             if "Conv2d" in str(l):
                 ch = str(l.in_channels) + " -> " + str(l.out_channels)
                 print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
             elif "Pool2d" in str(l):
                 st = st.replace("  ",' ') + st.replace(", ",')')
                 print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
         else:
             print(n())
    print(")")

## Print progress every `print_iter` iterations if `print_iter` not 0.

In [None]:
def print_iterations(t, loss):
    if params.print_iter > 0 and t % params.print_iter == 0:
        print("Iteration " + str(t) + " / "+ str(params.num_iterations))
        for i, loss_module in enumerate(content_losses):
            print("  Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
        for i, loss_module in enumerate(style_losses):
            print("  Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
        print("  Total loss: " + str(loss.item()))

## Save the image every `save_iter` iterations if `save_iter` not 0.

In [None]:
def save_image(t):
    should_save = params.save_iter > 0 and t % params.save_iter == 0
    should_save = should_save or t == params.num_iterations
    if should_save:
        output_filename, file_extension = os.path.splitext(params.output_image)
        if t == params.num_iterations:
            filename = output_filename + str(file_extension)
        else:
            filename = str(output_filename) + "_" + str(t) + str(file_extension)
        disp = deprocess(img.clone())

        # Maybe perform postprocessing for color-independent style transfer
        if params.original_colors == 1:
            disp = original_colors(deprocess(content_image.clone()), disp)

        disp.save(str(filename))

## Function to evaluate loss and gradient 
We run the net forward and backward to get the gradient, and sum up losses from the loss modules.
`optim.lbfgs` internally handles iteration and calls this function many times, so we manually count the number of iterations to handle printing and saving intermediate results.

In [None]:
num_calls = [0]
def feval():
    num_calls[0] += 1
    optimizer.zero_grad()
    net(img)
    loss = 0

    for mod in content_losses:
        loss += mod.loss.to(backward_device)
    for mod in style_losses:
        loss += mod.loss.to(backward_device)
    if params.tv_weight > 0:
        for mod in tv_losses:
            loss += mod.loss.to(backward_device)

    loss.backward()

    save_image(num_calls[0])
    print_iterations(num_calls[0], loss)

    return loss

## Divide weights by channel size

In [None]:
def normalize_weights(content_losses, style_losses):
    for n, i in enumerate(content_losses):
        i.strength = i.strength / max(i.target.size())
    for n, i in enumerate(style_losses):
        i.strength = i.strength / max(i.target.size())

# **Experimentation**

In [None]:
dtype, multidevice, backward_device = setup_gpu()

In [None]:
cnn, layerList = load_caffe_model(params.model_file, params.pooling, params.gpu, params.disable_check)

VGG-19 Architecture Detected
Successfully loaded vgg19-d01eb7cb.pth
conv1_1: 64 3 3 3
conv1_2: 64 64 3 3
conv2_1: 128 64 3 3
conv2_2: 128 128 3 3
conv3_1: 256 128 3 3
conv3_2: 256 256 3 3
conv3_3: 256 256 3 3
conv3_4: 256 256 3 3
conv4_1: 512 256 3 3
conv4_2: 512 512 3 3
conv4_3: 512 512 3 3
conv4_4: 512 512 3 3
conv5_1: 512 512 3 3
conv5_2: 512 512 3 3
conv5_3: 512 512 3 3
conv5_4: 512 512 3 3


In [None]:
content_image = preprocess(params.content_image, params.image_size).type(dtype)

In [None]:
style_image_input = params.style_image.split(',')
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
for image in style_image_input:
    if os.path.isdir(image):
        images = (image + "/" + file for file in os.listdir(image)
        if os.path.splitext(file)[1].lower() in ext)
        style_image_list.extend(images)
    else:
        style_image_list.append(image)
style_images_caffe = []
for image in style_image_list:
    style_size = int(params.image_size * params.style_scale)
    img_caffe = preprocess(image, style_size).type(dtype)
    style_images_caffe.append(img_caffe)

In [None]:
if params.init_image != None:
    image_size = (content_image.size(2), content_image.size(3))
    init_image = preprocess(params.init_image, image_size).type(dtype)

Handle style blending weights for multiple style inputs

In [None]:
style_blend_weights = []
if params.style_blend_weights == None:
    # Style blending not specified, so use equal weighting
    for i in style_image_list:
        style_blend_weights.append(1.0)
    for i, blend_weights in enumerate(style_blend_weights):
        style_blend_weights[i] = int(style_blend_weights[i])
else:
    style_blend_weights = params.style_blend_weights.split(',')
    assert len(style_blend_weights) == len(style_image_list), \
      "-style_blend_weights and -style_images must have the same number of elements!"

Normalize the style blending weights so they sum to 1

In [None]:
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
    style_blend_weights[i] = float(style_blend_weights[i])
    style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
    style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)

In [None]:
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')

Set up the network, inserting style and content loss modules

In [None]:
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net = nn.Sequential()
c, r = 0, 0

if params.tv_weight > 0:
    tv_mod = TVLoss(params.tv_weight).type(dtype)
    net.add_module(str(len(net)), tv_mod)
    tv_losses.append(tv_mod)

In [None]:
for i, layer in enumerate(list(cnn), 1):
    if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
        if isinstance(layer, nn.Conv2d):
            net.add_module(str(len(net)), layer)

            if layerList['C'][c] in content_layers:
                print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
                loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
                net.add_module(str(len(net)), loss_module)
                content_losses.append(loss_module)

            if layerList['C'][c] in style_layers:
                print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
                loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
                net.add_module(str(len(net)), loss_module)
                style_losses.append(loss_module)
            c+=1

        if isinstance(layer, nn.ReLU):
            net.add_module(str(len(net)), layer)

            if layerList['R'][r] in content_layers:
                print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
                loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
                net.add_module(str(len(net)), loss_module)
                content_losses.append(loss_module)
                next_content_idx += 1

            if layerList['R'][r] in style_layers:
                print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
                loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
                net.add_module(str(len(net)), loss_module)
                style_losses.append(loss_module)
                next_style_idx += 1
            r+=1

        if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
            net.add_module(str(len(net)), layer)

Setting up style layer 2: relu1_1
Setting up style layer 7: relu2_1
Setting up style layer 12: relu3_1
Setting up style layer 21: relu4_1
Setting up content layer 23: relu4_2
Setting up style layer 30: relu5_1


In [None]:
if multidevice:
    net = setup_multi_device(net)

Capture content targets

In [None]:
for i in content_losses:
    i.mode = 'capture'
print("Capturing content targets")
print_torch(net, multidevice)
net(content_image)

Capturing content targets
nn.Sequential ( 
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> (25) -> (26) -> (27) -> (28) -> (29) -> (30) -> (31) -> (32) -> (33) -> (34) -> (35) -> (36) -> (37) -> output]
  (1): nn.TVLoss
  (2): nn.Conv2d(3 -> 64, 3x3, 1,1, 1,1)
  (3): nn.ReLU
  (4): nn.StyleLoss
  (5): nn.Conv2d(64 -> 64, 3x3, 1,1, 1,1)
  (6): nn.ReLU
  (7): nn.MaxPool2d(2x2, 2,2)
  (8): nn.Conv2d(64 -> 128, 3x3, 1,1, 1,1)
  (9): nn.ReLU
  (10): nn.StyleLoss
  (11): nn.Conv2d(128 -> 128, 3x3, 1,1, 1,1)
  (12): nn.ReLU
  (13): nn.MaxPool2d(2x2, 2,2)
  (14): nn.Conv2d(128 -> 256, 3x3, 1,1, 1,1)
  (15): nn.ReLU
  (16): nn.StyleLoss
  (17): nn.Conv2d(256 -> 256, 3x3, 1,1, 1,1)
  (18): nn.ReLU
  (19): nn.Conv2d(256 -> 256, 3x3, 1,1, 1,1)
  (20): nn.ReLU
  (21): nn.Conv2d(256 -> 256, 3x3, 1,1, 1,1)
  (22): nn.ReLU
  (23): nn.MaxPool2d(2x2, 2,2)
  (

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.6878e+02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 9.5606e+01,  ..., 1.2843e+02,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 9.6218e+01,  ..., 0.0000

Capture style targets

In [None]:
for i in content_losses:
    i.mode = 'None'

for i, image in enumerate(style_images_caffe):
    print("Capturing style target " + str(i+1))
    for j in style_losses:
        j.mode = 'capture'
        j.blend_weight = style_blend_weights[i]
    net(style_images_caffe[i])

Capturing style target 1


Set all loss modules to loss mode

In [None]:
for i in content_losses:
    i.mode = 'loss'
for i in style_losses:
    i.mode = 'loss'

Normalize content and style weights

In [None]:
if params.normalize_weights:
    normalize_weights(content_losses, style_losses)

Freeze the network in order to prevent unnecessary gradient calculations

In [None]:
for param in net.parameters():
    param.requires_grad = False

Initialize the image

In [None]:
if params.seed >= 0:
    torch.manual_seed(params.seed)
    torch.cuda.manual_seed_all(params.seed)
    torch.backends.cudnn.deterministic=True
if params.init == 'random':
    B, C, H, W = content_image.size()
    img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
elif params.init == 'image':
    if params.init_image != None:
        img = init_image.clone()
    else:
        img = content_image.clone()
img = nn.Parameter(img)

In [None]:
optimizer, loopVal = setup_optimizer(img)
while num_calls[0] <= loopVal:
      optimizer.step(feval)

Running optimization with L-BFGS
Iteration 100 / 1000
  Content 1 loss: 1512603.5
  Style 1 loss: 17190.15234375
  Style 2 loss: 81972.3359375
  Style 3 loss: 18371.2890625
  Style 4 loss: 257689.890625
  Style 5 loss: 516.2354736328125
  Total loss: 1895934.625
Iteration 200 / 1000
  Content 1 loss: 1390582.5
  Style 1 loss: 7014.34033203125
  Style 2 loss: 23515.283203125
  Style 3 loss: 8332.5927734375
  Style 4 loss: 223809.15625
  Style 5 loss: 587.9150390625
  Total loss: 1661314.125
Iteration 300 / 1000
  Content 1 loss: 1356689.375
  Style 1 loss: 2466.635986328125
  Style 2 loss: 9472.1416015625
  Style 3 loss: 6263.751953125
  Style 4 loss: 219881.34375
  Style 5 loss: 620.3319091796875
  Total loss: 1602279.0
Iteration 400 / 1000
  Content 1 loss: 1345573.25
  Style 1 loss: 1046.33837890625
  Style 2 loss: 5879.7841796875
  Style 3 loss: 5727.53466796875
  Style 4 loss: 218761.53125
  Style 5 loss: 632.0095825195312
  Total loss: 1583992.0
Iteration 500 / 1000
  Content 1 lo