In [1]:
import sumie
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import skimage

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
model = torchvision.models.vgg16(pretrained=True).eval()
sumie.utils.remove_inplace(model)
sumie.utils.max_to_avg_pool(model)
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU()
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1,

In [3]:
import math

class ImageBatch(torch.nn.Module):
    
    def __init__(self, modules, tensors, transforms):
        super(ImageBatch, self).__init__()
        # TODO add thin tensor wrapper module
        self.module_ims = torch.nn.ModuleList(modules)
        self.fixed_ims = tensors
        self.transforms = torch.nn.Sequential(*transforms)
        
    def forward(self):
        module_ims = [im() for im in self.module_ims]
        fixed_ims = [im for im in self.fixed_ims]
        module_ims.extend(fixed_ims)
        batched = torch.cat(module_ims)
        return self.transforms(batched)
    
    def get_image(self):
        return self.module_ims[0].get_image()
    

def change_scale(opt, i):
    opt.image.transforms[-2].factor *= (10 ** (1/1024))
    
imsize = 512
#style_url = 'https://upload.wikimedia.org/wikipedia/commons/3/36/Vassily_Kandinsky%2C_1912_-_Improvisation_27%2C_Garden_of_Love_II.jpg'
style_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c2/Florecillas.jpg/729px-Florecillas.jpg'
content_url = 'https://upload.wikimedia.org/wikipedia/commons/6/69/Phoenix_Hall%2C_Byodo-in%2C_November_2016_-01.jpg'
imsize = 512
style_image = sumie.io.load_url(style_url, size=(imsize, imsize))
content_image = sumie.io.load_url(content_url, size=(imsize, imsize))
im = sumie.Image(size=imsize)

device = 'cuda'
im.to(device)


Image(
  (base_image): FftImage()
  (decorrelation): DecorrelateColours()
  (limit): Sigmoid()
  (transforms): Sequential()
)

In [None]:

content_image.to(device)
model.to(device)

images =  [style_image.to(device), content_image.to(device)]
combined_im = ImageBatch([im,], images, 
                             [
                    sumie.transforms.PositionJitter(16),
                    sumie.transforms.ScaleJitter(1.01),
                    sumie.transforms.PositionJitter(16),
                    sumie.transforms.Interpolate(0.1),
                    sumie.transforms.Normalise(),
                             ])
combined_im.to(device)
styles = []
for i in [10, 14, 17, 24]:
    styles.append(sumie.objectives.BatchMatchActivations(model.features[i], 0, 1, func=lambda x: sumie.utils.gram_matrix(x.unsqueeze(0))))
style_objective = sumie.objectives.Composite(styles, weights=[1, 2, 4, 8, 100, 1000])

content_objective = sumie.objectives.BatchMatchActivations(model.features[19], 0, 2)
style_transfer = sumie.objectives.Composite((style_objective, content_objective), weights=[10000, 1])

opt = sumie.Optimiser(combined_im, model.features, style_transfer)
opt.add_callback(change_scale)
#opt.run(iterations=1024, lr=0.05, progress=True, output='tmp')
sumie.vis.show(im.get_image(), figsize = (10, 10))
plt.semilogy([-1*x.item() for x in opt.history])

In [28]:

imsize = 512
content_url = 'https://upload.wikimedia.org/wikipedia/commons/6/69/Phoenix_Hall%2C_Byodo-in%2C_November_2016_-01.jpg'
style_image = sumie.io.load_url(style_url, size=(imsize, imsize))
content_image = sumie.io.load_url(content_url, size=(imsize, imsize))
im = sumie.Image(size=imsize)

device = 'cuda'
im.to(device)

images =  [content_image.to(device),]
combined_im = ImageBatch([im,], images, 
                             [
                    sumie.transforms.PositionJitter(16),
                    sumie.transforms.ScaleJitter(1.01),
                    sumie.transforms.PositionJitter(16),
                    sumie.transforms.Interpolate(0.5),
                    sumie.transforms.Normalise(),
                             ])
combined_im.to(device)
model.to(device)
styles = []
modules = [10, 14, 17, 24]
style_image = sumie.utils.normalise(style_image)
style_objective = sumie.objectives.Style(style_image.to(device), model.features, [model.features[x] for x in modules], weights=[1, 2, 4, 8])

content_objective = sumie.objectives.BatchMatchActivations(model.features[19], 0, 1)
style_transfer = sumie.objectives.Composite((style_objective, content_objective), weights=[10000, 1])

opt = sumie.Optimiser(combined_im, model.features, style_transfer)
#opt.add_callback(change_scale)
opt.run(iterations=512, lr=0.05, progress=True, output='tmp')
sumie.vis.show(im.get_image(), figsize = (10, 10))
plt.semilogy([-1*x.item() for x in opt.history])

 80%|███████▉  | 409/512 [01:27<00:22,  4.65it/s]

KeyboardInterrupt: 