In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import datetime

import torch
from torch import nn
from torch.nn import functional as F

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets

from PIL import Image

import datetime
import sys
sys.path.append('..')

In [None]:
args = {
    'epochs': 200,
    'img_size': 256,
    'batch_size': 32,
    'cuda': True if torch.cuda.is_available() else False,
    'sample_interval': 20,
    'checkpoint_interval': -1,
    'dataset_name': 'facade',

    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
}

if torch.cuda.is_available():
    args['device'] = torch.device('cuda')
else:
    args['device'] = torch.device('cpu')

print(args['device'])

In [None]:
# # ## download dos dados Facades
# !mkdir 'facades'

# !wget -N https://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_base.zip
# !unzip -o CMP_facade_DB_base.zip -d ./facades/

# !wget -N https://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_extended.zip
# !unzip -o CMP_facade_DB_extended.zip -d ./facades/

In [None]:
# Função que recebe um "mapa de classes", como os mostrados acima, para uma imagem colorida
# cada classe é mapeada para uma cor diferente, de acordo com o dicionário "color_map"
def label2color_map(mask, color_map):
    h, w = mask.shape
    mask_img = np.zeros(shape=[h, w, 3])

    for i in np.unique(mask):
        mask_img[np.where(mask == i)] = color_map[i]

    return (mask_img.transpose(2, 0, 1) / 127.5) - 1.0

# Implementando um Dataset personalizado para ler os dados que acabamos de fazer o download
class FacadesDataset(Dataset):
    def __init__(self, root = '/pgeoprj/ciag2023/datasets/facades/base/', transforms_= None, max_label=12):
        self.transform = transforms_
        self.root = root
        self.max_label = max_label

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) )
        ])

        ## Aqui criamos um mapa de cores, que determina a cor que cada classe tem na máscara de classes
        ## Esses valores podem ser alterados livremente, porém vale a pena sempre verificar se não existem
        ## duas classes diferentes com a mesma cor (ou cores muito parecidas.)
        self.color_maps = {
            1: [0, 0, 0],
            2: [254, 127, 45],
            3: [252, 202, 70],
            4: [161, 193, 129],
            5: [97, 155, 138],
            6: [35, 61, 77],
            7: [120, 0, 0],
            8: [193, 18, 31],
            9: [253, 240, 213],
            10: [102, 155, 188],
            11: [254, 109, 115],
            12: [106, 153, 78]
        }

        files_list = os.listdir(root)
        files_list = [s.split('.')[0] for s in files_list]
        self.files_list = np.unique(files_list)[1:]

    def __getitem__(self, idx):
        img_file = self.files_list[idx] + '.jpg'
        mask_file = self.files_list[idx] + '.png'

        img = Image.open( os.path.join(self.root, img_file) )
        mask = Image.open( os.path.join(self.root, mask_file) )

        if self.transform is not None:
            img = self.normalize( self.transform(img) )
            mask = self.transform(mask)

        img = np.array(img)
        mask = label2color_map(np.array(mask), self.color_maps)

        return img.astype(np.float32), mask.astype(np.float32)

    def __len__(self):
        return len(self.files_list)

transforms_ = transforms.Compose([
    transforms.Resize( size=(args['img_size'], args['img_size']) ),
])

facades_data = FacadesDataset(transforms_=transforms_)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


In [None]:
facades_dataloader = DataLoader(dataset = facades_data,
                                batch_size = args['batch_size'],
                                shuffle = True,
                                num_workers = 1)

In [None]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

In [None]:
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, args['img_size'] // 2 ** 4, args['img_size'] // 2 ** 4)

In [None]:
from models import GeneratorUNet, Discriminator
generator = GeneratorUNet().to(args['device'])
discriminator = Discriminator().to(args['device'])

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal);

In [None]:
def sample_images(data_loader):
    """Saves a generated sample from the validation set"""
    imgs_real, mask = next( iter(data_loader) )
    imgs_real = imgs_real.to(args['device'])
    mask = mask.to(args['device'])

    imgs_fake = generator(mask)

    fig, ax = plt.subplots(nrows=min(imgs_real.size(0), 2), ncols=3, figsize=(12, 8))

    for i in range( min(imgs_real.size(0), 2) ):

        ax[i, 0].imshow(imgs_real.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 0].set_yticks([])
        ax[i, 0].set_xticks([])
        ax[i, 0].set_title('Goal')

        ax[i, 1].imshow(imgs_fake.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 1].set_yticks([])
        ax[i, 1].set_xticks([])
        ax[i, 1].set_title('Generated')

        ax[i, 2].imshow(mask.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 2].set_yticks([])
        ax[i, 2].set_xticks([])
        ax[i, 2].set_title('Mask')

    plt.show()

In [None]:
sample_images(facades_dataloader)

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam( generator.parameters(),
                               lr = args['lr'],
                               betas = (args['b1'], args['b2']) )

optimizer_D = torch.optim.Adam( discriminator.parameters(),
                               lr = args['lr'],
                               betas = (args['b1'], args['b2']) )

In [None]:
################################################################################
#  Training ####################################################################
################################################################################

prev_time = time.time()

for epoch in range(1, args['epochs'] + 1):

    for i, (real_img, mask) in enumerate(facades_dataloader):

        # Transfer images and masks to GPU
        real_img = real_img.to(args['device'])
        mask = mask.to(args['device'])

        # Adversarial ground truths
        y_true = torch.ones(size=(real_img.size(0), *patch), requires_grad=False).to(args['device'])
        y_fake = torch.zeros(size=(real_img.size(0), *patch), requires_grad=False).to(args['device'])

        # ------------------
        #  Train Generators
        # ------------------

        # Clearing gradients for G optimizer.
        optimizer_G.zero_grad()

        # GAN loss.
        fake_img = generator(mask)
        pred_fake = discriminator(fake_img, mask)
        loss_GAN = criterion_GAN(pred_fake, y_true)

        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_img, real_img)

        # Total loss
        loss_G = (loss_GAN + lambda_pixel * loss_pixel) / (1 + lambda_pixel)

        # G backward and optimizer step.
        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Clearing gradients for D optimizer.
        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_img, mask)
        loss_real = criterion_GAN(pred_real, y_true)

        # Fake loss
        pred_fake = discriminator(fake_img.detach(), mask)
        loss_fake = criterion_GAN(pred_fake, y_fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        # D backward and optimizer step.
        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(facades_dataloader) + i
        batches_left = args['epochs'] * len(facades_dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # If at sample interval save image
        if batches_done % args['sample_interval'] == 0:

            # Print log
            print(f'[Epoch {epoch}/{ args["epochs"] }] [Batch {i}/{len(facades_dataloader)}] [D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}, pixel: {loss_pixel.item():.4f}, adv: {loss_GAN.item():.4f}] ETA: {time_left}')
            sample_images(facades_dataloader)