In [4]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
from PIL import Image
import numpy as np

In [20]:
rgb_mean = np.array([0.485,0.456,0.405])
rgb_std = np.array([0.229,0.224,0.225])
def preprocess(img,image_shape):
    process = torchvision.transforms.compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean = rgb_mean,std = rgb_std)       
    ])
    return process

def postprocess(img_tensor):
    inv_normalize = torchvision.transforms.Normalize(mean = -rgb_mean / rgb_std, std = 1/rgb_std)
    to_PIL_image = torchvision.tranforms.ToPILImage()
    return to_PIL_image(inv_normalize(img_tensor[0])).clamp(0,1)

In [14]:
pretrained_net = torchvision.models.vgg19(pretrained= False)

In [13]:
content_layer,style_layer =[0,5,10,19,28],[25]

In [17]:
net_list = []
for i in range(max(content_layer+style_layer)+1):
    net_list.append(pretrained_net.features[i])

net = nn.Sequential(*net_list)

In [18]:
def extract_features(X, content_layer , style_layer):
    contents = []
    style = []
    for i in range(len(net)):
        X = net[i](X)
        if i in content_layer:
            contents.append(X)
        if i in style_layer :
            style.append(X)

    return contents , style

In [33]:
def get_contents(image_shape):
    content_X = preprocess(content_img , image_shape)
    contents_Y,_ = extract_features(content_X, content_layer, style_layer)
    return content_X, contents_Y

def get_styles(image_shape):
    style_X = prepocess(style_img, image_shape)
    _, styles_Y = extract_features(style_X, content_layer, style_layer)
    return style_X, style_Y

In [22]:
def gram(X):
    num_channels , n = X.shape[1],X.shape[2]*X.shape[3]
    X = X.view(num_channels , n)
    return torch.matmul(X,X.t()) / (num_channels *n)

In [23]:
def style_loss(Y_hat , gram_Y):
    return F.mse_loss(gram(Y_hat), gram_Y)

In [24]:
def content_loss(Y_hat, Y):
    return F.mse_loss(Y_hat, Y)

In [25]:
def tv_loss(Y_hat):
    return 0.5(F.l1_loss(Y_hat[:,:,1:,:],Y_hat[:,:,:-1,:])+
                F.l1_loss(Y_hat[:,:,:,1:],Y_hat[:,:,:,:-1])
              )

In [27]:
content_weight , style_weight, tv_weight = 1, 1e3,10

def compute_loss(X, contents_Y_hat,styles_Y_hat,contents_Y,styles_Y):
    contents_l = [content_loss(Y_hat,Y)* content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)
                 ]
    styles_l = [style_loss(Y_hat, Y)* style_weight for Y_hat,Y in zip(styles_Y_hat,styles_Y)]
    tv_l = tv_loss(X)*tv_weight
    l = sum(styles_l) + sum(contents_l)+tv_l
    return contents_l,styles_l,l

In [28]:
class GenerateImage(nn.Module):
    def __init__(self , img_shape):
        super(GenerateImage,self).__init__()
        self.weight = nn.Parameter(torch.rand(*image_shape))

    def forward(self):
        return self.weight

In [29]:
def get_inits(X, lr, styles_Y):
    gen_img  = GenerateImage(X.shape)
    gen_img.weight.data = X.data
    optimizer = torch.optim.Adam(gen_img.parameters(),lr =lr)
    style_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, optimizer

In [30]:
def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch):
    print("training on ", device)
    X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1)
    for i in range(max_epochs):
        start = time.time()
        
        contents_Y_hat, styles_Y_hat = extract_features(
                X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
                X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        
        optimizer.zero_grad()
        l.backward(retain_graph = True)
        optimizer.step()
        scheduler.step()
        
        if i % 50 == 0 and i != 0:
            print('epoch %3d, content loss %.2f, style loss %.2f, '
                  'TV loss %.2f, %.2f sec'
                  % (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(),
                     time.time() - start))
    return X.detach()

In [34]:
image_shape =  (150, 225)
content_X, contents_Y = get_contents(image_shape)
style_X, styles_Y = get_styles(image_shape)

NameError: name 'content_img' is not defined