## Edge smooth

In [None]:
from GAN.edge_smooth import make_edge_smooth

style_dir = '/TRAINING/neon_img/neon_img_resize'
size = 256

make_edge_smooth(style_dir, size)

## Dataloader

In [None]:
from GAN.dataset import NeonDataSet
import argparse
from multiprocessing import cpu_count
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='neon_img', type=str)
parser.add_argument('--data_dir', default='TRAINING', type=str)
parser.add_argument('--batch_size', default=4, type=int)
parser.add_argument('--debug_samples', default=0, type=int)
parser.add_argument('--num_parallel_workers', default=1, type=int)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--init-epochs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=6)
parser.add_argument('--checkpoint-dir', type=str, default='SAVE_POINT/checkpoints')
parser.add_argument('--save-image-dir', type=str, default='SAVE_POINT/images')
parser.add_argument('--gan-loss', type=str, default='lsgan', help='lsgan / hinge / bce')
parser.add_argument('--resume', type=str, default='False')
parser.add_argument('--use_sn', action='store_true')
parser.add_argument('--save-interval', type=int, default=1)
parser.add_argument('--debug-samples', type=int, default=0)
parser.add_argument('--lr-g', type=float, default=2e-4)
parser.add_argument('--lr-d', type=float, default=4e-4)
parser.add_argument('--init-lr', type=float, default=1e-3)
parser.add_argument('--wadvg', type=float, default=10.0, help='Adversarial loss weight for G')
parser.add_argument('--wadvd', type=float, default=10.0, help='Adversarial loss weight for D')
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight')
parser.add_argument('--wgra', type=float, default=3.0, help='Gram loss weight')
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight')
parser.add_argument('--d-layers', type=int, default=3, help='Discriminator conv layers')
parser.add_argument('--d-noise', action='store_true')
args = parser.parse_args(args=[])
plt.figure()

def collate_fn(batch):
    img, neon, neon_gray, neon_smt_gray = zip(*batch)
    return (
        torch.stack(img, 0),
        torch.stack(neon, 0),
        torch.stack(neon_gray, 0),
        torch.stack(neon_smt_gray, 0),
    )

In [None]:
data_loader = DataLoader(
    NeonDataSet(args),
    batch_size=args.batch_size,
    #num_workers=cpu_count(),
    pin_memory=True,
    shuffle=True,
    collate_fn=collate_fn,
)

## GAN

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torch.optim as optim

import os
import cv2
import numpy as np
from tqdm import tqdm

from GAN.conv_blocks import DownConv
from GAN.conv_blocks import UpConv
from GAN.conv_blocks import SeparableConv2D
from GAN.conv_blocks import InvertedResBlock
from GAN.conv_blocks import ConvBlock
from GAN.utils import initialize_weights
from GAN.image_processing import denormalize_input
#from utils.common import load_checkpoint
from GAN.utils import save_checkpoint
from GAN.utils import set_lr

from GAN.losses import NeonGanLoss
from GAN.losses import LossSummary

In [None]:
gaussian_mean = torch.tensor(0.0)
gaussian_std = torch.tensor(0.1)

In [None]:
class Generator(nn.Module):
    def __init__(self, dataset=''):
        super(Generator, self).__init__()
        self.name = f'generator_{dataset}'
        bias = False

        self.encode_blocks = nn.Sequential(
            ConvBlock(3, 64, bias=bias),
            ConvBlock(64, 128, bias=bias),
            DownConv(128, bias=bias),
            ConvBlock(128, 128, bias=bias),
            SeparableConv2D(128, 256, bias=bias),
            DownConv(256, bias=bias),
            ConvBlock(256, 256, bias=bias),
        )

        self.res_blocks = nn.Sequential(
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
        )

        self.decode_blocks = nn.Sequential(
            ConvBlock(256, 128, bias=bias),
            UpConv(128, bias=bias),
            SeparableConv2D(128, 128, bias=bias),
            ConvBlock(128, 128, bias=bias),
            UpConv(128, bias=bias),
            ConvBlock(128, 64, bias=bias),
            ConvBlock(64, 64, bias=bias),
            nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.Tanh(),
        )

        initialize_weights(self)

    def forward(self, x):
        out = self.encode_blocks(x)
        out = self.res_blocks(out)
        img = self.decode_blocks(out)

        return img

class Discriminator(nn.Module):
    def __init__(self,  args):
        super(Discriminator, self).__init__()
        self.name = f'discriminator_{args.dataset}'
        self.bias = False
        channels = 32

        layers = [
            nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
            nn.LeakyReLU(0.2, True)
        ]

        for i in range(args.d_layers):
            layers += [
                nn.Conv2d(channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
                nn.InstanceNorm2d(channels * 4),
                nn.LeakyReLU(0.2, True),
            ]
            channels *= 4

        layers += [
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
        ]

        if args.use_sn:
            for i in range(len(layers)):
                if isinstance(layers[i], nn.Conv2d):
                    layers[i] = spectral_norm(layers[i])

        self.discriminate = nn.Sequential(*layers)

        initialize_weights(self)

    def forward(self, img):
        return self.discriminate(img)
    
def check_params(args):
    data_path = os.path.join(args.data_dir, args.dataset)
    if not os.path.exists(data_path):
        raise FileNotFoundError(f'Dataset not found {data_path}')

    if not os.path.exists(args.save_image_dir):
        print(f'* {args.save_image_dir} does not exist, creating...')
        os.makedirs(args.save_image_dir)

    if not os.path.exists(args.checkpoint_dir):
        print(f'* {args.checkpoint_dir} does not exist, creating...')
        os.makedirs(args.checkpoint_dir)

    assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'


def save_samples(generator, loader, args, max_imgs=2, subname='gen'):
    '''
    Generate and save images
    '''
    generator.eval()

    max_iter = (max_imgs // args.batch_size) + 1
    fake_imgs = []

    for i, (img, *_) in enumerate(loader):
        with torch.no_grad():
            fake_img = generator(img. to(device))
            fake_img = fake_img.detach().cpu().numpy()
            # Channel first -> channel last
            fake_img  = fake_img.transpose(0, 2, 3, 1)
            fake_imgs.append(denormalize_input(fake_img, dtype=np.int16))

        if i + 1 == max_iter:
            break

    fake_imgs = np.concatenate(fake_imgs, axis=0)

    for i, img in enumerate(fake_imgs):
        save_path = os.path.join(args.save_image_dir, f'{subname}_{i}.jpg')
        cv2.imwrite(save_path, img[..., ::-1])

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
G = Generator(args.dataset).to(device)
D = Discriminator(args).to(device)

loss_tracker = LossSummary()
loss_fn = NeonGanLoss(args)

optimizer_g = optim.Adam(G.parameters(), lr=args.lr_g, betas=(0.5, 0.999))
optimizer_d = optim.Adam(D.parameters(), lr=args.lr_d, betas=(0.5, 0.999))

start_e = 0

for e in range(start_e, args.epochs):
    print(f"Epoch {e}/{args.epochs}")
    bar = tqdm(data_loader)
    G.train()

    init_losses = []

    if e < args.init_epochs:
        # Train with content loss only
        set_lr(optimizer_g, args.init_lr)
        for img, *_ in bar:
            img = img.to(device)

            optimizer_g.zero_grad()

            fake_img = G(img)
            loss = loss_fn.content_loss_vgg(img, fake_img)
            loss.backward()
            optimizer_g.step()

            init_losses.append(loss.cpu().detach().numpy())
            avg_content_loss = sum(init_losses) / len(init_losses)
            bar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}')

        set_lr(optimizer_g, args.lr_g)
        save_checkpoint(G, optimizer_g, e, args, posfix='_init')
        save_samples(G, data_loader, args, subname='initg')
        continue

    loss_tracker.reset()
    for img, neon, neon_gray, neon_smt_gray in bar:
        # To device
        img = img.to(device)
        neon = neon.to(device)
        neon_gray = neon_gray.to(device)
        neon_smt_gray = neon_smt_gray.to(device)

        # ---------------- TRAIN D ---------------- #
        optimizer_d.zero_grad()
        fake_img = G(img).detach()

        # Add some Gaussian noise to images before feeding to D
        if args.d_noise:
            fake_img += gaussian_noise()
            neon += gaussian_noise()
            neon_gray += gaussian_noise()
            neon_smt_gray += gaussian_noise()

        fake_d = D(fake_img)
        real_neon_d = D(neon)
        real_neon_gray_d = D(neon_gray)
        real_neon_smt_gray_d = D(neon_smt_gray)

        loss_d = loss_fn.compute_loss_D(
            fake_d, real_neon_d, real_neon_gray_d, real_neon_smt_gray_d)

        loss_d.backward()
        optimizer_d.step()

        loss_tracker.update_loss_D(loss_d)

        # ---------------- TRAIN G ---------------- #
        optimizer_g.zero_grad()

        fake_img = G(img)
        fake_d = D(fake_img)

        adv_loss, con_loss, gra_loss, col_loss = loss_fn.compute_loss_G(
            fake_img, img, fake_d, neon_gray)

        loss_g = adv_loss + con_loss + gra_loss + col_loss

        loss_g.backward()
        optimizer_g.step()

        loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)

        avg_adv, avg_gram, avg_color, avg_content = loss_tracker.avg_loss_G()
        avg_adv_d = loss_tracker.avg_loss_D()
        bar.set_description(f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}')

    if e % args.save_interval == 0:
        save_checkpoint(G, optimizer_g, e, args)
        save_checkpoint(D, optimizer_d, e, args)
        save_samples(G, data_loader, args)
