In [0]:
from PIL import Image
import numpy as np
from collections import namedtuple

import torch
from torchvision import models
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.onnx

#import argparse
import os
import sys
import time
import re

#Image Preprocess

In [0]:
def load_image(filename, size=None, scale=None):
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img


def save_image(filename, data):
    img = data.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)


def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram


def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std


#VGG Module

In [0]:
class Vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out


#TransformerNet

In [0]:
class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out


#Main function()

In [0]:
def check_paths(save_model_dir, checkpoint_model_dir = None):
    try:
        if not os.path.exists(save_model_dir):
            os.makedirs(save_model_dir)
        if checkpoint_model_dir is not None and not (os.path.exists(checkpoint_model_dir)):
            os.makedirs(checkpoint_model_dir)
    except OSError as e:
        print(e)
        sys.exit(1)


def train(dataset,style_image,save_model_dir,cuda):
    seed = 42
    image_size = 256
    style_size = None
    batch_size = 4
    lr = 1e-3
    epochs = 2
    content_weight = 1e5
    style_weight = 1e10
    log_interval = 500
    checkpoint_model_dir = None
    checkpoint_interval = 2000


    device = torch.device("cuda" if cuda else "cpu")
    
    np.random.seed(seed)
    torch.manual_seed(seed)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = load_image(style_image, size=style_size)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]

    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = str(time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)


def stylize(content_image, output_image, model, cuda, content_scale = None):
    device = torch.device("cuda" if cuda else "cpu")
    content_image = load_image(content_image, scale=content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    # if model.endswith(".onnx"):
    #     output = stylize_onnx_caffe2(content_image, model, cuda)
    # else:
    #     with torch.no_grad():
    #         style_model = TransformerNet()
    #         state_dict = torch.load(model)
    #         # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    #         for k in list(state_dict.keys()):
    #             if re.search(r'in\d+\.running_(mean|var)$', k):
    #                 del state_dict[k]
    #         style_model.load_state_dict(state_dict)
    #         style_model.to(device)
    #         # if export_onnx:
    #         #     assert export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
    #         #     output = torch.onnx._export(style_model, content_image, export_onnx).cpu()
    #         # else:
    #         #     output = style_model(content_image).cpu()
    #         output = style_model(content_image).cpu()
    
    with torch.no_grad():
      style_model = TransformerNet()
      state_dict = torch.load(model)
      # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
      for k in list(state_dict.keys()):
          if re.search(r'in\d+\.running_(mean|var)$', k):
              del state_dict[k]
      style_model.load_state_dict(state_dict)
      style_model.to(device)
      output = style_model(content_image).cpu()   
    
    save_image(output_image, output[0])


# def stylize_onnx_caffe2(content_image, model, cuda):
#     """
#     Read ONNX model and run it using Caffe2
#     """

#     #assert not args.export_onnx
#     !pip install onnx
#     !pip install caffe2.python.onnx
#     import onnx
#     import caffe2.python.onnx.backend as backend

#     model = onnx.load(model)

#     prepared_backend = backend.prepare(model, device='CUDA' if cuda else 'CPU')
#     inp = {model.graph.input[0].name: content_image.numpy()}
#     c2_out = prepared_backend.run(inp)[0]

#     return torch.from_numpy(c2_out)


#Train Custom Transfer Model

In [0]:
dataset = 'drive/My Drive/TEST/val2017'        #COCO 2017 Datasets mounted on My Google Drive
style_image = 'picasso.jpg'
save_model_dir = 'Model'
cuda = 1
check_paths(save_model_dir)
train(dataset,style_image,save_model_dir,cuda)

#Stylize Image or Video
## Model Configuration

In [0]:
model = 'Composition-VII.model'      #support .model, .pth
style_name = model.split('.')[0]
cuda = 1

## Image Test

In [0]:
Content_IMG = 'dancing.jpg'
Output_IMG = '{}_styled.jpg'.format(Content_IMG.split('.')[0])
stylize(Content_IMG, Output_IMG, model, cuda)

## Video Style Transfer

### Style Transfer

In [13]:
import cv2
video_path = 'Clip2.mp4'
video_name = video_path.split('.')[0]
cap = cv2.VideoCapture(video_path)
original_folder = '{}_{}/Original'.format(video_name,style_name)
styled_folder = '{}_{}/Styled'.format(video_name,style_name)
check_paths(original_folder)
check_paths(styled_folder)


count = 0
new_frames = []
while True:
    _, frame = cap.read()
    if frame is None:
        break
    cv2.imwrite(os.path.join(original_folder, 'Frame_{}.jpg'.format(count)), frame)
    #cv2.imwrite('Frame_{}.jpg'.format(count), frame)
    stylize(os.path.join(original_folder, 'Frame_{}.jpg'.format(count)),os.path.join(styled_folder, 'Frame_{}.jpg'.format(count)), model, cuda)
    
    #stylize('{}\Original\Frame_{}.jpg'.format(video_name, count), '{}\Styled\Frame_{}.jpg'.format(video_name, count), model, cuda)
    tmp = cv2.imread(os.path.join(styled_folder, 'Frame_{}.jpg'.format(count)))
    new_frames.append(tmp)
     
    count += 1
    if count%10 == 0:
      print('Frame:{}'.format(count))
#------------------    
cap.release()
print(count)



Frame:10
Frame:20
Frame:30
Frame:40
Frame:50
Frame:60
Frame:70
Frame:80
Frame:90
Frame:100
Frame:110
Frame:120
Frame:130
Frame:140
Frame:150
Frame:160
Frame:170
Frame:180
Frame:190
Frame:200
Frame:210
Frame:220
Frame:230
Frame:240
Frame:250
Frame:260
Frame:270
Frame:280
Frame:290
Frame:300
Frame:310
Frame:320
Frame:330
Frame:340
Frame:350
Frame:360
Frame:370
Frame:380
Frame:390
Frame:400
Frame:410
414


### Combine to video

In [0]:
fps = 30
shape = new_frames[0].shape
resolution = (shape[1],shape[0])
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
video_writer = cv2.VideoWriter(filename='{}_{}/styledVideo.mp4'.format(video_name, style_name),fourcc=fourcc, fps=fps,frameSize=resolution)
for frame in new_frames:
  video_writer.write(frame)
#cv2.waitKey(100)

video_writer.release()


### Resolution

In [0]:
resolution

## Zip Result Files to Download
(Need to mannually modify the name every time)

In [0]:
!zip 'Clip2_Composition-VIII.zip' 'Clip2_Composition-VIII' -r    # Need to manually change the Directory Name in codes
print('Done')