In [1]:
""" Downloading Dataset"""
!wget http://images.cocodataset.org/zips/val2017.zip
!mkdir './Dataset_new'
!unzip -q ./val2017.zip -d './Dataset_new'

--2023-11-18 14:53:03--  http://images.cocodataset.org/zips/val2017.zip

Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.114.241, 52.216.51.89, 52.217.122.17, ...

Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.114.241|:80... connected.

HTTP request sent, awaiting response... 200 OK

Length: 815585330 (778M) [application/zip]

Saving to: ‘val2017.zip’






2023-11-18 14:53:54 (15.7 MB/s) - ‘val2017.zip’ saved [815585330/815585330]




In [3]:
""" Downloading content and style images """
!mkdir ./content
!mkdir ./style
!wget -q https://github.com/myelinfoundry-2019/challenge/raw/master/japanese_garden.jpg -P './content'
!wget -q https://github.com/myelinfoundry-2019/challenge/raw/master/picasso_selfportrait.jpg -P './style'

In [4]:

import torch
from torch.autograd import Variable
from collections import namedtuple
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import os
import shutil
from sklearn.model_selection import train_test_split
import sys
import random
from PIL import Image
import glob
from torch.optim import Adam
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import cv2
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42) #for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mean and standard deviation used for training
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [6]:
# Dividing Dataset in train,validation and test
dataset_root = '/content/Dataset_new/val2017'
output_root = '/content/dataset_new'

train_dir = os.path.join(output_root, 'train')
test_dir = os.path.join(output_root, 'test')
validation_dir = os.path.join(output_root, 'validation')

os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
os.makedirs(validation_dir, exist_ok=True)

train_dir = os.path.join(train_dir, 'train')
test_dir = os.path.join(test_dir, 'test')
validation_dir = os.path.join(validation_dir, 'validation')

os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
os.makedirs(validation_dir, exist_ok=True)

all_files = os.listdir(dataset_root)

train_files, test_validation_files = train_test_split(all_files, test_size=0.2, random_state=42)
test_files, validation_files = train_test_split(test_validation_files, test_size=0.5, random_state=42)

for file in train_files:
    shutil.copy(os.path.join(dataset_root, file), os.path.join(train_dir, file))

for file in test_files:
    shutil.copy(os.path.join(dataset_root, file), os.path.join(test_dir, file))

for file in validation_files:
    shutil.copy(os.path.join(dataset_root, file), os.path.join(validation_dir, file))


## Models Used
We have chosen below 2 models
1. **VGG16**: Pre-trained model for feature extraction for loss comparisions.
2. **TransformerNet**: The main model which acts as an encoder-decoder pair and learns to convert any image to a specific style.

In [7]:
""" Pretrained VGG16 Model """
class VGG16(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()


        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3"])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out


""" Transformer Net """
class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        self.model = nn.Sequential(
            ConvBlock(3, 32, kernel_size=9, stride=1),
            ConvBlock(32, 64, kernel_size=3, stride=2),
            ConvBlock(64, 128, kernel_size=3, stride=2),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128),
            ConvBlock(128, 64, kernel_size=3, upsample=True),
            ConvBlock(64, 32, kernel_size=3, upsample=True),
            ConvBlock(32, 3, kernel_size=9, stride=1, normalize=False, relu=False),
        )

    def forward(self, x):
        return self.model(x)


""" Components of Transformer Net """
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=True),
            ConvBlock(channels, channels, kernel_size=3, stride=1, normalize=True, relu=False),
        )

    def forward(self, x):
        return self.block(x) + x


class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, upsample=False, normalize=True, relu=True):
        super(ConvBlock, self).__init__()
        self.upsample = upsample
        self.block = nn.Sequential(
            nn.ReflectionPad2d(kernel_size // 2), nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        )
        self.norm = nn.InstanceNorm2d(out_channels, affine=True) if normalize else None
        self.relu = relu

    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2)
        x = self.block(x)
        if self.norm is not None:
            x = self.norm(x)
        if self.relu:
            x = F.relu(x)
        return x

In [8]:
""" Utility Functions """
def gram_matrix(y):
    (b, c, h, w) = y.size()
    features = y.view(b, c, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (c * h * w)
    return gram

def train_transform(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((int(image_size * 1.15),int(image_size * 1.15))),
            transforms.RandomCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    )
    return transform

def validation_transform(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((int(image_size * 1.15),int(image_size * 1.15))),
            transforms.RandomCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    )
    return transform

def style_transform(image_size=None):
    resize = [transforms.Resize((image_size,image_size))] if image_size else []
    transform = transforms.Compose(resize + [transforms.ToTensor(), transforms.Normalize(mean, std)])
    return transform

def test_transform(image_size=None):
    resize = [transforms.Resize(image_size)] if image_size else []
    transform = transforms.Compose(resize + [transforms.ToTensor(), transforms.Normalize(mean, std)])
    return transform

# Denormalizes image tensors using mean and std
def denormalize(tensors):
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return tensors

# Denormalizes and rescales image tensor
def deprocess(image_tensor):
    image_tensor = denormalize(image_tensor)[0]
    image_tensor *= 255
    image_np = torch.clamp(image_tensor, 0, 255).cpu().numpy().astype(np.uint8)
    image_np = image_np.transpose(1, 2, 0)
    return image_np

In [9]:
pip install torchmetrics

Collecting torchmetrics

  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m



Collecting lightning-utilities>=0.8.0 (from torchmetrics)

  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)












Installing collected packages: lightning-utilities, torchmetrics

Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.0


In [10]:
from torchmetrics.image import StructuralSimilarityIndexMeasure
def fast_trainer(style_image,
                 style_name,
                 train_dataset_path,
                 validation_dataset_path,
                 image_size=256,
                 style_size=448,
                 batch_size = 8,
                 lr = 1e-5,
                 epochs = 1,
                 checkpoint_model = None,
                 checkpoint_interval=200,
                 sample_interval=200,
                 lambda_style=10e10,
                 lambda_content=10e5,):

    os.makedirs(f"./checkpoints", exist_ok=True)

    train_dataset = datasets.ImageFolder(train_dataset_path, train_transform(image_size))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size)

    validation_dataset = datasets.ImageFolder(validation_dataset_path, validation_transform(image_size))
    validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    vgg = VGG16(requires_grad=False).to(device)

    if checkpoint_model:
        transformer.load_state_dict(torch.load(checkpoint_model))

    optimizer = Adam(transformer.parameters(), lr)
    l2_loss = torch.nn.MSELoss().to(device)

    style = style_transform(style_size)(Image.open(style_image))
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(style)
    gram_style = [gram_matrix(y) for y in features_style]

    ssim = StructuralSimilarityIndexMeasure().to(device)
    train_metrics = {"content": [], "style": [], "total": [],"ssim": []}
    val_metrics = {"content": [], "style": [], "total": [],"ssim": []}
    for epoch in range(epochs):
        val_ssim = []
        train_ssim = []
        epoch_metrics = {"content": [], "style": [], "total": [],"ssim": []}
        for batch_i, (images, _) in enumerate(train_dataloader):
            optimizer.zero_grad()

            images_original = images.to(device)
            images_transformed = transformer(images_original)

            train_r  = ssim(images_original,images_transformed)
            train_ssim.append(train_r)

            features_original = vgg(images_original)
            features_transformed = vgg(images_transformed)

            content_loss = lambda_content * l2_loss(features_transformed.relu2_2, features_original.relu2_2)

            style_loss = 0
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += l2_loss(gm_y, gm_s[: images.size(0), :, :])
            style_loss *= lambda_style

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            epoch_metrics["content"] += [content_loss.item()]
            epoch_metrics["style"] += [style_loss.item()]
            epoch_metrics["total"] += [total_loss.item()]
            epoch_metrics["ssim"] += [train_r.item()]

            train_metrics["content"] += [content_loss.item()]
            train_metrics["style"] += [style_loss.item()]
            train_metrics["total"] += [total_loss.item()]
            train_metrics["ssim"] += [train_r.item()]

            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [ssim: %.2f] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
                % (
                    epoch + 1, epochs, batch_i,len(train_dataset),train_r, content_loss.item(),np.mean(epoch_metrics["content"]),
                    style_loss.item(), np.mean(epoch_metrics["style"]), total_loss.item(), np.mean(epoch_metrics["total"]),
                )
            )
            batches_done = epoch * len(train_dataloader) + batch_i + 1
            torch.save(transformer.state_dict(), f"./checkpoints/last_checkpoint.pth")

        for batch_i, (images, _) in enumerate(validation_dataloader):
            images_original = images.to(device)
            images_transformed = transformer(images_original)
            val_r  = ssim(images_original,images_transformed)
            val_ssim.append(val_r)

            features_original = vgg(images_original)
            features_transformed = vgg(images_transformed)

            content_loss_1 = lambda_content * l2_loss(features_transformed.relu2_2, features_original.relu2_2)

            style_loss_1 = 0
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss_1 += l2_loss(gm_y, gm_s[: images.size(0), :, :])
            style_loss_1 *= lambda_style

            total_val_loss = content_loss_1 + style_loss_1

            val_metrics["content"] += [content_loss.item()]
            val_metrics["style"] += [style_loss.item()]
            val_metrics["total"] += [total_loss.item()]
            val_metrics["ssim"] += [val_r.item()]
            sys.stdout.write(
                "\r[Validation] [ssim: %.2f] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
                % (
                    val_r, content_loss.item(),np.mean(val_metrics["content"]),
                    style_loss.item(), np.mean(val_metrics["style"]), total_loss.item(), np.mean(val_metrics["total"]),
                )
            )

    print("Training Completed!")
    plt.plot(val_metrics["content"], label = "Content Loss")
    plt.plot(val_metrics["style"], label = "Style Loss")
    plt.plot(val_metrics["total"], label = "Total Loss")
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Validation Loss')
    plt.legend()
    plt.show()

In [None]:
""" Run this to train the model """
fast_trainer(style_image='/content/style/picasso_selfportrait.jpg',style_name = 'Picasso_Selfportrait',
             train_dataset_path='/content/dataset_new/train', validation_dataset_path='/content/dataset_new/validation',epochs = 1)

In [None]:
def test_check(test_dataset_path,checkpoint_model,image_size):
    all_files = os.listdir(test_dataset_path)
    transform = test_transform()
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint_model))
    transformer.eval()
    ssim = StructuralSimilarityIndexMeasure().to(device)
    total_test = []
    for file in all_files:
        image_path = os.path.join(test_dataset_path,file)
        image_tensor = Variable(transform(Image.open(image_path))).to(device)
        image_tensor = image_tensor.unsqueeze(0)
        stylized_image = transformer(image_tensor)
        if image_tensor.size() == stylized_image.size():
            test_r = ssim(image_tensor,stylized_image)
            total_test.append(test_r.int())

    print(np.mean(total_test.cpu()))

In [None]:
def test_image(image_path,checkpoint_model,save_path):
    os.makedirs(os.path.join(save_path,"results"), exist_ok=True)

    transform = test_transform()
    transformer = TransformerNet().to(device)
    transformer.load_state_dict(torch.load(checkpoint_model))
    transformer.eval()

    image_tensor = Variable(transform(Image.open(image_path))).to(device)
    image_tensor = image_tensor.unsqueeze(0)

    with torch.no_grad():
        stylized_image = denormalize(transformer(image_tensor)).cpu()
    
    fn = checkpoint_model.split('/')[-1].split('.')[0]
    save_image(stylized_image, os.path.join(save_path,f"results/{fn}-output.jpg"))
    print("Image Saved!")
    plt.imshow(cv2.cvtColor(cv2.imread(os.path.join(save_path,f"results/{fn}-output.jpg")), cv2.COLOR_BGR2RGB))

In [None]:
""" Run this to test the model """
test_check(test_dataset_path = '/content/dataset_new/test/test/',checkpoint_model = '/content/checkpoints/last_checkpoint.pth',image_size = 256)

In [None]:
""" Run this to visualize the styled images """
test_image(image_path = './content/japanese_garden.jpg',
           checkpoint_model = '/content/checkpoints/last_checkpoint.pth',
           save_path = './')