### Real- Time Style Transfer Using ResNet

In [None]:
import os
import sys
import time
import re
import pickle
import torchvision as tv
import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.onnx
from matplotlib import pyplot as plt
import utils
from vgg import Vgg16
import nntools as nt
from torch import nn
import torchvision.models as models

#### Transformation Net Class Definition - Baseline Implementation

In [None]:
class transformation_net_batch(torch.nn.Module):
    def __init__(self):
        super(transformation_net_batch, self).__init__()
        
        # Initial Downsampling layers

        self.conv1 = Conv_Layer(3, 32, 9, stride=1)
        self.bn1 = torch.nn.BatchNorm2d(32, affine=True)

        self.conv2 = Conv_Layer(32, 64, 3, stride=2)
        self.bn2 = torch.nn.BatchNorm2d(64, affine=True)

        self.conv3 = Conv_Layer(64, 128, 3, stride=2)
        self.bn3 = torch.nn.BatchNorm2d(128, affine=True)

        # Residual layers
        self.res1 = Res_block_Batch(128)
        self.res2 = Res_block_Batch(128)
        self.res3 = Res_block_Batch(128)
        self.res4 = Res_block_Batch(128)
        self.res5 = Res_block_Batch(128)
        
        # Final Upsampling Layers
        self.deconv1 = Conv_Layer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.bn4 = torch.nn.BatchNorm2d(64, affine=True)

        self.deconv2 = Conv_Layer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.bn5 = torch.nn.BatchNorm2d(32, affine=True)

        self.deconv3 = Conv_Layer(32, 3, kernel_size=9, stride=1)
    
        # Residual Layers
        
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.bn1(self.conv1(X)))
        y = self.relu(self.bn2(self.conv2(y)))
        y = self.relu(self.bn3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.bn4(self.deconv1(y)))
        y = self.relu(self.bn5(self.deconv2(y)))
        y = self.deconv3(y)
        return y
    
class Res_block_Batch(torch.nn.Module):
    
    def __init__(self, channels):
        super(Res_block_Batch, self).__init__()
        self.conv1 = Conv_Layer(channels, channels, kernel_size=3, stride=1)
        self.bn1 = torch.nn.BatchNorm2d(channels, affine=True)
        self.conv2 = Conv_Layer(channels, channels, kernel_size=3, stride=1)
        self.bn2 = torch.nn.BatchNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        h = self.relu(self.bn1(self.conv1(x)))
        h = self.bn2(self.conv2(h))
        h = h + x
        return h


#### Transformation Net Class Definition - Replacing BatchNorm with Instance Normalization Implementation

In [None]:
class transformation_net_instance(torch.nn.Module):
    def __init__(self):
        super(transformation_net_instance, self).__init__()
        
        # Initial Downsampling layers

        self.conv1 = Conv_Layer(3, 32, 9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)

        self.conv2 = Conv_Layer(32, 64, 3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)

        self.conv3 = Conv_Layer(64, 128, 3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)

        # Residual layers
        self.res1 = Res_block_Instance(128)
        self.res2 = Res_block_Instance(128)
        self.res3 = Res_block_Instance(128)
        self.res4 = Res_block_Instance(128)
        self.res5 = Res_block_Instance(128)
        
        # Final Upsampling Layers
        self.deconv1 = Conv_Layer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)

        self.deconv2 = Conv_Layer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)

        self.deconv3 = Conv_Layer(32, 3, kernel_size=9, stride=1)
    
        # Residual Layers
        
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y

class Res_block_Instance(torch.nn.Module):
    
    def __init__(self, channels):
        super(Res_block_Instance, self).__init__()
        self.conv1 = Conv_Layer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = Conv_Layer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        h = self.relu(self.in1(self.conv1(x)))
        h = self.in2(self.conv2(h))
        h = h + x
        return h

#### Convolutional Layer for Transformation Net

In [None]:
class Conv_Layer(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(Conv_Layer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        h = self.reflection_pad(x_in)
        h = self.conv(h)
        return h

#### Training the Transformation Net based on the Arguments

In [None]:
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    if args.Normalization == 'instance':
        transformer = transformation_net_instance().to(device)
    else:
        transformer = transformation_net_batch().to(device)
    
    optimizer = Adam(transformer.parameters(), args.lr)
    mse = torch.nn.MSELoss()

    resnet = tv.models.resnet18(pretrained = True).to(device)
    
    for param in resnet.parameters():
        param.requires_grad_(False)
        
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    style_img = utils.imread(args.style_image, size=args.style_size)
    style_img = style_transform(style_img)
    style_img = style_img.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = utils.get_resnet_features(resnet, style_img)
    gram_style = [utils.gram_matrix(features_style[y]) for y in features_style]
    
    plot_index = 0 
    x_index = []
    y_index = []

    for e in range(args.epochs):
        transformer.train()
        total_content_loss = 0.
        total_style_loss = 0.
        count = 0
        
        for batch_index, (content_batch, _) in enumerate(train_loader):
            
            plot_index += 1
            n_batch = len(content_batch)
            count += n_batch
            optimizer.zero_grad()

            content_batch = content_batch.to(device)
            
            y = transformer(content_batch)
            
            y = utils.normalize_imageset(y)
            content_batch = utils.normalize_imageset(content_batch)

            features_y = utils.get_resnet_features(resnet, content_batch)
            features_x = utils.get_resnet_features(resnet, content_batch)
        
            content_loss = args.content_weight * mse(features_y['layer4'], features_x['layer4'])

            style_loss = 0.
            
            for f_y, g_s in zip(features_y, gram_style):
                g_y = utils.gram_matrix(features_y[f_y])
                style_loss += mse(g_y, g_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            
            total_loss.backward()
            optimizer.step()

            total_content_loss += content_loss.item()
            total_style_loss += style_loss.item()
            
            total_loss = (total_content_loss + total_style_loss) / (batch_index + 1)

            if (batch_index + 1) % args.log_interval == 0:
                message = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  total_content_loss / (batch_index + 1),
                                  total_style_loss / (batch_index + 1), total_loss
                )
                print(message)
                
            y_index.append(total_loss)
            x_index.append(plot_index)

            if args.checkpoint_dir is not None and (batch_index + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_index_" + str(batch_index + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()
            

    # save model
    
    transformer.eval().cpu()
    save_model_filename = "model_1.pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
    
    with open('x_index_model1.pkl', 'wb') as f:
        pickle.dump(x_index, f)
    with open('y_index_model1.pkl', 'wb') as f:
        pickle.dump(y_index, f)


#### Testing Function Definition

In [None]:
def test_stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.imread(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        if args.Normalization == 'instance':
            style_model = transformation_net_instance()
        else:
            style_model = transformation_net_batch()
        style_model.eval()
        state_dict = torch.load(args.model)

        print ("Found the model!")
        if args.Normalization == 'instance':
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
                if re.search(r'conv\d+\.conv2d.(weight|bias)$', k):
                    k1 = k.replace("2d", "")
                    state_dict[k1] = state_dict[k]
                    del state_dict[k]

        style_model.load_state_dict(state_dict)
        style_model.to(device)
        print ("Loaded the model")

        output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])

    print ("Saved image")

#### Training Arguments

In [None]:
class Args_train:
    epochs = 2
    batch_size = 4
    dataset_path = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/train2014_new/'
    style_image = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/images/style-images/starry_night.jpg'
    save_model_dir = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/saved_models'
    checkpoint_dir = '../saved_models'
    image_size = 256
    style_size = None
    cuda = 1
    content_weight = 1e5
    style_weight = 1e10
    lr = 1e-3
    log_interval = 500
    checkpoint_interval = 2000
    Normalization = 'instance'

#### Training the model with the arguments 

In [None]:
args = Args_train()
train(args)

#### Plotting the Total Loss Function 

In [None]:
with open('x_index_model1.pkl', 'rb') as f:
...   x_axis = pickle.load(f)

with open('y_index_model1.pkl', 'rb') as f:
...   y_axis = pickle.load(f)

In [None]:
plt.plot(xaxis,yaxis)
plt.xlabel('Training Samples ----->')
plt.ylabel('Total Loss----->')

#### Testing Arguments

In [None]:

class Args_eval:
    content_image = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/images/style-images/rain-princess.jpg'
    content_scale = None
    output_image = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/images/output-images/output_autumn_starry_in.jpg'
    model = '/datasets/home/65/465/ssreekri/examples/fast_neural_style/saved_models/candy.pth'
    cuda = 1
    Normalization = 'instance'


#### Testing the Style Transfer

In [None]:
args=Args_eval()
test_stylize(args)