In [1]:
import torch
import torch.cuda
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from torchvision.utils import save_image
import os
from tqdm import tqdm
import numpy as np

In [2]:
# Reference: https://towardsdatascience.com/implementing-neural-style-transfer-using-pytorch-fd8d43fb7bfa

In [3]:
image_size = 64

In [4]:
# Use these five layers of VGG network: 0, 5, 10, 19, 28
# Each element in features array corresponds to output of the intermediate layer
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        # Eliminate the unused layers(layers beyond conv5_1)
        # self.layers = [0, 5, 10, 19, 28]
        self.layers = [0, 5, 10]
        self.model = models.vgg19(pretrained=True).features

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.model):
            x = layer(x)
            if i in self.layers:
                features.append(x)
        
        return features

In [5]:
def calculate_content_loss(generated_features, content_features):
    content_loss = torch.mean((generated_features - content_features) ** 2)
    return content_loss

In [6]:
def calculate_style_loss(generated, style):
    batch_size, channel, height, width = generated.shape

    gram = torch.mm(generated.view(channel, height * width), generated.view(channel, height * width).t())
    A = torch.mm(style.view(channel, height * width), style.view(channel, height * width).t())

    style_l = torch.mean((gram - A) ** 2)
    return style_l

In [7]:
def calculate_loss(generated_features, content_features, style_features, alpha, beta):
    style_loss = content_loss = 0
    for generated, content, style in zip(generated_features, content_features, style_features):
        content_loss += calculate_content_loss(generated, content)
        style_loss += calculate_style_loss(generated, style)

    total_loss = alpha * content_loss + beta * style_loss
    return total_loss

In [8]:
def image_loader(path, device):
    image = Image.open(path)
    loader = transforms.Compose([transforms.Resize((image_size, image_size)), transforms.ToTensor()])
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

In [9]:
def create_model():
    device = torch.device("cuda" if (torch.cuda.is_available()) else 'cpu')
    model = VGG().to(device).eval()

    return device, model

In [10]:
def perform_style_transfer(
    model,
    device, 
    content_img,
    style_img,
    image_num,
    epochs=800,
    alpha=5,
    beta=100,
    optimizer_betas=(0.9, 0.999)
    ):
        style_image = image_loader(style_img, device)

        content_image = image_loader(content_img, device)
        generated_image = content_image.clone().requires_grad_(True)

        # Updates the pixels of the generated image not the model parameter
        optimizer = optim.Adam([generated_image], betas=optimizer_betas)  # (0.5, 0.999)

        for epoch in tqdm(range(1, epochs + 1)):
            generated_features = model(generated_image)
            content_features = model(content_image)
            style_features = model(style_image)

            # iterating over the activation of each layer and calculate the loss and
            # add it to the content and the style loss
            total_loss = calculate_loss(generated_features, content_features, style_features, alpha, beta)

            # optimize the pixel values of the generated image and backpropagate the loss
            optimizer.zero_grad()
            total_loss.backward()  # Backpropagate the total loss
            optimizer.step()  # Update the pixel values of the generated image

        save_image(generated_image, 'images/t1/wgan_gp_style/' + str(image_num) + '.png')

        style_image = style_image.cpu().detach().numpy()
        content_image = content_image.cpu().detach().numpy()
        generated_image = generated_image.cpu().detach().numpy()
        old_psnr = psnr(style_image, content_image)
        new_psnr = psnr(style_image, generated_image)

        tqdm.write(f"Old PSNR: {old_psnr:.4f}")
        tqdm.write(f"New PSNR: {new_psnr:.4f}")

In [11]:
device, model = create_model()

num_epoch = 500
alpha = 5
beta = 100
samples_num = 5

real_folder = "images/t1/real/"
fake_folder = "images/t1/wgan_gp/"

samples_num = len(os.listdir(real_folder))

for img_num in range(0, samples_num):
    content_img_filename = real_folder + str(img_num) + '.png'
    style_img_filename = fake_folder + str(img_num) + '.png'
    perform_style_transfer(
        model,
        device,
        content_img_filename,
        style_img_filename,
        img_num,
        epochs=num_epoch,
        alpha=alpha,
        beta=beta
    )

100%|██████████| 500/500 [00:06<00:00, 77.82it/s] 


Old PSNR: 21.6774
New PSNR: 22.9408


100%|██████████| 500/500 [00:04<00:00, 108.22it/s]


Old PSNR: 19.4812
New PSNR: 19.6188


100%|██████████| 500/500 [00:04<00:00, 110.13it/s]


Old PSNR: 22.5921
New PSNR: 23.7398


100%|██████████| 500/500 [00:04<00:00, 109.74it/s]


Old PSNR: 21.1178
New PSNR: 22.6336


100%|██████████| 500/500 [00:04<00:00, 109.51it/s]


Old PSNR: 22.1213
New PSNR: 23.7125


100%|██████████| 500/500 [00:04<00:00, 111.14it/s]


Old PSNR: 20.3147
New PSNR: 21.0194


100%|██████████| 500/500 [00:04<00:00, 108.57it/s]


Old PSNR: 23.6077
New PSNR: 25.5022


100%|██████████| 500/500 [00:04<00:00, 110.10it/s]


Old PSNR: 21.9920
New PSNR: 23.5079


100%|██████████| 500/500 [00:04<00:00, 109.37it/s]


Old PSNR: 21.9971
New PSNR: 23.2601


100%|██████████| 500/500 [00:04<00:00, 110.94it/s]


Old PSNR: 22.9207
New PSNR: 24.5805


100%|██████████| 500/500 [00:04<00:00, 109.64it/s]


Old PSNR: 20.3381
New PSNR: 21.1830


100%|██████████| 500/500 [00:04<00:00, 109.04it/s]


Old PSNR: 22.4662
New PSNR: 24.3064


100%|██████████| 500/500 [00:04<00:00, 107.77it/s]


Old PSNR: 19.8645
New PSNR: 21.0164


100%|██████████| 500/500 [00:04<00:00, 110.68it/s]


Old PSNR: 20.8610
New PSNR: 21.7753


100%|██████████| 500/500 [00:04<00:00, 108.45it/s]


Old PSNR: 20.8656
New PSNR: 22.3707


100%|██████████| 500/500 [00:04<00:00, 107.40it/s]


Old PSNR: 20.6389
New PSNR: 22.0586


100%|██████████| 500/500 [00:04<00:00, 108.83it/s]


Old PSNR: 22.6854
New PSNR: 24.9601


100%|██████████| 500/500 [00:04<00:00, 110.00it/s]


Old PSNR: 20.6433
New PSNR: 21.4007


100%|██████████| 500/500 [00:04<00:00, 109.26it/s]


Old PSNR: 18.5786
New PSNR: 19.2396


100%|██████████| 500/500 [00:04<00:00, 107.27it/s]


Old PSNR: 21.8347
New PSNR: 22.8740


100%|██████████| 500/500 [00:04<00:00, 107.08it/s]


Old PSNR: 21.7896
New PSNR: 24.2746


100%|██████████| 500/500 [00:04<00:00, 107.69it/s]


Old PSNR: 21.4010
New PSNR: 22.5904


100%|██████████| 500/500 [00:04<00:00, 106.92it/s]


Old PSNR: 21.7639
New PSNR: 23.3194


100%|██████████| 500/500 [00:04<00:00, 105.97it/s]


Old PSNR: 21.1942
New PSNR: 22.6658


100%|██████████| 500/500 [00:04<00:00, 107.79it/s]


Old PSNR: 21.5654
New PSNR: 22.9638


100%|██████████| 500/500 [00:04<00:00, 107.04it/s]


Old PSNR: 21.0902
New PSNR: 22.5459


100%|██████████| 500/500 [00:04<00:00, 108.85it/s]


Old PSNR: 21.7957
New PSNR: 23.1859


100%|██████████| 500/500 [00:04<00:00, 107.74it/s]


Old PSNR: 21.9462
New PSNR: 23.7519


100%|██████████| 500/500 [00:04<00:00, 109.16it/s]


Old PSNR: 21.7369
New PSNR: 22.8827


100%|██████████| 500/500 [00:04<00:00, 106.96it/s]


Old PSNR: 22.4151
New PSNR: 24.3471


100%|██████████| 500/500 [00:04<00:00, 105.82it/s]


Old PSNR: 20.8565
New PSNR: 22.3874


100%|██████████| 500/500 [00:04<00:00, 104.74it/s]


Old PSNR: 21.6583
New PSNR: 22.1882


100%|██████████| 500/500 [00:04<00:00, 108.14it/s]


Old PSNR: 20.6073
New PSNR: 20.5443


100%|██████████| 500/500 [00:04<00:00, 109.00it/s]


Old PSNR: 22.2486
New PSNR: 23.6132


100%|██████████| 500/500 [00:04<00:00, 105.81it/s]


Old PSNR: 21.7227
New PSNR: 22.7770


 39%|███▊      | 193/500 [00:01<00:02, 103.42it/s]