In [6]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import imutils
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.models import vgg16, vgg16_bn
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image

from tensorboardX import SummaryWriter

from models import *
from func import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
batch_size = 4
device = "cpu"

transform = transforms.Compose([
    transforms.Resize(256),  # only resize the shorter edge to keep aspect ratio
    transforms.CenterCrop(256),  # then crop to get square img
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Does not resize style image, since gram matrices has the same shapes after all
style_transform = transforms.Compose([ 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = ImageFolder("data", transform=transform)
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [3]:
vgg = Vgg16Wrapper(requires_grad=False)
vgg.to(device);

In [4]:
img = Image.open("data/starry_night.jpg")
img.thumbnail((256, 256), Image.ANTIALIAS)
style_img = style_transform(img).to(device)
style_img = style_img.unsqueeze(0)

# Extract style information from style image
_, style_acts_true = vgg(style_img,
                    style_layer_idxs=[5, 12, 22, 32])
grams_true = [gram_matrix(act) for act in style_acts_true]
grams_true = [gram.repeat(batch_size, 1, 1) for gram in grams_true]

In [5]:
style_img.shape

torch.Size([1, 3, 203, 256])

In [8]:
content_weight = 1
style_weight = 1e6

num_epochs = 200
ckpt_interval = 1#200
ckpt_dir = "checkpoints"
result_dir = "gen_imgs"
Path(ckpt_dir).mkdir(exist_ok=True)
Path(result_dir).mkdir(exist_ok=True)

net = TransformerNet()
net.to(device)

opt = optim.Adam(net.parameters(), lr=1e-3)
mse = nn.MSELoss()

writer = SummaryWriter('logdir')
i = 1

for _ in range(num_epochs):
    for X_train, _ in loader:
        opt.zero_grad()
        
        X_train = X_train.to(device)
        gen_imgs = net(X_train)
        
        # Extract content & style features from generated img
        
        gen_imgs /= 255.  # Tried removing sigmoid
        norm_gen_imgs = normalize_images(gen_imgs)
        content_pred, style_acts = vgg(norm_gen_imgs, content_layer_idxs=[12], 
                                       style_layer_idxs=[5, 12, 22, 32])
        grams_pred = [gram_matrix(act) for act in style_acts]
        # Extract content features from content images (the inputs in the 1st place)
        content_true, _ = vgg(X_train, content_layer_idxs=[12])
        
        content_loss = 0.
        for c_pred, c_true in zip(content_pred, content_true):
            content_loss += mse(c_pred, c_true)
        content_loss *= content_weight
            
        style_loss = 0.
        for g_pred, g_true in zip(grams_pred, grams_true):
            style_loss += mse(g_pred, g_true)
        style_loss *= style_weight
        
        total_loss = content_loss + style_loss
        total_loss.backward()
        opt.step()
        
        writer.add_scalar('content_loss', content_loss.item())
        writer.add_scalar('style_loss', style_loss.item())
        writer.add_scalar('total_loss', total_loss.item())
        
#         if i % ckpt_interval == 0:
#             # Save model
#             torch.save(net.state_dict(), f"{ckpt_dir}/net_{i}.pth")
#             # Save a generated images
#             with open(f"{result_dir}/generated_{i}.jpg", "w") as f:
#                 save_image(gen_imgs, f)
            
        print(f"Iter {i}: content={content_loss.item():.4f}, style={style_loss.item():.4f}, total={total_loss.item():.4f}")
        i += 1

Iter 1: content=2475.7485, style=5182.1899, total=7657.9385
Iter 2: content=2378.6646, style=5174.8301, total=7553.4946
Iter 3: content=2278.8516, style=5141.7793, total=7420.6309
Iter 4: content=2173.8816, style=5124.5293, total=7298.4111
Iter 5: content=2091.0352, style=5102.8608, total=7193.8960
Iter 6: content=2043.7766, style=5082.0649, total=7125.8418
Iter 7: content=2000.9401, style=5072.0552, total=7072.9951
Iter 8: content=1977.6740, style=5051.2544, total=7028.9282
Iter 9: content=1964.0236, style=5039.8271, total=7003.8506
Iter 10: content=1935.6667, style=4998.1646, total=6933.8311
Iter 11: content=1919.3312, style=4966.1914, total=6885.5225
Iter 12: content=1898.5607, style=4938.3936, total=6836.9541
Iter 13: content=1878.1234, style=4907.1196, total=6785.2432
Iter 14: content=1867.9537, style=4863.7495, total=6731.7031
Iter 15: content=1857.4147, style=4839.3809, total=6696.7954
Iter 16: content=1853.2291, style=4794.4902, total=6647.7192
Iter 17: content=1823.0789, style

KeyboardInterrupt: 