# Texture synthesis (Gatys)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg19
from torchvision.transforms import ToTensor
from PIL import Image
from IPython.display import clear_output, display
from time import sleep

In [2]:
class TextureSynthesisNetwork(nn.Module):
    def __init__(self, layers):
        super(TextureSynthesisNetwork, self).__init__()
        self.vgg = vgg19(pretrained=True).features[:layers].eval()

        # Disable grad
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for layer in self.vgg:
            x = layer(x)
            features.append(x)
        return features

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def texture_synthesis(image, layers, iterations, learning_rate):
    image = image.unsqueeze(0).to(device)
    synthesized_image = torch.randn_like(image, requires_grad=True).to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam([synthesized_image], lr=learning_rate)

    network = TextureSynthesisNetwork(layers).to(device)

    for i in range(iterations):


        clear_output(wait=True) 
        print(f'Iteration {i + 1}/{iterations}')
        sleep(10)

        optimizer.zero_grad()
        features_image = network(image)
        features_synthesized = network(synthesized_image)

        if i%10==0:
            Image.fromarray((torch.clamp(synthesized_image.squeeze(0), 0, 1).detach().squeeze(0).numpy().transpose(1,2,0) * 255).astype('uint8')).show()

        loss = 0.0
        for f_image, f_synthesized in zip(features_image, features_synthesized):
            loss += criterion(f_synthesized, f_image)

        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Iteration {i + 1}/{iterations}, Loss: {loss.item()}")

    return synthesized_image.squeeze(0)

In [22]:
from skimage.io import imread

In [24]:
# input_image = Image.open("bookshelf.png")
input_image = imread("bookshelf.png")[...,:3]
# input_tensor = ToTensor()(input_image).unsqueeze(0)
input_tensor = ToTensor()(input_image)

# Texture parameters
layers = 5 
iterations = 170 
learning_rate = 0.1 


synthesized_texture = texture_synthesis(input_tensor, layers, iterations, learning_rate)


synthesized_image = torch.clamp(synthesized_texture, 0, 1).detach().squeeze(0).numpy().transpose(1, 2, 0)
synthesized_image = Image.fromarray((synthesized_image * 255).astype("uint8"))


synthesized_image.show()

Iteration 70/70


In [5]:
input_image = Image.open("bookshelf.png")

In [7]:
in_tensor = ToTensor()(input_image)

In [19]:
test = imread("bookshelf.png")[..., :3]

In [20]:
test = ToTensor()(test)