# **Install Dataset**

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz

In [None]:
!tar -xvf maps.tar.gz

In [None]:
!pip install sklego torch_snippets 


# Import **packages**

In [None]:

import itertools
from PIL import Image
from torch_snippets import *
from torchvision.utils import make_grid


In [None]:
from glob import glob
train_set = glob('maps/train/*.jpg')
val_set = glob('maps/val/*.jpg')

In [None]:
print(f'Train set size: {train_set.__len__()}\n Val set size: {val_set.__len__()}')

# **Define Dataset class**

In [None]:
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
both_transform = A.Compose(
    [A.Resize(width=256, height=256),A.HorizontalFlip(p=0.5)], additional_targets={"image0": "image"},
)
transform_only_input = A.Compose(
    [
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)
class MapDataset(Dataset):
    def __init__(self, images):

        self.list_files = images

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_path = self.list_files[index]
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

In [None]:
train_dataset = MapDataset(train_set)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataset = MapDataset(val_set)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

# **Visualize some examples**

In [None]:
images = []
for i_ in range(16):
  samples_input, samples_target = next(iter(train_loader))
  samples = torch.cat([samples_input, samples_target],axis=0)
  images.append(samples)
images = torch.cat(images,axis=0)
images = make_grid((images * 127.5) + 127.5).permute(1,2,0)
show(images)

# **Define Models**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
!pwd

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
from generator import Generator
from discriminator import Discriminator
generator = Generator(in_channels=3, features=64).to(device)
discriminator = Discriminator(in_channels=3).to(device)

In [None]:
generator

In [None]:
def gram_matrix(input):
    """
    A gram matrix is the result of multiplying a given matrix by its transposed matrix. 
    """
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):

    def __init__(self):
        super(StyleLoss, self).__init__()
        

    def forward(self, input_map, target_map):
        G = gram_matrix(input_map)
        target = gram_matrix(target_map).detach()
        loss = F.mse_loss(G, target)
        return loss

In [None]:
opt_disc = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999),)
opt_gen = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss() # BCEWithLogits
L1_LOSS = nn.L1Loss()
Style_loss = StyleLoss()

In [None]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
      
def discriminator_train_step(real_src, real_trg, fake_trg):
    #discriminator.train()
    with torch.cuda.amp.autocast():
        prediction_real = discriminator(real_src, real_trg)
        error_real = BCE(prediction_real, torch.ones_like(prediction_real))

        prediction_fake = discriminator(real_src, fake_trg.detach())
        error_fake = BCE(prediction_fake, torch.zeros_like(prediction_real))
        D_loss = (error_real + error_fake) / 2

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()
    return D_loss

def generator_train_step(real_src, real_trg, fake_trg):
    #discriminator.train()
    with torch.cuda.amp.autocast():
        prediction = discriminator(real_src, fake_trg)

        loss_GAN = BCE(prediction, torch.ones_like(prediction))
        loss_pixel = L1_LOSS(fake_trg, real_trg)
        loss_G = loss_GAN + lambda_pixel * loss_pixel
        if Style_loss is not None:
            style_loss = Style_loss(fake_trg, real_trg)
            loss_G += style_loss

    opt_gen.zero_grad()
    g_scaler.scale(loss_G).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    return loss_G

In [None]:
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

In [None]:
epochs = 200
lambda_pixel = 100

log = Report(epochs)

for epoch in range(epochs):
    N = len(train_loader)
    generator.train()
    discriminator.train()
    for bx, batch in enumerate(train_loader):
        real_src, real_trg = batch
        real_src, real_trg = real_src.to(device), real_trg.to(device)
        fake_trg = generator(real_src)
        
        errD = discriminator_train_step(real_src, real_trg, fake_trg)
        errG = generator_train_step(real_src, real_trg, fake_trg)
        log.record(pos=epoch+(1+bx)/N, errD=errD.item(), errG=errG.item(), end='\r')
    log.report_avgs(epoch+1)
    if epoch % 10 == 0:
      generator.eval()
      with torch.no_grad():
        images = []
        for i_ in range(16):
          data = next(iter(val_loader))
          real_src, real_trg = data
          real_src, real_trg = real_src.to(device), real_trg.to(device)
          fake_trg = generator(real_src)
          samples = torch.cat([real_src, fake_trg],axis=0)
          images.append(samples)
        images = torch.cat(images,axis=0)
        images = make_grid((images * 127.5) + 127.5).permute(1,2,0)
        show(images)

In [None]:
torch.save(generator, 'generator.pt')

In [None]:
torch.save(discriminator, 'discriminator.pt')