# Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization

## Imports

In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image

## Define the model

In [2]:
from model import StyleTransfer, train_decoder
model = StyleTransfer()

Using cache found in /home/mathis/.cache/torch/hub/pytorch_vision_v0.9.0


## Training of the decoder

### Datasets

In [3]:
from dataset import CIFAR10, ArtBench10
from torch.utils.data import DataLoader
from torch.utils.data import Subset

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

# datasets
cifar10_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
artbench_dataset = ArtBench10(root='./data', train=True, transform=transform, download=True)

# choose the size of the training dataset
size = 128
indices = list(range(size))
cifar10_dataset = Subset(cifar10_dataset, indices)
artbench_dataset = Subset(artbench_dataset, indices)

# create the dataloaders
cifar10_loader = DataLoader(cifar10_dataset, batch_size=64, shuffle=True)
artbench_loader = DataLoader(artbench_dataset, batch_size=64, shuffle=True)

print("len(cifar10_dataset) =", len(cifar10_dataset), "images")
print("len(artbench_dataset) =", len(artbench_dataset), "images")

print("len(cifar10_loader) =", len(cifar10_loader), "batches")
print("len(artbench_loader) =", len(artbench_loader), "batches")

Files already downloaded and verified
Files already downloaded and verified
len(cifar10_dataset) = 128 images
len(artbench_dataset) = 128 images
len(cifar10_loader) = 2 batches
len(artbench_loader) = 2 batches


### Training

In [4]:
# Training
res = train_decoder(model=model,
                    content_loader=cifar10_loader,
                    style_loader=artbench_loader,
                    nb_epochs=3,
                    learning_rate=0.001,
                    lam=2.0)

Epochs:   0%|          | 0/3 [00:00<?, ?it/s]

loss = 54.57487487792969


Epochs:  33%|███▎      | 1/3 [00:11<00:23, 11.91s/it]

loss = 107.22421264648438
loss = 159.7381591796875


Epochs:  67%|██████▋   | 2/3 [00:23<00:11, 11.80s/it]

loss = 252.36868286132812
loss = 282.654052734375


Epochs: 100%|██████████| 3/3 [00:35<00:00, 11.81s/it]

loss = 320.7870788574219





## Load and preprocess the images

In [5]:
# Load the content and style images
content_img = Image.open("images/content/golden_gate.jpg")
style_img = Image.open("images/style/sketch.png")

# Preprocess the images
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

content_tensor = preprocess(content_img).unsqueeze(0)
style_tensor = preprocess(style_img).unsqueeze(0)

print("content_tensor:", content_tensor.shape)
print("style_tensor:", style_tensor.shape)

transforms.ToPILImage()(content_tensor.squeeze(0).cpu().clamp(0, 1)).show()
transforms.ToPILImage()(style_tensor.squeeze(0).cpu().clamp(0, 1)).show()

content_tensor: torch.Size([1, 3, 32, 32])
style_tensor: torch.Size([1, 3, 32, 32])


## Run the model

In [6]:
with torch.no_grad():
    stylized_img = model(content_tensor, style_tensor)

## Print the result

In [7]:
# Output the stylized image
output_img = transforms.ToPILImage()(stylized_img.squeeze(0).cpu().clamp(0, 1))
output_img.show()