In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchsummary import summary

import numpy as np
import matplotlib.pyplot as plt
import cv2

import copy

from fastprogress.fastprogress import progress_bar

### set up config

In [2]:
OPTIMIZER = 'adam'
lr = 1                   # small LR requires much more iterations, large LR may not converge enough
alpha = 1e-7             # content weight
beta = 1e-5              # style weight
iter_cnt = 8000          # number of iterations
mstone_every = 800       # show result every "mstone_every" iters

# Image Files
content_img_path = 'images/content_s.png'
style_img_path = 'images/style_01.png'

VGG19_path = 'models/vgg19-model.pth'

### work with images

In [5]:
# Load image file
def load_image(path):
    # Images loaded as BGR
    img = cv2.imread(path)
    return img

# Show image
def show_image(img):
    # Convert from BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # imshow() only accepts float [0,1] or int [0,255]
    img = np.array(img/255).clip(0,1)

    plt.figure(figsize=(8, 4))
    plt.imshow(img)
    plt.show()

# Save Image
def save_image(img, name):
    img = img.clip(0, 255)
    cv2.imwrite('images\out\saved_' + str(name) + '.png', img)

# Preprocessing
def img_to_tensor(img):
    H, W, C = img.shape
    i2t_trans = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])

    # Normalize
    normalize_t = transforms.Normalize([103.939, 116.779, 123.68],[1,1,1])
    tensor = normalize_t(i2t_trans(img)*255)

    tensor = tensor.unsqueeze(dim=0)
    return tensor

def tensor_to_img(tensor):
    # Inverse norm
    t2i_trans = transforms.Compose([
        transforms.Normalize([-103.939, -116.779, -123.68],[1,1,1])])

    tensor = tensor.squeeze()
    img = t2i_trans(tensor)
    img = img.cpu().numpy()

    img = img.transpose(1, 2, 0)
    return img

In [None]:
# Load Images
content_img = load_image(content_img_path)
style_img = load_image(style_img_path)

# Show Images
show_image(content_img)
show_image(style_img)