In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

**Load VGG19 Model**

In [None]:
vgg19 = torchvision.models.vgg19(pretrained=True)

In [None]:
model = rf_pool.models.FeedForwardNetwork(vgg19)

In [None]:
# check that weights were loaded by viewing first layer
rf_pool.utils.visualize.show_weights(model, 'features', 'conv2d0_weight')

**Create Image Dataset**

In [None]:
# set transform
transform = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop((224,224)),
                                transforms.ToTensor(),
                                rf_pool.ops.Op(lambda x: x / torch.max(x)),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225]),
                               ])

In [None]:
# get urls for dataset
base_url = 'https://random-ize.com/random-art-gallery/'
# set url dataset
styleset = rf_pool.utils.datasets.URLDataset('.', urls=[base_url], transform=transform,
                                             find_img_url=True, url_pattern='src="(.+\.jpg)"',
                                             url_replace=['/random-art-gallery/', base_url])

In [None]:
# get urls for dataset
base_url = 'https://loremflickr.com/300/300'
# set url dataset
contentset = rf_pool.utils.datasets.URLDataset('.', urls=[base_url], transform=transform,
                                               find_img_url=False)

**Style Transfer**

In [None]:
def gram_loss_fn(target, seed):
    t = torch.flatten(target, -2) 
    t = torch.matmul(t, t.transpose(-2,-1))
    s = torch.flatten(seed, -2)
    s = torch.matmul(s, s.transpose(-2,-1))
    return torch.nn.MSELoss()(t, s)

In [None]:
# Get Content and Style Images
content_img = contentset[0][0].unsqueeze(0)
plt.imshow(rf_pool.utils.functions.normalize_range(content_img[0]).permute(1,2,0))
plt.show()

style_img = styleset[0][0].unsqueeze(0)
plt.imshow(rf_pool.utils.functions.normalize_range(style_img[0]).permute(1,2,0))
plt.show()

In [None]:
# Style Transfer with Content and Style Losses
seed = content_img.detach().requires_grad_(True)

content_loss = rf_pool.losses.LayerLoss(model, {'features': {'conv2d30': []}}, 
                                        torch.nn.MSELoss(), input_target=content_img)
style_loss = rf_pool.losses.LayerLoss(model, {'features': dict([('conv2d%d' % d, []) 
                                                                for d in [0, 5, 10, 19, 28]])}, 
                                      gram_loss_fn, input_target=style_img)
loss_fn = rf_pool.losses.MultiLoss(losses=[content_loss, style_loss], 
                                   weights=[1e-3, 1e-2])

optim = torch.optim.Adam([seed], lr=5e-2)
model.optimize_texture(1000, seed, loss_fn, optim, monitor=5)