In [2]:
# VGG, decoder 모델 다운로드 링크
# drive.google.com/drive/folders/12g2eYD4oqd8F269nFlZ0qNEUEUGxcCPe?usp=drive_link

In [1]:
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import net
from function import calc_mean_std

In [2]:
def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    adain_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return adain_feat * style_std.expand(size) + style_mean.expand(size)

def test_transform(size, crop):
    transform_list = []
    if size != 0:
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def style_transfer(vgg, decoder, content, style, alpha=1.0):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    
    feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)

In [11]:
content_path = './input/content/lenna.jpg'
style_path = './input/style/asheville.jpg'
vgg_path = 'models/vgg.pth'
decoder_path = 'models/decoder.pth'
content_size = 512
style_size = 512
crop = False
output = './output'
alpha = 1.0

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

output_dir = Path(output)
output_dir.mkdir(exist_ok=True, parents=True)

assert (content_path or style_path)
content_path = Path(content_path)
style_path = Path(style_path)

decoder = net.decoder
decoder.eval()

vgg = net.vgg
vgg.eval()

decoder.load_state_dict(torch.load(decoder_path))
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder.to(device)

Sequential(
  (0): ReflectionPad2d((1, 1, 1, 1))
  (1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
  (2): ReLU()
  (3): Upsample(scale_factor=2.0, mode='nearest')
  (4): ReflectionPad2d((1, 1, 1, 1))
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): ReflectionPad2d((1, 1, 1, 1))
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (9): ReLU()
  (10): ReflectionPad2d((1, 1, 1, 1))
  (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
  (12): ReLU()
  (13): ReflectionPad2d((1, 1, 1, 1))
  (14): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
  (15): ReLU()
  (16): Upsample(scale_factor=2.0, mode='nearest')
  (17): ReflectionPad2d((1, 1, 1, 1))
  (18): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (19): ReLU()
  (20): ReflectionPad2d((1, 1, 1, 1))
  (21): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
  (22): ReLU()
  (23): Upsample(scale_factor=2.0, mode='nearest')
  (24): ReflectionPad2d((1, 1, 1, 1))
  (25): Conv2d(64

In [13]:
content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)

content = content_tf(Image.open(content_path))
style = style_tf(Image.open(style_path))

style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
    output = style_transfer(vgg, decoder, content, style, alpha)
output = output.cpu()

output_name = output_dir / '{:s}_stylized_{:s}.jpg'.format(
    content_path.stem, style_path.stem)
save_image(output, str(output_name))