# Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization

## Imports

In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image

## Define the model

In [2]:
from model import StyleTransfer, train_decoder
model = StyleTransfer()

Using cache found in /home/mathis/.cache/torch/hub/pytorch_vision_v0.9.0


## Training of the decoder

In [3]:
# Dataset
from dataset import CIFAR10, ArtBench10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

cifar10_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
cifar10_loader = torch.utils.data.DataLoader(cifar10_dataset, batch_size=64, shuffle=True)

artbench_dataset = ArtBench10(root='./data', train=True, transform=transform, download=True)
artbench_loader = DataLoader(artbench_dataset, batch_size=64, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Training
res = train_decoder(model=model,
                    content_loader=cifar10_loader,
                    style_loader=artbench_loader,
                    nb_epochs=10,
                    learning_rate=0.001,
                    lam=2.0)

  0%|          | 0/10 [00:00<?, ?it/s]

In model train decoder-------------
content_batch: torch.Size([64, 3, 32, 32])
style_batch: torch.Size([64, 3, 32, 32])
In style_loss-----------
output_features.shape: torch.Size([64, 64, 16, 16])
style_features.shape: torch.Size([64, 64, 32, 32])
In style_loss-----------
output_features.shape: torch.Size([64, 256, 4, 4])
style_features.shape: torch.Size([64, 256, 8, 8])
In style_loss-----------
output_features.shape: torch.Size([64, 256, 4, 4])
style_features.shape: torch.Size([64, 256, 8, 8])
In style_loss-----------
output_features.shape: torch.Size([64, 512, 1, 1])
style_features.shape: torch.Size([64, 512, 2, 2])
loss: tensor(nan)


  gen_std = torch.std(output_features, dim=[2, 3], keepdim=True)
  0%|          | 0/10 [00:06<?, ?it/s]


KeyboardInterrupt: 

## Load and preprocess the images

In [None]:
# Load the content and style images
content_img = Image.open("images/content/golden_gate.jpg")
style_img = Image.open("images/style/sketch.png")

# Preprocess the images
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

content_tensor = preprocess(content_img).unsqueeze(0)
style_tensor = preprocess(style_img).unsqueeze(0)

print("content_tensor:", content_tensor.shape)
print("style_tensor:", style_tensor.shape)

transforms.ToPILImage()(content_tensor.squeeze(0).cpu().clamp(0, 1)).show()
transforms.ToPILImage()(style_tensor.squeeze(0).cpu().clamp(0, 1)).show()

## Run the model

In [None]:
with torch.no_grad():
    # Extract content and style features
    content_features = encoder(content_tensor)
    style_features = encoder(style_tensor)

    # Perform AdaIN
    stylized_features = adain(content_features, style_features)

    # Decode the stylized features
    stylized_img = decoder(stylized_features)

## Print the result

In [None]:
# Output the stylized image
output_img = transforms.ToPILImage()(stylized_img.squeeze(0).cpu().clamp(0, 1))
output_img.show()