In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import imutils
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.models import vgg16, vgg16_bn
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image

from tensorboardX import SummaryWriter

from models import *
from func import *

In [2]:
batch_size = 1
device = "cpu"

transform = transforms.Compose([
    transforms.Resize(256),  # only resize the shorter edge to keep aspect ratio
    transforms.CenterCrop(256),  # then crop to get square img
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Does not resize style image, since gram matrices has the same shapes after all
style_transform = transforms.Compose([ 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = ImageFolder("data", transform=transform)
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [3]:
vgg = Vgg16Wrapper(requires_grad=False)
vgg.to(device);

In [4]:
img = Image.open("data/starry_night.jpg")
img.thumbnail((256, 256), Image.ANTIALIAS)
style_img = style_transform(img).to(device)
style_img = style_img.unsqueeze(0)

# Extract style information from style image
_, style_acts_true = vgg(style_img,
                    style_layer_idxs=[5, 12, 22, 32])
grams_true = [gram_matrix(act) for act in style_acts_true]
grams_true = [gram.repeat(batch_size, 1, 1) for gram in grams_true]

In [6]:
content_weight = 1e5
style_weight = 1e10

num_epochs = 2
ckpt_interval = 1#200
ckpt_dir = "checkpoints"
result_dir = "gen_imgs"
Path(ckpt_dir).mkdir(exist_ok=True)
Path(result_dir).mkdir(exist_ok=True)

net = TransformerNet()
net.to(device)

opt = optim.Adam(net.parameters(), lr=1e-3)
mse = nn.MSELoss()

writer = SummaryWriter('logdir')
i = 1

for _ in range(num_epochs):
    for X_train, _ in loader:
        opt.zero_grad()
        
        X_train = X_train.to(device)
        gen_imgs = net(X_train)
        
        # Extract content & style features from generated img
        norm_gen_imgs = normalize_images(gen_imgs)
        content_pred, style_acts = vgg(norm_gen_imgs, content_layer_idxs=[12], 
                                       style_layer_idxs=[5, 12, 22, 32])
        grams_pred = [gram_matrix(act) for act in style_acts]
        # Extract content features from content images (the inputs in the 1st place)
        content_true, _ = vgg(X_train, content_layer_idxs=[12])
        
        content_loss = 0.
        for c_pred, c_true in zip(content_pred, content_true):
            content_loss += mse(c_pred, c_true)
        content_loss *= content_weight
            
        style_loss = 0.
        for g_pred, g_true in zip(grams_pred, grams_true):
            style_loss += mse(g_pred, g_true)
        style_loss *= style_weight
        
        total_loss = content_loss + style_loss
        total_loss.backward()
        opt.step()
        
        writer.add_scalar('content_loss', content_loss.item())
        writer.add_scalar('style_loss', style_loss.item())
        writer.add_scalar('total_loss', total_loss.item())
        
        if i % ckpt_interval == 0:
            # Save model
            torch.save(net.state_dict(), f"{ckpt_dir}/net_{i}.pth")
            # Save a generated images
            with open(f"{result_dir}/generated_{i}.jpg", "w") as f:
                save_image(gen_imgs, f)
            
        print(f"Iter {i}: content={content_loss.item():.4f}, style={style_loss.item():.4f}, total={total_loss.item():.4f}")
        i += 1

Iter 1: content=5599.5352, style=1228.4926, total=6828.0278
Iter 2: content=5219.8799, style=4426.1997, total=9646.0801
Iter 3: content=5077.5112, style=12195.6289, total=17273.1406
Iter 4: content=3574.9048, style=3431.7566, total=7006.6611
Iter 5: content=3384.4060, style=3488.7773, total=6873.1836
Iter 6: content=4312.9668, style=3386.7637, total=7699.7305
Iter 7: content=4938.4268, style=3640.4358, total=8578.8623
Iter 8: content=4637.3560, style=3175.1948, total=7812.5508
Iter 9: content=3014.4910, style=3417.9067, total=6432.3975
Iter 10: content=3741.4185, style=3246.7822, total=6988.2007
Iter 11: content=5102.7153, style=3090.0229, total=8192.7383
Iter 12: content=5209.4907, style=3342.4006, total=8551.8916
Iter 13: content=3857.4368, style=3192.2092, total=7049.6460
Iter 14: content=4800.1089, style=2972.0105, total=7772.1191
