In [0]:
!pip install torch==1.4.0 torchvision==0.5.0

Collecting torch==1.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4MB)
[K     |█████▉                          | 138.3MB 1.5MB/s eta 0:06:39

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import torchvision.models as models

import numpy as np
import cv2
import os
import time

In [42]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = xm.xla_device()
print(device)

# import my Google Drive for saved models
from google.colab import drive
drive.mount('/content/drive')

base_path = '/content/drive/My Drive/Colab Notebooks/mrf-cnn'
content_path = base_path + '/content-images'
style_path = base_path + '/style-images'
output_path = base_path + '/output'

cuda:0
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# based on https://github.com/jonzhaocn/cnnmrf-pytorch/blob/master/mylibs.py but logic is my own.

# The Content Loss class just does an MSE loss against the content image (Formula 4 from paper)
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()
        self.loss = None

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

    def update(self, target):
        self.target = target.detach()

In [0]:
# based on https://github.com/jonzhaocn/cnnmrf-pytorch/blob/master/mylibs.py

# The Style Loss class handles the MRF loss function against the style image (Formula 2,3 from paper)
class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.patch_size = (3,3)
        self.stride = 1 #set to 2 because 1 crashed system.
        self.gpu_chunk_size = 512
        self.loss = None

        # this can be pre-computed as it doesn't change during the forward pass
        self.style_patches = self.patches_sampling(target.detach(), patch_size=self.patch_size, stride=self.stride)
        self.style_patches_norm = self.cal_patches_norm()
        self.style_patches_norm = self.style_patches_norm.view(-1, 1, 1)

    def update(self, target):
        # we do however have to update it when the image resolution changes
        self.style_patches = self.patches_sampling(target.detach(), patch_size=self.patch_size,stride=self.stride)
        self.style_patches_norm = self.cal_patches_norm()
        self.style_patches_norm = self.style_patches_norm.view(-1, 1, 1)

    def forward(self, input):
        synthesis_patches = self.patches_sampling(input, patch_size=self.patch_size, stride=self.stride)
        max_response = []
        for i in range(0, self.style_patches.shape[0], self.gpu_chunk_size):
            i_start = i
            i_end = min(i+self.gpu_chunk_size, self.style_patches.shape[0])
            weight = self.style_patches[i_start:i_end, :, :, :]
            response = F.conv2d(input, weight, stride=self.stride)
            max_response.append(response.squeeze(dim=0))
        max_response = torch.cat(max_response, dim=0)

        max_response = max_response.div(self.style_patches_norm)
        max_response = torch.argmax(max_response, dim=0)
        max_response = torch.reshape(max_response, (1, -1)).squeeze()
        # loss
        loss = 0
        for i in range(0, len(max_response), self.gpu_chunk_size):
            i_start = i
            i_end = min(i+self.gpu_chunk_size, len(max_response))
            tp_ind = tuple(range(i_start, i_end))
            sp_ind = max_response[i_start:i_end]
            loss += torch.sum(torch.mean(torch.pow(synthesis_patches[tp_ind, :, :, :]-self.style_patches[sp_ind, :, :, :], 2), dim=[1, 2, 3]))
        self.loss = loss/len(max_response)
        return input

    def patches_sampling(self, image, patch_size, stride):
        h, w = image.shape[2:4]
        patches = []
        for i in range(0, h - patch_size[0] + 1, stride):
            for j in range(0, w - patch_size[1] + 1, stride):
                patches.append(image[:, :, i:i + patch_size[0], j:j + patch_size[1]])
        patches = torch.cat(patches, dim=0).to(device)
        return patches

    def cal_patches_norm(self):
        # norm of style image patches
        norm_array = torch.zeros(self.style_patches.shape[0])
        for i in range(self.style_patches.shape[0]):
            norm_array[i] = torch.pow(torch.sum(torch.pow(self.style_patches[i], 2)), 0.5)
        return norm_array.to(device)

In [0]:
# This is the regularizer from the paper.  It's the squared gradient norm (Formula 5 from paper) which is used to smooth the image
class Regularizer(nn.Module):
    def __init__(self):
        super(Regularizer, self).__init__()
        self.loss = None
        self.unnormalize = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

    def forward(self, input):
        image = self.unnormalize(input.clone().squeeze()).permute([1, 2, 0])
        xij_1 = torch.cat((image[1:, :, :], image[-1, :, :].unsqueeze(0)), dim=0)
        xij_1 = xij_1 - image
        xi_1j = torch.cat((image[:, 1:, :], image[:, -1, :].unsqueeze(1)), dim=1)
        xi_1j = xi_1j - image

        self.loss = torch.sum(torch.pow(xij_1, 2) + torch.pow(xi_1j, 2))
        return input

In [0]:
class MRF_CNN(nn.Module):
    def __init__(self, style_image, content_image):
        super(MRF_CNN, self).__init__()
        self.alpha1 = .5
        self.alpha2 = 0.01
        self.style_layers = [13, 22] # 13 is relu3_1, 22 is relu4_1 as per https://www.mathworks.com/help/deeplearning/ref/vgg19.html
        self.content_layers = [24] # 24 is relu4_2

        # build the model based on vgg19 and insert the custom content and style loss layers
        vgg = models.vgg19(pretrained=True).to(device)
        model = nn.Sequential()
        content_losses = []
        style_losses = []
        regularizer = Regularizer()
        model.add_module('regularizer', regularizer)

        for i in range(len(vgg.features)):
            # add layer of vgg19
            layer = vgg.features[i]
            name = str(i)
            model.add_module(name, layer)

            # add content loss layer
            if i in self.content_layers:
                target = model(content_image).detach()
                content_loss = ContentLoss(target)
                model.add_module("content_loss_" + name, content_loss)
                content_losses.append(content_loss)

            # add style loss layer
            if i in self.style_layers:
                target_feature = model(style_image).detach()
                style_loss = StyleLoss(target_feature)
                model.add_module("style_loss_" + name, style_loss)
                style_losses.append(style_loss)

        self.model = model
        self.content_losses = content_losses
        self.style_losses = style_losses
        self.regularizer = regularizer

    def forward(self, image):
        self.model(image)
        style_loss = 0
        content_loss = 0

        # calculate losses
        for x in self.style_losses:
            style_loss += x.loss
        for x in self.content_losses:
            content_loss += x.loss
        loss = style_loss + (self.alpha1 * content_loss) + (self.alpha2 * self.regularizer.loss)
        return loss

    def update(self, style_image, content_image):
        # update the target of style loss layer
        x = style_image.clone()
        next_style_idx = 0
        i = 0
        for layer in self.model:
            if isinstance(layer, Regularizer) or isinstance(layer, ContentLoss) or isinstance(layer, StyleLoss):
                continue
            x = layer(x)
            if i in self.style_layers:
                # extract feature of style image in vgg19 as style loss target
                self.style_losses[next_style_idx].update(x)
                next_style_idx += 1
            i += 1

        # update the target of content loss layer
        x = content_image.clone()
        next_content_idx = 0
        i = 0
        for layer in self.model:
            if isinstance(layer, Regularizer) or isinstance(layer, ContentLoss) or isinstance(layer, StyleLoss):
                continue
            x = layer(x)
            if i in self.content_layers:
                # extract feature of content image in vgg19 as content loss target
                self.content_losses[next_content_idx].update(x)
                next_content_idx += 1
            i += 1

In [0]:
# from https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

In [0]:
# this function finds all the images in content and style path and then encodes them
def get_and_transform_images(content_path, style_path):

    # make sure paths exist first and they contain files
    if not os.path.isdir(content_path):
        raise ValueError('directory %s does not exist.' % content_path)
    if not os.path.isdir(style_path):
        raise ValueError('directory %s does not exist.' % style_path)
    if len(os.listdir(content_path) ) == 0:
        raise ValueError('directory %s is empty.' % content_path)
    if len(os.listdir(style_path) ) == 0: 
        raise ValueError('directory %s is empty.' % style_path)

    # Pretrained models must have images at least 3x224x224 in size and must be 
    # normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # Source: https://pytorch.org/docs/stable/torchvision/models.html
    transform = T.Compose([T.ToTensor(),T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

    # read all the images in the folder and then convert to RGB and apply the transform
    content_images = []
    for cimg in os.listdir(content_path):
      if os.path.isdir(cimg):
        continue
      print(cimg)
      img = cv2.imread(content_path + "/" + cimg)
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      img = transform(img).unsqueeze(0).to(device)
      content_images.append(img)

    style_images = []
    for simg in os.listdir(style_path):
      if os.path.isdir(simg):
        continue
      print(simg)
      img = cv2.imread(style_path + "/" + simg)
      img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      img = transform(img).unsqueeze(0).to(device)
      style_images.append(img)

    # need to build an image pyramid using scaling factor of two and stop when longest dimension is less than 64 pixels
    images_pyramid_content = []
    for img in content_images:
      newimg = img
      img_pyramid = []
      img_pyramid.append(newimg)
      images_pyramid_content.append(img_pyramid)

    images_pyramid_style = []
    for img in style_images:
      newimg = img
      img_pyramid = []
      img_pyramid.append(newimg)
      images_pyramid_style.append(img_pyramid)

    return images_pyramid_content, images_pyramid_style

In [49]:
# the main function that runs the algorithm

global iteration
max_iterations = 200
content_images, style_images = get_and_transform_images(content_path, style_path)

unnormalize = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

for x, content_image in enumerate(content_images):
  for y, style_image in enumerate(style_images):

    timer = float(time.time())

    iteration = 0
    synthesized_image = None

    # create the model
    model = MRF_CNN(style_image=style_image[0], content_image=content_image[0]).to(device)
    model.train()

    synthesized_image = content_image[0].clone().requires_grad_(True).to(device)

    def closure():
      global iteration
      optimizer.zero_grad()
      loss = model(synthesized_image)
      loss.backward()

      # save image every after each set of iterations
      if (iteration + 1) % 50 == 0:
          image = unnormalize(synthesized_image.clone().squeeze())
          image = F.interpolate(image.unsqueeze(0), size=content_image[0].shape[2:4], mode='bilinear', align_corners=True)
          torchvision.utils.save_image(image.squeeze(), output_path + '/image-c%d-s%d-it%d.jpg' % (x, y, iteration + 1))
          print('save image: image-c%d-s%d-it%d.jpg loss: %f' % (x, y, iteration + 1, loss.item()))

      iteration += 1

      if iteration == max_iterations:
          iteration = 0
      return loss

    optimizer = optim.LBFGS([synthesized_image], max_iter=max_iterations)
    optimizer.step(closure)

    print("Execution Time: %6.3f" % (float(time.time()) - timer))

content.jpg
style.jpg
save image: image-c0-s0-it50.jpg loss: 28.839828
save image: image-c0-s0-it100.jpg loss: 26.996178
save image: image-c0-s0-it150.jpg loss: 25.985758
save image: image-c0-s0-it200.jpg loss: 25.251431
Execution Time: 332.495
