# Style transfer image2image (photo to image)
### https://www.kaggle.com/code/theblackmamba31/photo-to-sketch-using-autoencoder

In [107]:
!pip install wandb



In [108]:
import matplotlib.pyplot as plt       # Plotting
import numpy as np                    # Tableau Multidimensionnel
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2, glob
import wandb
import random


from tqdm import tqdm
from datetime import datetime
from torchvision import models, datasets, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from glob import glob

IM_SIZE = 256

In [109]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [110]:
device

'cuda'

## Loading data - b&w sketch

In [126]:
class PhotoDatasetAugmented(Dataset):
    def __init__(self, drawings_data_dir, photo_data_dir, train=True):
        self.train=train
        self.drawings_path = glob(drawings_data_dir +'/*.jpeg')
        self.photos_path = [ photo_data_dir + elem.split('/')[-1].split('_')[-2] + '_photo.jpeg' for elem in self.drawings_path]
        self.transforms_initial = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize([256,256], antialias=True),
        ])
        self.transforms_fliprotate = transforms.Compose([
            transforms.RandomRotation(180, fill=1.0),
            transforms.RandomHorizontalFlip(0.25),
        ])
        self.transforms_color = transforms.ColorJitter(hue=0.5)
    def __len__(self): return len(self.photos_path)
    def __getitem__(self, ix):
        photo = cv2.cvtColor(cv2.imread(self.photos_path[ix]), cv2.COLOR_BGR2RGB)
        drawing = cv2.cvtColor(cv2.imread(self.drawings_path[ix]), cv2.COLOR_BGR2RGB)
        photo = self.transforms_initial(photo)
        drawing = self.transforms_initial(drawing)
        if self.train:
          if random.random()>0.5:
            photo = self.transforms_color(photo)
          pair_tensor = torch.stack((photo, drawing))
          pair_tensor = self.transforms_fliprotate(pair_tensor)
          photo, drawing = torch.unbind(pair_tensor)
        return photo, drawing


In [127]:
train_dataset = PhotoDatasetAugmented('data/dessin/train_dessin/bw', 'data/photo/', train=False)
test_dataset = PhotoDatasetAugmented('data/dessin/test_dessin/bw',  'data/photo/', train=False)

In [128]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=2)

In [129]:
class AE(torch.nn.Module):
	def __init__(self):
		super().__init__()
		# Conv2D (in_channel, out_channel, kernel_size, stride)
		self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.Conv2d(16, 32, 4, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 4, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 4, stride=2, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
		)

		self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, stride=2, bias=False),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, bias=False),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, bias=False),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(16, 8, (2,2), stride=(1,1)),
            nn.ConvTranspose2d(8, 3, (2,2), stride=(1,1))
		)

	def forward(self, x):
		encoded = self.encoder(x)
		decoded = self.decoder(encoded)
		return decoded


In [130]:
autoencoder = AE().to(device)
summary(autoencoder, (3,256,256), batch_size=32, device=device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [32, 16, 127, 127]             768
         LeakyReLU-2         [32, 16, 127, 127]               0
            Conv2d-3           [32, 32, 62, 62]           8,192
       BatchNorm2d-4           [32, 32, 62, 62]              64
         LeakyReLU-5           [32, 32, 62, 62]               0
            Conv2d-6           [32, 64, 30, 30]          32,768
         LeakyReLU-7           [32, 64, 30, 30]               0
            Conv2d-8          [32, 128, 14, 14]         131,072
       BatchNorm2d-9          [32, 128, 14, 14]             256
        LeakyReLU-10          [32, 128, 14, 14]               0
           Conv2d-11            [32, 256, 6, 6]         524,288
      BatchNorm2d-12            [32, 256, 6, 6]             512
        LeakyReLU-13            [32, 256, 6, 6]               0
           Conv2d-14            [32, 51

In [131]:
# Validation using MSE Loss function
loss_function = nn.MSELoss()

# Using an Adam Optimizer with lr = 0.001
optimizer = torch.optim.Adam(autoencoder.parameters(), lr = 1e-3)


In [132]:
config = {
    'batch_size':4,
    'lr':0.001,
    'epochs':100,
    'image_size':IM_SIZE,
    'dataset_training_mode':False,
    'num_workers':2
    }

run = wandb.init(project='magacrea', job_type='train', save_code=True, config=config)
epochs = 100

for epoch in tqdm(range(epochs)):
    for ix, batch in enumerate(iter(train_dataloader)):
        photo, drawing = batch
        photo, drawing = photo.to(device), drawing.to(device)
        if epoch==0 and ix==0:
          print(f'photo min : {photo.min()}, photo max : {photo.max()}')
          print(f'drawing min : {drawing.min()}, drawing max : {drawing.max()}')

        # Shape [BATCH_SIZE, CHANNEL, SIZE, SIZE]
    	# Output of Autoencoder

        reconstructed_drawing = autoencoder(photo)

    	# Calculating the loss function
        loss = loss_function(reconstructed_drawing, drawing)

    	# The gradients are set to zero,
    	# the gradient is computed and stored.
    	# .step() performs parameter update
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

    	# Storing the losses in a list for plotting
        run.log({'training/loss' : loss.item()})
    run.log({'training/reconstructed': wandb.Image(reconstructed_drawing), 'training/drawing': wandb.Image(drawing), 'training/photo': wandb.Image(photo)})

with torch.no_grad():
  losses=[]
  autoencoder=autoencoder.eval()
  for ix, batch in enumerate(iter(test_dataloader)):
    photo, drawing = batch
    photo, drawing = photo.to(device), drawing.to(device)
    reconstructed_drawing = autoencoder(photo)
    loss = loss_function(reconstructed_drawing, drawing)
    losses.append(loss)
    run.log({'test/reconstructed': wandb.Image(reconstructed_drawing), 'test/drawing': wandb.Image(drawing), 'test/photo': wandb.Image(photo)})

  mean_loss = torch.Tensor(losses).mean()
  run.log({'test/mean_loss': mean_loss})
  autoencoder=autoencoder.train()

torch.save(autoencoder.state_dict(), 'autoencoder.pth')
artifact = wandb.Artifact(name="autoencoder", type="model")
artifact.add_file(local_path="./autoencoder.pth")  # Add dataset directory to artifact
run.log_artifact(artifact)  # Logs the artifact version "my_data:v0"
run.finish()

VBox(children=(Label(value='0.528 MB of 0.528 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

100%|██████████| 100/100 [1:18:51<00:00, 47.31s/it]


VBox(children=(Label(value='757.985 MB of 758.595 MB uploaded\r'), FloatProgress(value=0.9991957025184749, max…

0,1
test/mean_loss,▁
training/loss,▅█▅▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/mean_loss,0.05667
training/loss,0.0273
