In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch import optim
from PIL import Image
from torchvision import transforms, models
from torchvision.utils import save_image
from matplotlib import pyplot as plt

In [None]:
# pre_process = transforms.Compose([
#     transforms.Resize((512, 712)),
#     transforms.ToTensor(),
    
# ])

In [None]:
class vgg_modified(nn.Module):
    def __init__(self):
        super(vgg_modified, self).__init__()
        self.chosen_features = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)

        return features

In [None]:
# def load_image(image_name):
#     image = Image.open(image_name)
#     image = pre_process(image).unsqueeze(0)
#     return image.to(device)


In [None]:
# In the paper, if we take more than 500 x 500 size it would take more than a hours. 

def load_image(img_path, max_size = 612, shape = None):
    ''' Load in and transform an image, making sure the image
    is <= 400 pixels in the x-y dims.'''
    if 'http' in img_path:
        response = requests.get(img_path)
        image = Image.open(BytesIO(response.content)).convert('RGB')
        
    else:
        image = Image.open(img_path).convert('RGB')
        
    # Load images will be slow down processing
    
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
        
    if shape is not None:
        size = shape


#     in_transform = transforms.Compose([
#                         transforms.Resize((612, 512)),
#                         transforms.ToTensor(),
#                         transforms.Normalize((0.485, 0.456, 0.406), 
#                                              (0.229, 0.224, 0.225))])
    
    in_transform = transforms.Compose([
                    transforms.Resize((712, 512)),
                    transforms.ToTensor()])
    
    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3, :, :].unsqueeze(0)
    
    return image.to(device)

In [None]:
# !wget https://www.dropbox.com/s/z1y0fy2r6z6m6py/60.jpg
# !wget https://www.dropbox.com/s/1svdliljyo0a98v/style_image.png
# !wget https://raw.githubusercontent.com/bensains1/fast-style-transfer-master/master/examples/content/chicago.jpg
# !wget https://raw.githubusercontent.com/bensains1/fast-style-transfer-master/master/examples/style/the_shipwreck_of_the_minotaur.jpg
# !wget https://raw.githubusercontent.com/bensains1/fast-style-transfer-master/master/examples/style/udnie.jpg
# !wget https://raw.githubusercontent.com/bensains1/fast-style-transfer-master/master/examples/style/the_scream.jpg

In [None]:
# helper function for un-normalizing an image 
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))  # Denormalize
    image = image.clip(0, 1)

    return image;

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
content_img = load_image('../input/imagedataset/fub2.jpg')
style_img = load_image('../input/imagedataset/hockney.jpg')

model = vgg_modified().to(device)

In [None]:
# display the images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# content and style ims side-by-side
ax1.imshow(content_img.cpu().detach().numpy().squeeze().transpose(1,2,0))  # Run without normalize 
ax2.imshow(style_img.cpu().detach().numpy().squeeze().transpose(1,2,0))    # Run without normalize
# ax1.imshow(im_convert(content_img))   # Run with normalize
# ax2.imshow(im_convert(style_img))     # Run with normalize

In [None]:
# generated_img = torch.randn(content_img.shape, device = device, requires_grad=True)
generated_img = content_img.clone()
generated_img.requires_grad = True

In [None]:
# TODO: Hyperparameters
total_steps = 15000
learning_rate = 0.001
alpha = 1
beta = 0.001
optimizer = optim.Adam([generated_img], lr=learning_rate)

In [None]:
from tqdm import tqdm
import time
start = time.time()
for step in tqdm(range(total_steps)):
    generated_features = model(generated_img)
    content_features = model(content_img)
    style_features = model(style_img)

    style_loss = content_loss = 0

    for gen_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        content_loss += torch.mean((gen_feature - content_feature) ** 2)

        # Compute Gram Matrix
        G = gen_feature.view(channel, height * width).mm\
            (gen_feature.view(channel, height * width).t())

        S = style_feature.view(channel, height * width).mm\
            (style_feature.view(channel, height * width).t())

        style_loss += torch.mean((G - S) ** 2)
        
    end = time.time()

    total_loss = alpha * content_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 400 == 0:
        print("EPOCH: {}/{} \tTotal Loss: {:2f} \tTime Elapsed: {:4f}".format(step, total_steps, total_loss, end - start))
        plt.imshow(generated_img.cpu().detach().numpy().squeeze().transpose(1,2,0))  # image display
        save_image(generated_img, 'generated.png')
        plt.show()
        

In [None]:
# plt.imshow(generated_img.cpu().detach().numpy().squeeze().transpose(1,2,0))

In [None]:
# Plot the image with the combination of content and style images:
# with torch.no_grad():
#     out_img = post_process(generated_img[0]).permute(1,2,0)
# show(out_img)