# Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization

In [None]:
# Run this cell if you are on external server like colab or kaggle. (Useful to train on GPU)
# !git clone "https://github.com/mathisemb/AdaIN.git"
# %cd AdaIN

## Imports

In [None]:
import torch
from PIL import Image
from utils.plot_tools import *
from utils.dataloader_maker2 import dataloader_maker
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Training of the decoder

### Dataloaders

In [None]:
content_path = 'utils/datasets/MS_COCO_val'
content_loader = dataloader_maker(folder_path= content_path, nb_of_images=32, batch_size=8)

style_path = 'utils/datasets/wikiart/wikiart'
style_loader = dataloader_maker(folder_path= style_path, nb_of_images=32, batch_size=8)

In [None]:
# Show some images
for img in content_loader :
    plot_img(img[0][0],img[0][1], img[0][2])
    break
for img in style_loader :
    plot_img(img[0][0],img[0][1], img[0][2])
    break

## Define the model

In [None]:
from model import StyleTransfer
lr = 3e-4
lam = 2.
model = StyleTransfer(lr=lr, lam=lam)

### Training the decoder

In [None]:
# Training
nb_epochs = 1
model.train_decoder(content_loader, style_loader, nb_epochs)

### Plot Loss

In [None]:
plot_losses(model.content_LOSS, model.style_LOSS)

### Saving model

In [None]:
# Will automatically be saved into the saving_path : 'model_checkpoints/Adain/'
model.save()

### Loading model to retrain it

The loading is carried out in such a way that the era returns to where it was, and the list of all losses is continued and not overwritten.

In [None]:
model = StyleTransfer(lr, lam)
checkpoint_epoch = 50 #Choose which model you want to load, see them in the saving_path : 'model_checkpoints/Adain/'


model.load(epoch=checkpoint_epoch)

model.train_decoder(content_loader, style_loader, nb_epochs)
model.save()

In [None]:
plt.plot(model.LOSS)

# Evaluating the model

## Load and preprocess the images

In [None]:
import torchvision.transforms as transforms

# Load the content and style images
content_img = Image.open("images/content/000000000785.jpg")
style_img = Image.open("images/style/albert-marquet_life-class-at-the-cole-des-beaux-arts-fauvist-nude-1898.jpg")

# Preprocess the images
preprocess = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

content_tensor = preprocess(content_img).unsqueeze(0).to(device)
style_tensor = preprocess(style_img).unsqueeze(0).to(device)

#print("content_tensor:", content_tensor.shape)
#print("style_tensor:", style_tensor.shape)

#transforms.ToPILImage()(content_tensor.squeeze(0).cpu().clamp(0, 1)).show()
#transforms.ToPILImage()(style_tensor.squeeze(0).cpu().clamp(0, 1)).show()

## Load and Run the model

In [None]:
load_checkpoint = checkpoint_epoch + nb_epochs
#load_checkpoint = 50
print(load_checkpoint)

model = StyleTransfer()
model.load(load_checkpoint)

with torch.no_grad():
    stylized_img = model(content_tensor, style_tensor)

## Print the result

In [None]:
plot_img(content_tensor, style_tensor, stylized_img)