In [1]:
from __future__ import division
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import transforms
import numpy as np
import argparse
from PIL import Image

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_image(image_path,transform=None,max_size=None,shape=None):
    """Load an image and convert it to a torch tensor."""
    image = Image.open(image_path)
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int),Image.ANTIALIAS)
    if shape:
        image = image.resize(shape,Image.LANCZOS)
    if transform:
        image = transform(image).unsqueeze(0)
    return image.to(device)

In [4]:
class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet,self).__init__()
        self.select = ['0','5','10','19','28']
        self.vgg = models.vgg19(pretrained=True).features
    def forward(self,x):
        """Extract multiple convolutional feature maps."""
        features = []
        # 遍历输出每一层，选择添加特征
        for name,layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

def main(config):
    # Image preprocessing
    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # We use the same normalization statistics here.
    transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                       std=(0.229, 0.224, 0.225))])
    # Load content and style images
    # Make the style image same size as the content image
    content = load_image(config.content,transform,max_size=config.max_size)
    style = load_image(config.style,transform,shape=[content.size(2),content.size(3)])
    # Initialize a target image with the content image
    target = content.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([target],lr=config.lr,betas=[0.5,0.999])
    vgg = VGGNet().to(device).eval()
    for step in range(config.total_step):
        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)
        
        style_loss = 0
        content_loss = 0
        for f1,f2,f3 in zip(target_features,content_features,style_features):
            # Compute content loss with target and content images
            content_loss += torch.mean((f1-f2)**2)

SyntaxError: unexpected EOF while parsing (<ipython-input-4-4ebd9a1976fe>, line 13)