导入需要使用的包

In [4]:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import numpy as np


# 使用GPU，如果有的话
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义一些使用到的超参数

In [7]:
content='png/content.png'
style='png/style.png'
max_size=400
total_step=2000
log_step=10
sample_step=500
style_weight=100
lr=0.003

创建一个加载图片的helper函数

In [2]:
def load_image(image_path, transform=None, max_size=None, shape=None):
    """载入图片并转化为tensor"""
    image = Image.open(image_path)
    
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

定义一个图片预处理pipeline

In [5]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

加载风格图片

In [8]:
content = load_image(content, transform, max_size=max_size)
style = load_image(style, transform, shape=[content.size(2), content.size(3)])

定义VGG网络，这里使用`torchvision.models`进行创建，返回第0,5,10,19,28层作为特征。由于`pretrained=True`所以会下载weights文件，国内IP可能会出现如下错误：
~~~bash
Downloading: "https://download.pytorch.......pth" to C:\Users\......./.cache\torch\checkpoints\resnet152-b121ed2d.pth

.....

urllib.error.URLError: <urlopen error [WinError 10054] 远程主机强迫关闭了一个现有的连接。>
~~~

这个是时候直接打开http链接下载到C:\User

In [3]:
class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        """Extract multiple convolutional feature maps."""
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

定义content为目标

In [9]:
target = content.clone().requires_grad_(True)

新建一个模型

In [10]:
optimizer = torch.optim.Adam([target], lr=lr, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()

训练

In [15]:
for step in range(total_step):

    # 提取特征，分别是0,5,10,19,28层的输出。是5个元素的list
    target_features = vgg(target)
    content_features = vgg(content)
    style_features = vgg(style)

    style_loss = 0
    content_loss = 0
    # List中每个元素取出来计算
    for f1, f2, f3 in zip(target_features, content_features, style_features):
        # 计算目标图片与输入的loss
        content_loss += torch.mean((f1 - f2)**2)

        # 转化为卷积特征图
        _, c, h, w = f1.size()
        f1 = f1.view(c, h * w)
        f3 = f3.view(c, h * w)

        # 计算Gram矩阵
        f1 = torch.mm(f1, f1.t())
        f3 = torch.mm(f3, f3.t())

        # 计算风格损失
        style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 

    # 损失加和并更新参数
    loss = content_loss + style_weight * style_loss 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (step+1) % log_step == 0:
        print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
               .format(step+1, total_step, content_loss.item(), style_loss.item()))

    if (step+1) % sample_step == 0:
        # Save the generated image
        denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
        img = target.clone().squeeze()
        img = denorm(img).clamp_(0, 1)
        torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

Step [10/2000], Content Loss: 7.4953, Style Loss: 23746.6035
Step [20/2000], Content Loss: 15.1861, Style Loss: 21374.9062
Step [30/2000], Content Loss: 22.3211, Style Loss: 18908.2129
Step [40/2000], Content Loss: 28.8174, Style Loss: 16537.4043
Step [50/2000], Content Loss: 34.8026, Style Loss: 14334.9492
Step [60/2000], Content Loss: 40.2554, Style Loss: 12319.2197
Step [70/2000], Content Loss: 45.2621, Style Loss: 10493.8389
Step [80/2000], Content Loss: 49.8830, Style Loss: 8877.5225
Step [90/2000], Content Loss: 54.0880, Style Loss: 7477.4932
Step [100/2000], Content Loss: 57.8993, Style Loss: 6292.5620
Step [110/2000], Content Loss: 61.3147, Style Loss: 5307.6030
Step [120/2000], Content Loss: 64.3333, Style Loss: 4499.2031
Step [130/2000], Content Loss: 66.9763, Style Loss: 3840.0518
Step [140/2000], Content Loss: 69.2697, Style Loss: 3300.7930
Step [150/2000], Content Loss: 71.2308, Style Loss: 2859.2915
Step [160/2000], Content Loss: 72.9064, Style Loss: 2496.5149
Step [170/2

KeyboardInterrupt: 