In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import cv2
from google.colab.patches import cv2_imshow

from PIL import Image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


### Scaling function
Function for rescaling to range [0, 1]. This is needed for displaying images.

In [None]:
def scale_0_1(x):
  return (x+1)/2

# **Generators**

## Pix2Pix
 [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)

Check also [SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis](https://arxiv.org/abs/1611.07004).

In [None]:
class Pix2Pix(nn.Module):
  def __init__(self, d = 64):     #(N,1,256,256)
    super(Pix2Pix, self).__init__()
    self.d = d

    #Unet encoder
    self.conv1 = nn.Conv2d(1, d, 4, 2, 1)
    self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
    self.conv2_bn = nn.BatchNorm2d(d*2)
    self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
    self.conv3_bn = nn.BatchNorm2d(d*4)
    self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
    self.conv4_bn = nn.BatchNorm2d(d*8)
    self.conv5 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv5_bn = nn.BatchNorm2d(d*8)
    self.conv6 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv6_bn = nn.BatchNorm2d(d*8)
    self.conv7 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv7_bn = nn.BatchNorm2d(d*8)
    self.conv8 = nn.Conv2d(d*8, d*8, 4, 2, 1)


    #Unet decoder
    self.deconv1 = nn.ConvTranspose2d(d*8, d*8, 4, 2, 1)
    self.deconv1_bn = nn.BatchNorm2d(d*8)
    self.deconv2 = nn.ConvTranspose2d(d*8*2, d*8, 4, 2, 1)
    self.deconv2_bn = nn.BatchNorm2d(d*8)
    self.deconv3 = nn.ConvTranspose2d(d*8*2, d*8 , 4, 2, 1)
    self.deconv3_bn = nn.BatchNorm2d(d*8)
    self.deconv4 = nn.ConvTranspose2d(d*8*2, d*8, 4, 2, 1)
    self.deconv4_bn = nn.BatchNorm2d(d*8)
    self.deconv5 = nn.ConvTranspose2d(d*8*2, d*4, 4, 2, 1)
    self.deconv5_bn = nn.BatchNorm2d(d*4)
    self.deconv6 = nn.ConvTranspose2d(d*4*2 , d*2, 4, 2, 1)
    self.deconv6_bn = nn.BatchNorm2d(d*2)
    self.deconv7 = nn.ConvTranspose2d(d*2*2, d, 4, 2, 1)
    self.deconv7_bn = nn.BatchNorm2d(d)
    self.deconv8 = nn.ConvTranspose2d(d*2, 3, 4, 2, 1)

    #Weight initialization
    self.weight_init()



  def forward(self, input):
    e1 = F.leaky_relu(self.conv1(input), 0.2)
    e2 = F.leaky_relu(self.conv2_bn(self.conv2(e1)), 0.2)
    e3 = F.leaky_relu(self.conv3_bn(self.conv3(e2)), 0.2)
    e4 = F.leaky_relu(self.conv4_bn(self.conv4(e3)), 0.2)
    e5 = F.leaky_relu(self.conv5_bn(self.conv5(e4)), 0.2)
    e6 = F.leaky_relu(self.conv6_bn(self.conv6(e5)), 0.2)
    e7 = F.leaky_relu(self.conv7_bn(self.conv7(e6)), 0.2)
    e8 = F.relu(self.conv8(e7))
    d1 = F.dropout(self.deconv1_bn(self.deconv1(e8)), 0.5 , training=True)
    d1 = F.relu(torch.cat([d1, e7], 1))
    d2 = F.dropout(self.deconv2_bn(self.deconv2(d1)), 0.5, training=True)
    d2 = F.relu(torch.cat([d2, e6], 1))
    d3 = F.dropout(self.deconv3_bn(self.deconv3(d2)), 0.5 , training=True)
    d3 = F.relu(torch.cat([d3, e5], 1))
    d4 = self.deconv4_bn(self.deconv4(d3))
    d4 = F.relu(torch.cat([d4, e4], 1))
    d5 = self.deconv5_bn(self.deconv5(d4))
    d5 = F.relu(torch.cat([d5, e3], 1))
    d6 = self.deconv6_bn(self.deconv6(d5))
    d6 = F.relu(torch.cat([d6, e2], 1))
    d7 = self.deconv7_bn(self.deconv7(d6))
    d7 = F.relu(torch.cat([d7, e1], 1))
    d8 = self.deconv8(d7)
    output = torch.tanh(d8)

    return output



  def weight_init(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias.data is not None:
          m.bias.data.zero_()
        print(m, ': Weights initialized')

## StylePix2Pix
Combination of [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004) (i.e. Pix2Pix) and [Conditional Image Synthesis With Auxiliary Classifier GANs](https://arxiv.org/abs/1610.09585) (i.e. ACGAN).

Check also [SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis](https://arxiv.org/abs/1611.07004).

In [None]:
class StylePix2Pix(nn.Module):
  def __init__(self, d = 64, n = 3):     #(N,1,256,256)
    super(StylePix2Pix, self).__init__()
    self.d = d

    #Concatenate Label
    self.embedding = nn.Embedding(n, 100)
    self.linear = nn.Linear(100, d*4*d*4)

    #Unet encoder
    self.conv1 = nn.Conv2d(2, d, 4, 2, 1)
    self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
    self.conv2_bn = nn.BatchNorm2d(d*2)
    self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
    self.conv3_bn = nn.BatchNorm2d(d*4)
    self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
    self.conv4_bn = nn.BatchNorm2d(d*8)
    self.conv5 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv5_bn = nn.BatchNorm2d(d*8)
    self.conv6 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv6_bn = nn.BatchNorm2d(d*8)
    self.conv7 = nn.Conv2d(d*8, d*8, 4, 2, 1)
    self.conv7_bn = nn.BatchNorm2d(d*8)
    self.conv8 = nn.Conv2d(d*8, d*8, 4, 2, 1)


    #Unet decoder
    self.deconv1 = nn.ConvTranspose2d(d*8, d*8, 4, 2, 1)
    self.deconv1_bn = nn.BatchNorm2d(d*8)
    self.deconv2 = nn.ConvTranspose2d(d*8*2, d*8, 4, 2, 1)
    self.deconv2_bn = nn.BatchNorm2d(d*8)
    self.deconv3 = nn.ConvTranspose2d(d*8*2, d*8 , 4, 2, 1)
    self.deconv3_bn = nn.BatchNorm2d(d*8)
    self.deconv4 = nn.ConvTranspose2d(d*8*2, d*8, 4, 2, 1)
    self.deconv4_bn = nn.BatchNorm2d(d*8)
    self.deconv5 = nn.ConvTranspose2d(d*8*2, d*4, 4, 2, 1)
    self.deconv5_bn = nn.BatchNorm2d(d*4)
    self.deconv6 = nn.ConvTranspose2d(d*4*2 , d*2, 4, 2, 1)
    self.deconv6_bn = nn.BatchNorm2d(d*2)
    self.deconv7 = nn.ConvTranspose2d(d*2*2, d, 4, 2, 1)
    self.deconv7_bn = nn.BatchNorm2d(d)
    self.deconv8 = nn.ConvTranspose2d(d*2, 3, 4, 2, 1)

    #Weight initialization
    self.weight_init()



  def forward(self, input, label):
    
    embedding = self.embedding(label)
    linear = self.linear(embedding)

    merged =  torch.cat((input, torch.reshape(linear, (-1, 1, self.d*4, self.d*4))), 1)

    e1 = F.leaky_relu(self.conv1(merged), 0.2)
    e2 = F.leaky_relu(self.conv2_bn(self.conv2(e1)), 0.2)
    e3 = F.leaky_relu(self.conv3_bn(self.conv3(e2)), 0.2)
    e4 = F.leaky_relu(self.conv4_bn(self.conv4(e3)), 0.2)
    e5 = F.leaky_relu(self.conv5_bn(self.conv5(e4)), 0.2)
    e6 = F.leaky_relu(self.conv6_bn(self.conv6(e5)), 0.2)
    e7 = F.leaky_relu(self.conv7_bn(self.conv7(e6)), 0.2)
    e8 = F.relu(self.conv8(e7))
    d1 = F.dropout(self.deconv1_bn(self.deconv1(e8)), 0.5 , training=True)
    d1 = F.relu(torch.cat([d1, e7], 1))
    d2 = F.dropout(self.deconv2_bn(self.deconv2(d1)), 0.5, training=True)
    d2 = F.relu(torch.cat([d2, e6], 1))
    d3 = F.dropout(self.deconv3_bn(self.deconv3(d2)), 0.5 , training=True)
    d3 = F.relu(torch.cat([d3, e5], 1))
    d4 = self.deconv4_bn(self.deconv4(d3))
    d4 = F.relu(torch.cat([d4, e4], 1))
    d5 = self.deconv5_bn(self.deconv5(d4))
    d5 = F.relu(torch.cat([d5, e3], 1))
    d6 = self.deconv6_bn(self.deconv6(d5))
    d6 = F.relu(torch.cat([d6, e2], 1))
    d7 = self.deconv7_bn(self.deconv7(d6))
    d7 = F.relu(torch.cat([d7, e1], 1))
    d8 = self.deconv8(d7)
    output = torch.tanh(d8)

    return output



  def weight_init(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias.data is not None:
          m.bias.data.zero_()
        print(m, ': Weights initialized')

## CycleGAN
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)

In [None]:
class CycleGAN(nn.Module):
  def __init__(self):     #(N,3,256,256)
    super(CycleGAN, self).__init__()

    self.c7s1_64 = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(3, 64, 7, 1, 0),
        nn.InstanceNorm2d(64),
        nn.ReLU(True)
    )

    self.d128 = nn.Sequential(
        nn.Conv2d(64, 128, 3, 2, 1),
        nn.InstanceNorm2d(128),
        nn.ReLU(True)
    )

    self.d256 = nn.Sequential(
        nn.Conv2d(128, 256, 3, 2, 1),
        nn.InstanceNorm2d(256),
        nn.ReLU(True)
    )

    self.r256_1 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_2 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_3 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_4 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_5 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_6 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_7 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_8 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.r256_9 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, 3, 1, 0),
        nn.InstanceNorm2d(256)
    )

    self.u128 = nn.Sequential(
        nn.ConvTranspose2d(256, 128, 3, 2, 1, output_padding=1),
        nn.InstanceNorm2d(128),
        nn.ReLU(True)
    )

    self.u64 = nn.Sequential(
        nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1),
        nn.InstanceNorm2d(64),
        nn.ReLU(True)
    )

    self.c7s1_3 = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(64, 3, 7, 1, 0),
        nn.Tanh()
    )

    self.weight_init()



  def forward(self, input):
    x = self.c7s1_64(input)

    x = self.d128(x)
    x = self.d256(x)

    x = self.r256_1(x) + x
    x = self.r256_2(x) + x
    x = self.r256_3(x) + x
    x = self.r256_4(x) + x
    x = self.r256_5(x) + x
    x = self.r256_6(x) + x
    x = self.r256_7(x) + x
    x = self.r256_8(x) + x
    x = self.r256_9(x) + x

    x = self.u128(x)
    x = self.u64(x)

    out = self.c7s1_3(x)

    return out



  def weight_init(self):

    def normal_init(m):
      # Conv2d, ConvTranspose2d
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias.data is not None:
          m.bias.data.zero_()
        print(m, ': Weights initialized')

    for block in self._modules:
      try:
        for m in self._modules[block]:
          normal_init(m)
      except:
        pass

# **Test**

## On pictures in folders

In [None]:
PATH = '/content/gdrive/MyDrive/VanGoghGAN_Landscapes/Test'
#folders = sorted(os.listdir(PATH))
folders = ['Constable', 'Drawings', 'Lucien Pissarro', 'Matisse', 'Photos', 'Rousseau', 'Sargent', 'Seurat', 'Van Gogh Drawings']
print(folders)

['Constable', 'Drawings', 'Lucien Pissarro', 'Matisse', 'Photos', 'Rousseau', 'Sargent', 'Seurat', 'Van Gogh Drawings']


#### Pix2Pix

In [None]:
# to hide output of this cell
%%capture

CHECKPOINTS_FOLDER = '/content/gdrive/MyDrive/VanGoghGAN_Landscapes/Pix2Pix_Checkpoints'

pix2pix = Pix2Pix()
pix2pix.to(device)

epoch = 20
pix2pix.load_state_dict(torch.load(CHECKPOINTS_FOLDER+'/generator_{}.ckpt'.format(epoch), map_location=torch.device(device)))

In [None]:
for folder in folders:
  print(folder+'\n')
  sketch_dataset = torch.load(PATH+'/'+folder+'/'+folder+' Sketches.pt')
  for (i,filename) in enumerate(os.listdir(PATH+'/'+folder+'/Originals')):
    pred = scale_0_1(pix2pix(sketch_dataset[i:i+1].to(device)).squeeze())
    image = transforms.ToPILImage()(pred).convert('RGB')
    display(image)
    print(filename+'\n')
    image.save(PATH + '/' + folder + '/Images_Pix2Pix/' + filename)

#### StylePix2Pix

In [None]:
# to hide output of this cell
%%capture

CHECKPOINTS_FOLDER = '/content/gdrive/MyDrive/VanGoghGAN_Landscapes/StylePix2Pix_Checkpoints'

stylepix2pix = StylePix2Pix()
stylepix2pix.to(device)

epoch = 50
stylepix2pix.load_state_dict(torch.load(CHECKPOINTS_FOLDER+'/generator_{}.ckpt'.format(epoch), map_location=torch.device(device)))

In [None]:
for folder in folders:
  print(folder+'\n')
  sketch_dataset = torch.load(PATH+'/'+folder+'/'+folder+' Sketches.pt')
  for (i,filename) in enumerate(os.listdir(PATH+'/'+folder+'/Originals')):
    # 0: Monet
    pred = scale_0_1(stylepix2pix(sketch_dataset[i:i+1].to(device), torch.LongTensor([0]).to(device)).squeeze())
    image = transforms.ToPILImage()(pred).convert('RGB')
    display(image)
    print(filename+'\n')
    image.save(PATH + '/' + folder + '/Images_StylePix2Pix/' + filename[:-4] + ' 0_Monet.jpg')

    # 1: Van Gogh
    pred = scale_0_1(stylepix2pix(sketch_dataset[i:i+1].to(device), torch.LongTensor([1]).to(device)).squeeze())
    image = transforms.ToPILImage()(pred).convert('RGB')
    display(image)
    print(filename+'\n')
    image.save(PATH + '/' + folder + '/Images_StylePix2Pix/' + filename[:-4] + ' 1_Van_Gogh.jpg')

    # 2: Corot and Shishkin
    pred = scale_0_1(stylepix2pix(sketch_dataset[i:i+1].to(device), torch.LongTensor([2]).to(device)).squeeze())
    image = transforms.ToPILImage()(pred).convert('RGB')
    display(image)
    print(filename+'\n')
    image.save(PATH + '/' + folder + '/Images_StylePix2Pix/' + filename[:-4] + ' 2_Corot_Shishkin.jpg')

Output hidden; open in https://colab.research.google.com to view.

#### CycleGAN

In [None]:
# to hide output of this cell
%%capture

CHECKPOINTS_FOLDER = '/content/gdrive/MyDrive/VanGoghGAN_Landscapes/CycleGAN_Checkpoints'

cyclegan = CycleGAN()
cyclegan.to(device)

epoch = 102
cyclegan.load_state_dict(torch.load(CHECKPOINTS_FOLDER+'/generatorS2P_{}.ckpt'.format(epoch), map_location=torch.device(device)))

In [None]:
for folder in folders:
  print(folder+'\n')
  sketch_dataset = torch.load(PATH+'/'+folder+'/'+folder+' Sketches 3D.pt')
  for (i,filename) in enumerate(os.listdir(PATH+'/'+folder+'/Originals')):
    pred = scale_0_1(cyclegan(sketch_dataset[i:i+1].to(device)).squeeze())
    image = transforms.ToPILImage()(pred).convert('RGB')
    display(image)
    print(filename+'\n')
    image.save(PATH + '/' + folder + '/Images_CycleGAN/' + filename)