In [None]:
from collections import OrderedDict
import pickle
import re
import urllib

import imp
from IPython.display import clear_output, display
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage.filters
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

**Load VGG16 Model**

In [None]:
vgg16 = torchvision.models.vgg16(pretrained=True)

In [None]:
# get weights from vgg16
model_dict = OrderedDict()
for name, param in vgg16.named_parameters():
    if name.startswith('features'):
        model_dict.update({name: param})

In [None]:
# build part of vgg model
model = rf_pool.models.FeedForwardNetwork()
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(3, 64, 3), activation=torch.nn.ReLU()))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64, 64, 3), activation=torch.nn.ReLU(),
                                            pool=torch.nn.MaxPool2d(2, 2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64, 128, 3), activation=torch.nn.ReLU()))
model.append('3', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(128, 128, 3), activation=torch.nn.ReLU(),
                                             pool=torch.nn.MaxPool2d(2, 2)))
model.append('4', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(128, 256, 3), activation=torch.nn.ReLU()))
model.append('5', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(256, 256, 3), activation=torch.nn.ReLU()))
model.append('6', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(256, 256, 3), activation=torch.nn.ReLU(),
                                             pool=torch.nn.MaxPool2d(2, 2)))
model.append('7', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(256, 512, 3), activation=torch.nn.ReLU()))
model.append('8', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(512, 512, 3), activation=torch.nn.ReLU()))
model.append('9', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(512, 512, 3), activation=torch.nn.ReLU(),
                                             pool=torch.nn.MaxPool2d(2, 2)))
model.append('10', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(512, 512, 3), activation=torch.nn.ReLU()))
model.append('11', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(512, 512, 3), activation=torch.nn.ReLU()))
model.append('12', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(512, 512, 3), activation=torch.nn.ReLU(),
                                             pool=torch.nn.MaxPool2d(2, 2)))

In [None]:
# create dict for model key to vgg_16 key
param_keys = model.download_weights().keys()
param_dict = OrderedDict()
for (key, val) in zip(param_keys, model_dict.keys()):
    param_dict.update({key: val})

# load vgg16 weights into model
model.load_weights(model_dict, param_dict)

In [None]:
# check that weights were loaded by viewing first layer
model.show_weights('0')

In [None]:
# save weights for later re-use
model.save_model('vgg.pkl');

In [None]:
# if loading previous
model.load_model('vgg.pkl');

**Create Image Dataset**

In [None]:
# set transform
transform = transforms.Compose([transforms.CenterCrop((224,224)),
                                transforms.ToTensor(),
                                rf_pool.ops.Op(lambda x: x / torch.max(x)),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225]),
                               ])

In [None]:
# get urls for dataset
base_url = 'https://random-ize.com/random-art-gallery/'
# set url dataset
paintset = rf_pool.utils.datasets.URLDataset('.', urls=[base_url], transform=transform,
                                             find_img_url=True, url_pattern='src="(.+\.jpg)"',
                                             url_replace=['/random-art-gallery/', base_url])

In [None]:
# get urls for dataset
base_url = 'https://loremflickr.com/300/300'
# set url dataset
catset = rf_pool.utils.datasets.URLDataset('.', urls=[base_url], transform=transform,
                                           find_url=False)

**Style Transfer**

In [None]:
def gram_loss_fn(target, seed):
    t = torch.flatten(target, -2) 
    t = torch.matmul(t, t.transpose(-2,-1))
    s = torch.flatten(seed, -2)
    s = torch.matmul(s, s.transpose(-2,-1))
    return torch.nn.MSELoss()(t, s)

In [None]:
# Get Content and Style Images
content_img = catset[0][0].unsqueeze(0)
plt.imshow(rf_pool.utils.functions.normalize_range(content_img[0]).permute(1,2,0))
plt.show()

style_img = paintset[0][0].unsqueeze(0)
plt.imshow(rf_pool.utils.functions.normalize_range(style_img[0]).permute(1,2,0))
plt.show()

In [None]:
# Style Transfer with Content and Style Losses
seed = torch.rand_like(content_img, requires_grad=True)

content_loss = rf_pool.losses.LayerLoss(model, torch.nn.MSELoss(), ['2','4'], input_target=content_img)
style_loss = rf_pool.losses.LayerLoss(model, gram_loss_fn, ['2','4','7','10','11'], input_target=style_img)
loss_fn = rf_pool.losses.MultiLoss(losses=[content_loss, style_loss], weights=[100., 0.001])

optim = torch.optim.SGD([seed], lr=5e-3, momentum=0.9)
model.optimize_texture(1000, [], seed, loss_fn, 
                       optim, monitor=5, show_images=[content_img,style_img,seed], figsize=(10,10))