In [None]:
# download coco dataset to colab
from google.colab import drive
drive.mount('/content/drive')
%cd  /content/drive/'My Drive'/CV_project/train

!wget http://images.cocodataset.org/zips/train2014.zip

!unzip train2014.zip

In [None]:
# change directory path
%cd  /content/drive/'My Drive'/CV_project

Network:

In [None]:
import torch
import torch.nn as nn
from torch.nn.modules.instancenorm import InstanceNorm2d

class TransformerNet(nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()

        self.ConvBlock = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            nn.ReLU(),
            ConvLayer(32, 64, 3, 2),
            nn.ReLU(),
            ConvLayer(64, 128, 3, 2),
            nn.ReLU()
        )

        self.ResidualBlock = nn.Sequential(
            ResidualLayer(128, 3),
            ResidualLayer(128, 3),
            ResidualLayer(128, 3),
            ResidualLayer(128, 3),
            ResidualLayer(128, 3)
        )

        self.UpsampleBlock = nn.Sequential(
            UpsampleConvLayer(128, 64, 3, 1, 2),
            nn.InstanceNorm2d(64, affine=True),
            UpsampleConvLayer(64, 32, 3, 1, 2),
            nn.InstanceNorm2d(32, affine=True),
            ConvLayer(32, 3, 9, 1, norm='None')
        )

    def forward(self, x):
        x = self.ConvBlock(x)
        x = self.ResidualBlock(x)
        x = self.UpsampleBlock(x)
        return x

class TransformerResNextNetwork(nn.Module):
    '''
    Feedforward Transformation Network - ResNeXt
    '''
    def __init__(self):
        super(TransformerResNextNetwork, self).__init__()
        self.ConvBlock = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            nn.ReLU(),
            ConvLayer(32, 64, 3, 2),
            nn.ReLU(),
            ConvLayer(64, 128, 3, 2),
            nn.ReLU()
        )
        self.ResidualBlock = nn.Sequential(
            ResNextLayer(128, [64, 64, 128], kernel_size=3),
            ResNextLayer(128, [64, 64, 128], kernel_size=3),
            ResNextLayer(128, [64, 64, 128], kernel_size=3),
            ResNextLayer(128, [64, 64, 128], kernel_size=3),
            ResNextLayer(128, [64, 64, 128], kernel_size=3)
        )
        self.DeconvBlock = nn.Sequential(
            DeconvLayer(128, 64, 3, 2, 1),
            nn.ReLU(),
            DeconvLayer(64, 32, 3, 2, 1),
            nn.ReLU(),
            ConvLayer(32, 3, 9, 1, norm='None')
        )

    def forward(self, x):
        x = self.ConvBlock(x)
        x = self.ResidualBlock(x)
        out = self.DeconvBlock(x)
        return out

class TransformerNetworkDenseNet(nn.Module):
    '''
    Feedforward Transformer Network using DenseNet Block instead of Residual Block
    '''
    def __init__(self):
        super(TransformerNetworkDenseNet, self).__init__()
        self.ConvBlock = nn.Sequential(
            ConvLayerNB(3, 32, 9, 1),
            nn.ReLU(),
            ConvLayerNB(32, 64, 3, 2),
            nn.ReLU(),
            ConvLayerNB(64, 128, 3, 2),
            nn.ReLU()
        )
        self.DenseBlock = nn.Sequential(
            NormReluConv(128, 64, 1, 1),
            DenseLayerBottleNeck(64, 16),
            DenseLayerBottleNeck(80, 16),
            DenseLayerBottleNeck(96, 16),
            DenseLayerBottleNeck(112, 16)
        )
        self.DeconvBlock = nn.Sequential(
            DeconvLayer(128, 64, 3, 2, 1),
            nn.ReLU(),
            DeconvLayer(64, 32, 3, 2, 1),
            nn.ReLU(),
            ConvLayer(32, 3, 9, 1, norm='None')
        )

    def forward(self, x):
        x = self.ConvBlock(x)
        x = self.DenseBlock(x)
        out = self.DeconvBlock(x)
        return out

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, norm='instance'):
        super(ConvLayer, self).__init__()
        # padding layer
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        # convolution layer
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

        # normalization layer
        self.norm_type = norm
        if (norm == 'instance'):
            self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif (norm == 'batch'):
            self.norm_layer = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, x):
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        if (self.norm_type == 'None'):
            out = x
        else:
            out = self.norm_layer(x)
        return out

class ConvLayerNB(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, norm='instance'):
        super(ConvLayerNB, self).__init__()
        # Padding Layers
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        # Convolution Layer
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=False)

        # Normalization Layers
        self.norm_type = norm
        if (norm == 'instance'):
            self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif (norm == 'batch'):
            self.norm_layer = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, x):
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        if (self.norm_type == 'None'):
            out = x
        else:
            out = self.norm_layer(x)
        return out

class NormLReluConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(NormLReluConv, self).__init__()

        # Normalization Layers
        self.norm_layer = nn.InstanceNorm2d(in_channels, affine=True)

        # ReLU Layer
        self.relu_layer = nn.ReLU()

        # Padding Layers
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        # Convolution Layer
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=False)

    def forward(self, x):
        x = self.norm_layer(x)
        x = self.relu_layer(x)
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        return x

class NormReluConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, norm='instance'):
        super(NormReluConv, self).__init__()

        # Normalization Layers
        if (norm == 'instance'):
            self.norm_layer = nn.InstanceNorm2d(in_channels, affine=True)
        elif (norm == 'batch'):
            self.norm_layer = nn.BatchNorm2d(in_channels, affine=True)

        # ReLU Layer
        self.relu_layer = nn.ReLU()

        # Padding Layers
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)

        # Convolution Layer
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x = self.norm_layer(x)
        x = self.relu_layer(x)
        x = self.reflection_pad(x)
        x = self.conv_layer(x)
        return x

class ResidualLayer(nn.Module):
    '''
    Reference: Deep Residual Learning for Image Recognition
    https://arxiv.org/abs/1512.03385
    '''
    def __init__(self, channels=128, kernel_size=3):
        super(ResidualLayer, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1)
        self.relu = nn.ReLU()
        self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1)

    def forward(self, x):
        residual = x                        # preserve residual
        out = self.relu(self.conv1(x))      # 1st conv layer and activation
        out = self.conv2(out)               # 2nd conv layer
        out = out + residual                # add in residual
        return out

class ResNextLayer(nn.Module):
    '''
    Aggregated Residual Transformations for Deep Neural Networks
        Equal to better performance with 10x less parameters
    https://arxiv.org/abs/1611.05431
    '''
    def __init__(self, in_ch=128, channels=[64, 64, 128], kernel_size=3):
        super(ResNextLayer, self).__init__()
        ch1, ch2, ch3 = channels
        self.conv1 = ConvLayer(in_ch, ch1, kernel_size=1, stride=1)
        self.relu1 = nn.ReLU()
        self.conv2 = ConvLayer(ch1, ch2, kernel_size=kernel_size, stride=1)
        self.relu2 = nn.ReLU()
        self.conv3 = ConvLayer(ch2, ch3, kernel_size=1, stride=1)

    def forward(self, x):
        identity = x
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.conv3(out)
        out = out + identity
        return out

class UpsampleConvLayer(nn.Module):
    '''
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    Reference: http://distill.pub/2016/deconv-checkerboard/
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        padding_size = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(padding_size)
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv_layer(out)
        return out

class DeconvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm='instance'):
        super(DeconvLayer, self).__init__()

        # Transposed Convolution
        padding_size = kernel_size // 2
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding)

        # Normalization Layers
        self.norm_type = norm
        if (norm == 'instance'):
            self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif (norm == 'batch'):
            self.norm_layer = nn.BatchNorm2d(out_channels, affine=True)

    def forward(self, x):
        x = self.conv_transpose(x)
        if (self.norm_type == 'None'):
            out = x
        else:
            out = self.norm_layer(x)
        return out

class DenseLayerBottleNeck(nn.Module):
    '''
    NORM - RELU - CONV1 -> NORM - RELU - CONV3
    out_channels = Growth Rate
    '''
    def __init__(self, in_channels, out_channels):
        super(DenseLayerBottleNeck, self).__init__()

        self.conv1 = NormLReluConv(in_channels, 4*out_channels, 1, 1)
        self.conv3 = NormLReluConv(4*out_channels, out_channels, 3, 1)

    def forward(self, x):
        out = self.conv3(self.conv1(x))
        out = torch.cat((x, out), 1)
        return out

VGG:

In [None]:
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models, transforms

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        pretrained_features = models.vgg16(pretrained=True).features
        self.layer1 = nn.Sequential()
        self.layer2 = nn.Sequential()
        self.layer3 = nn.Sequential()
        self.layer4 = nn.Sequential()
        for i in range(4):
            self.layer1.add_module(str(i), pretrained_features[i])
        for i in range(4, 9):
            self.layer2.add_module(str(i), pretrained_features[i])
        for i in range(9, 16):
            self.layer3.add_module(str(i), pretrained_features[i])
        for i in range(16, 23):
            self.layer4.add_module(str(i), pretrained_features[i])

        # Disable gradient history
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, X):
        x = self.layer1(X)
        x_relu1_2 = x
        x = self.layer2(x)
        x_relu2_2 = x
        x = self.layer3(x)
        x_relu3_3 = x
        x = self.layer4(x)
        x_relu4_3 = x
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(x_relu1_2, x_relu2_2, x_relu3_3, x_relu4_3)
        return out

Util:

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

def load_image(filename):
    img = Image.open(filename).convert('RGB')
    return img

def show_image(img):
    plt.figure(figsize=(10,5))
    plt.imshow(img)

def save_image(filename, data):
    img = data.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)

def gram_matrix(tensor):
    b, ch, h, w = tensor.shape
    x = tensor.view(b, ch, h * w)
    x_t = x.transpose(1,2)
    return torch.bmm(x, x_t) / (ch * h * w)

def normalize_batch(batch):
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std

Style and Main:

Style Image:
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --cuda 0

Train Model:
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --cuda 1



In [None]:
# paths to style image, content image, and directory to save model to
style_image_path = './style_image/mosaic.jpg'
content_image_path = './content_image/amber.jpg'
model_path = './models/model_mosaic.pth'
train_image_path = './coco_dataset'

# path to save output image to
output_image_path = './output/mosaic_amber.jpg'

In [None]:
# set number of epochs
num_epochs = 2

In [None]:
# set batch size
batch_size = 4

In [None]:
# set learning rate
learning_rate = 1e-3

In [None]:
# set log interval: number of images after which the training loss is logged
log_interval = 500

In [None]:
# set checkpoint interval: number of batches after which a checkpoint of the trained model will be created
checkpoint_interval = 2000

In [None]:
# set command for main function
command = "train"
#command = "stylize"

In [None]:
# set command for transformer network type
#network = "transformer_net"
#network = "transformer_resnext"
network = "transformer_dense"

In [None]:
import time
import re

import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

def train():
    np.random.seed(42)
    torch.manual_seed(42)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(train_image_path, transform)
    indices = torch.arange(2000)
    train_dataset_2k = torch.utils.data.Subset(train_dataset, indices)
    train_loader_2k = DataLoader(train_dataset_2k, batch_size=batch_size)
    # train_loader = DataLoader(train_dataset, batch_size=batch_size)

    if network == "transformer_dense":
      transformer = TransformerNetworkDenseNet()
    elif network == "transformer_resnext":
      transformer = TransformerResNextNetwork()
    else:
      transformer = TransformerNet()

    optimizer = Adam(transformer.parameters(), learning_rate)
    mse_loss = torch.nn.MSELoss()

    vgg = VGG16()

    # get style features
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = load_image(style_image_path)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]

    for e in range(num_epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader_2k):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x
            y = transformer(x)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = 1e5 * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= 1e10

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset_2k),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

    # save model
    transformer.eval().cpu()
    torch.save(transformer.state_dict(), model_path)

    print("\nDone, trained model saved at", model_path)


def stylize():
    content_image = load_image(content_image_path)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model_path)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model
        output = style_model(content_image).cpu()

    save_image(output_image_path, output[0])

def main():
    if command == "train":
        train()
    else:
        stylize()


if __name__ == "__main__":
    main()