In [1]:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def load_image(image_path, transform=None, max_size=None, shape=None):
    """Load an image and convert it to a torch tensor."""
    image = Image.open(image_path)
    
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

In [11]:
class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        """Extract multiple convolutional feature maps."""
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features
#构建模型中，除了引入经典模型vgg19，还提取了特定层的特征

In [8]:
def main(config):
    
    # Image preprocessing
    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # We use the same normalization statistics here. 将图片转换为tensor格式，再归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))])
    
    # Load content and style images
    # Make the style image same size as the content image 加载内容图片和风格图片
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
    
    # Initialize a target image with the content image 目标图片由内容图片复制而来，将目标图片的参数放入优化器
    target = content.clone().requires_grad_(True)
    
    optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()
    
    for step in range(config.total_step):
        
        # Extract multiple(5) conv feature vectors 提取目标函数、内容函数、风格函数的特定特征
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)
        
        #初始化风格损失和内容损失
        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss with target and content images 计算内容损失
            content_loss += torch.mean((f1 - f2)**2)

            # Reshape convolutional feature maps 将全连接层的结果重塑为矩阵格式
            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            # Compute gram matrix 内积得到gram矩阵
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            # Compute style loss with target and style images
            #计算风格损失
            style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 
        
        # Compute total loss, backprop and optimize 计算总损失，进行梯度置零，反向传播，参数更新
        loss = content_loss + config.style_weight * style_loss 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #每经过若干次迭代就输出迭代结果
        if (step+1) % config.log_step == 0:
            print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                   .format(step+1, config.total_step, content_loss.item(), style_loss.item()))
            
        #每经过若干次迭代就输出一个图片结果
        if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{}.jpg'.format(step+1))


In [9]:
#设置参数
parser = argparse.ArgumentParser()
parser.add_argument('--content', type=str, default='jpg/content.jpg')
parser.add_argument('--style', type=str, default='jpg/style.jpg')
parser.add_argument('--target', type=str, default='jpg/target.jpg')
parser.add_argument('--max_size', type=int, default=400)
parser.add_argument('--total_step', type=int, default=2000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=100)
parser.add_argument('--style_weight', type=float, default=100)
parser.add_argument('--lr', type=float, default=0.003)
config = parser.parse_known_args()[0]
print(config)

Namespace(content='jpg/content.jpg', log_step=10, lr=0.003, max_size=400, sample_step=100, style='jpg/style.jpg', style_weight=100, target='jpg/target.jpg', total_step=2000)


In [None]:
main(config)

Step [10/2000], Content Loss: 1.4381, Style Loss: 83.6343
Step [20/2000], Content Loss: 2.8423, Style Loss: 68.5622
Step [30/2000], Content Loss: 3.7688, Style Loss: 57.9365
Step [40/2000], Content Loss: 4.4737, Style Loss: 49.6662
Step [50/2000], Content Loss: 5.0545, Style Loss: 43.0076
Step [60/2000], Content Loss: 5.5489, Style Loss: 37.5587
Step [70/2000], Content Loss: 5.9760, Style Loss: 33.0521
Step [80/2000], Content Loss: 6.3534, Style Loss: 29.2967
Step [90/2000], Content Loss: 6.6839, Style Loss: 26.1461
Step [100/2000], Content Loss: 6.9778, Style Loss: 23.4865
Step [110/2000], Content Loss: 7.2411, Style Loss: 21.2264
Step [120/2000], Content Loss: 7.4843, Style Loss: 19.2936
Step [130/2000], Content Loss: 7.7092, Style Loss: 17.6300
Step [140/2000], Content Loss: 7.9169, Style Loss: 16.1899
Step [150/2000], Content Loss: 8.1128, Style Loss: 14.9353
Step [160/2000], Content Loss: 8.2975, Style Loss: 13.8362
Step [170/2000], Content Loss: 8.4696, Style Loss: 12.8668
Step [

Step [1400/2000], Content Loss: 13.1277, Style Loss: 0.9790
Step [1410/2000], Content Loss: 13.1386, Style Loss: 0.9614
Step [1420/2000], Content Loss: 13.1470, Style Loss: 0.9572
Step [1430/2000], Content Loss: 13.1515, Style Loss: 0.9510
Step [1440/2000], Content Loss: 13.1640, Style Loss: 0.9367
Step [1450/2000], Content Loss: 13.1650, Style Loss: 0.9407
Step [1460/2000], Content Loss: 13.1781, Style Loss: 0.9213
Step [1470/2000], Content Loss: 13.1862, Style Loss: 0.9204
Step [1480/2000], Content Loss: 13.1917, Style Loss: 0.9089
Step [1490/2000], Content Loss: 13.2028, Style Loss: 0.9006
Step [1500/2000], Content Loss: 13.2007, Style Loss: 0.9047
Step [1510/2000], Content Loss: 13.2167, Style Loss: 0.8841
Step [1520/2000], Content Loss: 13.2201, Style Loss: 0.8897
Step [1530/2000], Content Loss: 13.2300, Style Loss: 0.8744
Step [1540/2000], Content Loss: 13.2392, Style Loss: 0.8642
Step [1550/2000], Content Loss: 13.2358, Style Loss: 0.8748
Step [1560/2000], Content Loss: 13.2526,