**PatchGan**

In [2]:
import torch
import torch.nn as nn

class Block(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace=True),
    )
  def forward(self, x):
    return self.conv(x)
class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.init = nn.Sequential(
        nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(features[0], inplace=True),
    )
    in_channels=features[0]
    layers=[]
    for i in features[1:]:
      layers.append(Block(in_channels, i, stride=1 if i == features[-1] else 2))
      in_channels = i
    
    layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")) #BatchSize x 30 x 30
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    x = self.init(x)
    x = self.model(x)
    return torch.sigmoid(x)




In [3]:
def test():
  x = torch.randn([1,3,256,256])
  y = torch.randn([1,3,256,256])
  model = Discriminator()
  preds = model(x)
  return preds.shape

if __name__ == "__main__":
  print(test())

torch.Size([1, 1, 30, 30])


**Generator**

In [4]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, **kwargs) if down
        else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if act else nn.Identity(),
    )

  def forward(self, x):
    return self.conv(x)

class Residual(nn.Module):
  def __init__(self, channel):
      super().__init__()
      self.res = nn.Sequential(
          CNNBlock(channel, channel, kernel_size=3, stride=1, padding=1),
          CNNBlock(channel, channel, act=False, kernel_size=3, stride=1, padding=1, padding_mode="reflect"),
      )
  def forward(self, x):
    x = x + self.res(x)
    return x

class Generator(nn.Module):
  def __init__(self, img_channels, num_res=9):
      super().__init__()
      self.init = nn.Sequential(
          nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
          nn.InstanceNorm2d(64),
          nn.ReLU(inplace=True),
      )

      self.down = nn.ModuleList(
          [
           CNNBlock(64, 128, kernel_size=3, stride=2, padding=1, padding_mode="reflect"),
           CNNBlock(128, 256, kernel_size=3, stride=2, padding=1, padding_mode="reflect"),
          ]
      ) 

      self.resi = nn.Sequential(
          *[Residual(256) for _ in range(num_res)]
      )

      self.up = nn.ModuleList(
          [
           CNNBlock(256, 128, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
           CNNBlock(128, 64, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
          ]
      ) 

      self.last = nn.Conv2d(64, img_channels,kernel_size=7, stride=1, padding=3, padding_mode="reflect")

  def forward(self, x):
    x = self.init(x)
    for down_layer in self.down:
      x = down_layer(x)
    x = self.resi(x)
    for up_layer in self.up:
      x = up_layer(x)
    x = self.last(x)
    return torch.tanh(x)

def test_gen():
  x = torch.randn([1,3,256,256])
  model = Generator(img_channels=3, num_res=9)
  preds = model(x)
  print(model)
  print(preds.shape)

if __name__ == "__main__":
  test_gen()


Generator(
  (init): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), padding_mode=reflect)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
  )
  (down): ModuleList(
    (0): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
    (1): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
  )
  (resi): Sequential(
    (0): Residual(
      (res): Sequential(
        (0): CNNBlock(
          (conv): Sequential(
            (0)

In [5]:
!pip install albumentations==1.0.3

Collecting albumentations==1.0.3
  Downloading albumentations-1.0.3-py3-none-any.whl (98 kB)
[?25l[K     |███▎                            | 10 kB 38.7 MB/s eta 0:00:01[K     |██████▋                         | 20 kB 15.8 MB/s eta 0:00:01[K     |██████████                      | 30 kB 9.5 MB/s eta 0:00:01[K     |█████████████▎                  | 40 kB 8.0 MB/s eta 0:00:01[K     |████████████████▋               | 51 kB 4.3 MB/s eta 0:00:01[K     |████████████████████            | 61 kB 5.1 MB/s eta 0:00:01[K     |███████████████████████▎        | 71 kB 5.2 MB/s eta 0:00:01[K     |██████████████████████████▌     | 81 kB 5.3 MB/s eta 0:00:01[K     |█████████████████████████████▉  | 92 kB 5.9 MB/s eta 0:00:01[K     |████████████████████████████████| 98 kB 4.2 MB/s 
Collecting opencv-python-headless>=4.1.1
  Downloading opencv_python_headless-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (47.8 MB)
[K     |████████████████████████████████| 47.8 MB 130 

In [6]:
!pip install opencv-python-headless==4.1.2.30

Collecting opencv-python-headless==4.1.2.30
  Downloading opencv_python_headless-4.1.2.30-cp37-cp37m-manylinux1_x86_64.whl (21.8 MB)
[K     |████████████████████████████████| 21.8 MB 4.4 MB/s 
Installing collected packages: opencv-python-headless
  Attempting uninstall: opencv-python-headless
    Found existing installation: opencv-python-headless 4.5.5.64
    Uninstalling opencv-python-headless-4.5.5.64:
      Successfully uninstalled opencv-python-headless-4.5.5.64
Successfully installed opencv-python-headless-4.1.2.30


In [7]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

**DataSet**

In [24]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "../content/drive/MyDrive/AI/maps/maps/train"
VAL_DIR = "../content/drive/MyDrive/AI/maps/maps/val"
BATCH_SIZE = 16
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4

L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

both_transform = A.Compose( 
    [A.Resize(width=256, height=256),
    A.HorizontalFlip(p=0.5),],
    additional_targets={"image0":"image"},)
transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)


In [25]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import sys
class MapDataset():
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files =os.listdir(self.root_dir)
    print(self.list_files)
  def __len__(self):
    return(len(self.list_files))
  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:,:600,:] #channel, width, height
    target_image = image[:,600:,:]
    augmentations = both_transform(image=input_image, image0=target_image)
    input_image, target_image=augmentations["image"], augmentations["image0"]
    input_image = transform_only_input(image=input_image)["image"]
    target_image = transform_only_mask(image=target_image)["image"]
    return input_image, target_image




In [26]:
if __name__ == "__main__":
    dataset = MapDataset("../content/drive/MyDrive/AI/maps/maps/train")
    loader = DataLoader(dataset, batch_size=5)
    for x,y in loader:
      print(x.shape)
      save_image(x, "x.png")
      save_image(y, "y.png")
      sys.exit()

['1066.jpg', '1080.jpg', '1095.jpg', '149.jpg', '1032.jpg', '168.jpg', '143.jpg', '1058.jpg', '139.jpg', '1086.jpg', '1012.jpg', '1079.jpg', '1006.jpg', '11.jpg', '117.jpg', '1085.jpg', '119.jpg', '1048.jpg', '1040.jpg', '137.jpg', '1017.jpg', '1090.jpg', '106.jpg', '1023.jpg', '1077.jpg', '1074.jpg', '1073.jpg', '153.jpg', '1071.jpg', '1061.jpg', '108.jpg', '157.jpg', '1087.jpg', '1091.jpg', '112.jpg', '123.jpg', '128.jpg', '10.jpg', '1045.jpg', '1037.jpg', '1069.jpg', '1083.jpg', '1094.jpg', '1036.jpg', '1.jpg', '1056.jpg', '166.jpg', '113.jpg', '161.jpg', '141.jpg', '145.jpg', '150.jpg', '1000.jpg', '15.jpg', '1075.jpg', '1038.jpg', '1020.jpg', '100.jpg', '1051.jpg', '1049.jpg', '1053.jpg', '1039.jpg', '12.jpg', '167.jpg', '110.jpg', '105.jpg', '124.jpg', '1005.jpg', '1025.jpg', '133.jpg', '1008.jpg', '1022.jpg', '154.jpg', '104.jpg', '101.jpg', '1076.jpg', '121.jpg', '107.jpg', '132.jpg', '158.jpg', '204.jpg', '317.jpg', '327.jpg', '246.jpg', '173.jpg', '29.jpg', '297.jpg', '32.jpg

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [27]:
# example of calculating the frechet inception distance in Keras for cifar10
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import shuffle
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from skimage.transform import resize
def scale_images(images, new_shape):
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
		images_list.append(new_image)
	return asarray(images_list)
 
# calculate frechet inception distance
def calculate_fid(model, images1, images2):
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)
	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
	# calculate sum squared difference between means
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid
 


**train**


In [28]:
import torch.optim as optim
from tqdm import tqdm
import torch
torch.cuda.empty_cache()

def train_fn(dis_X, dis_Y, gen_X, gen_Y, opt_dis, opt_gen, L1_LOSS, MSE_LOSS, train_loader, g_scaler, d_scaler):
  loop = tqdm(train_loader, leave=True)
  
  for idx, (x, y) in enumerate(loop):
    x = x.to(DEVICE)
    y = y.to(DEVICE)


    #train discriminator of x and y

    with torch.cuda.amp.autocast():
      fake_y = gen_Y(x)

      D_Y_real = dis_Y(y)
      D_Y_fake = dis_Y(fake_y.detach())

      D_Y_real_loss = MSE_LOSS(D_Y_real, torch.ones_like(D_Y_real))
      D_Y_fake_loss = MSE_LOSS(D_Y_fake, torch.zeros_like(D_Y_fake))

      D_Y_loss = D_Y_real_loss + D_Y_fake_loss

      fake_x = gen_X(y)

      D_X_real = dis_X(x)
      D_X_fake = dis_X(fake_x.detach())

      D_X_real_loss = MSE_LOSS(D_X_real, torch.ones_like(D_X_real))
      D_X_fake_loss = MSE_LOSS(D_X_fake, torch.zeros_like(D_X_fake))

      D_X_loss = D_X_real_loss + D_X_fake_loss

      D_loss = (D_Y_loss + D_X_loss)/2

    opt_dis.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_dis)
    d_scaler.update()


    #Generator Loss
    with torch.cuda.amp.autocast():
      D_Y_fake = dis_Y(fake_y)
      D_X_fake = dis_X(fake_x)

      loss_G_X = MSE_LOSS(D_X_fake, torch.ones_like(D_X_fake))
      loss_G_Y = MSE_LOSS(D_Y_fake, torch.ones_like(D_Y_fake))


      #cycle loss
      cycle_x = gen_X(fake_y)
      cycle_y = gen_Y(fake_x)

      CYCLE_X_loss = L1_LOSS(x, cycle_x)
      CYCLE_Y_loss = L1_LOSS(y, cycle_y)



      G_loss = (loss_G_X + loss_G_Y + CYCLE_X_loss * LAMBDA_CYCLE + CYCLE_Y_loss * LAMBDA_CYCLE)
    
    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    if idx % 50 == 0:
      save_image(fake_y*0.5+0.5, f"../content/drive/MyDrive/AI/maps/cycleganeval/fake_y_{idx}.png")
      save_image(y*0.5+0.5, f"../content/drive/MyDrive/AI/maps/cycleganeval/real_y_{idx}.png")
      save_image(x*0.5+0.5, f"../content/drive/MyDrive/AI/maps/cycleganeval/x_{idx}.png")


def train():
  dis_X = Discriminator(in_channels=3).to(DEVICE)
  dis_Y = Discriminator(in_channels=3).to(DEVICE)

  gen_X = Generator(img_channels=3, num_res=9).to(DEVICE)
  gen_Y = Generator(img_channels=3, num_res=9).to(DEVICE)

  opt_dis = optim.Adam(
        list(dis_X.parameters()) + list(dis_Y.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

  opt_gen = optim.Adam(
        list(gen_X.parameters()) + list(gen_Y.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )
  L1_LOSS = nn.L1Loss()
  MSE_LOSS = nn.MSELoss()

  train_dataset = MapDataset(root_dir=TRAIN_DIR)
  train_loader = DataLoader(
      train_dataset,
      batch_size=BATCH_SIZE,
      shuffle=True,
      num_workers=NUM_WORKERS,
  )

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  val_dataset = MapDataset(root_dir=VAL_DIR)
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


  for epoch in range(NUM_EPOCHS):
    train_fn(dis_X, dis_Y, gen_X, gen_Y, opt_dis, opt_gen, L1_LOSS, MSE_LOSS, train_loader, g_scaler, d_scaler)

if __name__ == "__main__":
  train()

['1066.jpg', '1080.jpg', '1095.jpg', '149.jpg', '1032.jpg', '168.jpg', '143.jpg', '1058.jpg', '139.jpg', '1086.jpg', '1012.jpg', '1079.jpg', '1006.jpg', '11.jpg', '117.jpg', '1085.jpg', '119.jpg', '1048.jpg', '1040.jpg', '137.jpg', '1017.jpg', '1090.jpg', '106.jpg', '1023.jpg', '1077.jpg', '1074.jpg', '1073.jpg', '153.jpg', '1071.jpg', '1061.jpg', '108.jpg', '157.jpg', '1087.jpg', '1091.jpg', '112.jpg', '123.jpg', '128.jpg', '10.jpg', '1045.jpg', '1037.jpg', '1069.jpg', '1083.jpg', '1094.jpg', '1036.jpg', '1.jpg', '1056.jpg', '166.jpg', '113.jpg', '161.jpg', '141.jpg', '145.jpg', '150.jpg', '1000.jpg', '15.jpg', '1075.jpg', '1038.jpg', '1020.jpg', '100.jpg', '1051.jpg', '1049.jpg', '1053.jpg', '1039.jpg', '12.jpg', '167.jpg', '110.jpg', '105.jpg', '124.jpg', '1005.jpg', '1025.jpg', '133.jpg', '1008.jpg', '1022.jpg', '154.jpg', '104.jpg', '101.jpg', '1076.jpg', '121.jpg', '107.jpg', '132.jpg', '158.jpg', '204.jpg', '317.jpg', '327.jpg', '246.jpg', '173.jpg', '29.jpg', '297.jpg', '32.jpg

100%|██████████| 69/69 [00:27<00:00,  2.50it/s]
100%|██████████| 69/69 [00:27<00:00,  2.54it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.54it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.52it/s]
100%|██████████| 69/69 [00:27<00:00,  2.53it/s]
100%|██████████| 69/69 [00:27<00:00,  2.

In [29]:
path_fake_y = "../content/drive/MyDrive/AI/maps/cycleganeval/fake_y_0.png"
path_real_y = "../content/drive/MyDrive/AI/maps/cycleganeval/real_y_0.png"   
# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))
images1 = np.array(Image.open(path_fake_y)).astype('float32')
images2 = np.array(Image.open(path_real_y)).astype('float32')
# resize images
images1 = scale_images(images1, (299,299,3))
images2 = scale_images(images2, (299,299,3))
# pre-process images
images1 = preprocess_input(images1)
images2 = preprocess_input(images2)
# calculate fid
fid = calculate_fid(model, images1, images2)

print('FID: %.3f' % fid)

FID: 22.973


In [None]:
from google.colab import files
import matplotlib.pyplot as plt


files.download("../content/saved_images/y_0.png")  

files.download("../content/saved_images/y_50.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>