In [None]:
!nvidia-smi

In [None]:
# !pip install torchsummary
# !pip install lpips

In [None]:
import sys

if ".." not in sys.path:
    sys.path.append("..")

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
from hydra.utils import instantiate

import argparse
from datetime import datetime
import os
import yaml
from omegaconf import OmegaConf
from tqdm import tqdm

# from modules.dataset import CityscapesDataset
from modules.dataset import StyleGANFaces, scale_width
from modules.loss import Pix2PixHDLoss
from utils import parse_config, get_lr_lambda, weights_init, freeze_encoder, show_tensor_images

In [None]:
def parse_arguments():
    parser = argparse.ArgumentParser()
#     parser.add_argument('-c', '--config', type=str, required=True)
    parser.add_argument('-c', '--config', type=str)
    parser.add_argument('-r', '--high_res', action='store_true', default=False)
    return parser.parse_args('')

## Create dataloader for generated faces and interpolated result.

In [None]:
import os
from collections.abc import Iterable
import glob
import numpy as np
from PIL import Image


In [None]:
dataset = StyleGANFaces(
    path_A="../../dataset/image-to-image/trainA",
    path_B="../../dataset/image-to-image/trainB",
    path_AtoB = "../../dataset/image-to-image/images_AtoB",
    path_BtoA = "../../dataset/image-to-image/images_BtoA"
)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=True, sampler=None,
    num_workers=0, collate_fn=None,
    pin_memory=True, drop_last=True,
)

In [None]:
for batch in dataloader:
    img_A, img_B, img_AtoB, img_BtoA = batch
    print(img_A.shape, img_B.shape)
    break

In [None]:
show_tensor_images(img_A)
show_tensor_images(img_B)
show_tensor_images(img_AtoB)
show_tensor_images(img_BtoA)



In [None]:
class Encoder(nn.Module):
    ''' Implements an encoder with instance-wise average pooling for feature mapping '''

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        base_channels: int = 64,
        n_layers: int = 4,
    ):
        super().__init__()

        self.out_channels = out_channels
        channels = base_channels

        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=0), 
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(inplace=True),
        ]

        # Downsampling layers
        for i in range(n_layers):
            layers += [
                nn.Conv2d(channels, channels // 2, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(channels // 2),
                nn.ReLU(inplace=True),
            ]
            channels //= 2
    
        # Upsampling layers
        for i in range(n_layers):
            layers += [
                nn.ConvTranspose2d(channels, channels * 2, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(channels * 2),
                nn.ReLU(inplace=True),
            ]
            channels *= 2

        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=0),
#             nn.Tanh(),
            nn.ReLU(inplace=True),
        ]

        self.layers = nn.Sequential(*layers)
        
        
#     def instancewise_average_pooling(self, x, inst):
#         '''
#         Applies instance-wise average pooling.

#         Given a feature map of size (b, c, h, w), the mean is computed for each b, c
#         across all h, w of the same instance
#         '''
#         x_mean = torch.zeros_like(x)
#         classes = torch.unique(inst, return_inverse=False, return_counts=False) # gather all unique classes present

#         for i in classes:
#             for b in range(x.size(0)):
#                 indices = torch.nonzero(inst[b:b+1] == i, as_tuple=False) # get indices of all positions equal to class i
#                 for j in range(self.out_channels):
#                     x_ins = x[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]
#                     mean_feat = torch.mean(x_ins).expand_as(x_ins)
#                     x_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] = mean_feat

#         return x_mean    
        

    def forward(self, x, inst=None):
        x = self.layers(x)
        return x

In [None]:
from modules.networks import VGG19
import lpips

import matplotlib.pyplot as plt

In [None]:
############# main() ###############
#

args = parse_arguments()
args.config = "notebook_lowres_custom.yml"

with open(args.config, 'r') as f:
    config = yaml.safe_load(f)
    config = OmegaConf.create(config)
    config = parse_config(config)
    
print(config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# encoder = instantiate(config.encoder).to(device).apply(weights_init)
# encoder = Encoder(in_channels=6, out_channels=3, base_channels=128).to(device).apply(weights_init)
# encoder =  VGG19().to(device)
# encoder = lpips.LPIPS(net='alex', spatial=True).to(device)
encoder = lpips.LPIPS(net='vgg', spatial=True).to(device)
generator = instantiate(config.generator).to(device).apply(weights_init)
discriminator = instantiate(config.discriminator).to(device).apply(weights_init)

# summary(encoder, (6, 256, 256))
# summary(generator, (9, 256, 256))
# summary(discriminator, (9, 256, 256))

In [None]:
# summary(encoder, (6, 256, 256))
encoder = lpips.LPIPS(net='vgg', spatial=True).to(device)
x = encoder(img_B.to(device), img_BtoA.to(device))

In [None]:
print(x.shape)
print(img_A.shape)
show_tensor_images(x)
plt.imshow(x[0,0,...].data.cpu().numpy())

In [None]:
# show_tensor_images(x[0].squeeze()[1:11])
for a in x[0]:
#     for b in a:
    print(a.shape)
    show_tensor_images(a[1:3])
    break
#     print(a.shape)
    break

In [None]:
# b_enc = encoder(img_B.cuda())
# print(b_enc.shape)
# print(img_A.shape)

# x = torch.cat([img_A.cuda(), b_enc], dim=1)
# print(x.shape)

# x = generator(x)
# print(x.shape)

# y = discriminator(x)
# print()
# for dis in y:
#     for b in dis:
#         print(b.shape)

# # show_tensor_images(b_enc)
# # show_tensor_images(x)


In [None]:
if args.high_res:
    g_optimizer = torch.optim.Adam(
        list(generator.parameters()), **config.optim,
    )
else:
    g_optimizer = torch.optim.Adam(
        list(generator.parameters()) + list(encoder.parameters()), **config.optim,
    )
d_optimizer = torch.optim.Adam(list(discriminator.parameters()), **config.optim)
g_scheduler = torch.optim.lr_scheduler.LambdaLR(
    g_optimizer,
    get_lr_lambda(config.train.epochs, config.train.decay_after),
)
d_scheduler = torch.optim.lr_scheduler.LambdaLR(
    d_optimizer,
    get_lr_lambda(config.train.epochs, config.train.decay_after),
)



In [None]:
start_epoch = 0
if config.resume_checkpoint is not None:
    state_dict = torch.load(config.resume_checkpoint)

    encoder.load_state_dict(state_dict['e_model_dict'])
    generator.load_state_dict(state_dict['g_model_dict'])
    discriminator.load_state_dict(state_dict['d_model_dict'])
    g_optimizer.load_state_dict(state_dict['g_optim_dict'])
    d_optimizer.load_state_dict(state_dict['d_optim_dict'])
    start_epoch = state_dict['epoch']

    msg = 'high-res' if args.high_res else 'low-res'
    print(f'Starting {msg} training from checkpoints')

elif args.high_res:
    state_dict = config.pretrain_checkpoint
    if state_dict is not None:
        encoder.load_state_dict(torch.load(state_dict['e_model_dict']))
        encoder = freeze_encoder(encoder)
        generator.g1.load_state_dict(torch.load(state_dict['g_model_dict']))
        print('Starting high-res training from pretrained low-res checkpoints')
    else:
        print('Starting high-res training from scratch (no valid checkpoint detected)')

else:
    print('Starting low-res training from random initialization')

In [None]:
import lpips

lpips_loss = lpips.LPIPS(net='alex').cuda()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.networks import VGG19

device='cuda'
lambda1=10. 
lambda2=10.
norm_weight_to_one=True

vgg = VGG19().to(device)
vgg_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

lambda0 = 1.0
# Keep ratio of composite loss, but scale down max to 1.0
scale = max(lambda0, lambda1, lambda2) if norm_weight_to_one else 1.0

lambda0 = lambda0 / scale
lambda1 = lambda1 / scale
lambda2 = lambda2 / scale


In [None]:
def vgg_loss(x_real, x_fake):
    ''' Computes perceptual loss with VGG network from real and fake images '''
    vgg_real = vgg(x_real)
    vgg_fake = vgg(x_fake)

    vgg_loss = 0.0
    for real, fake, weight in zip(vgg_real, vgg_fake, vgg_weights):
        vgg_loss += weight * F.l1_loss(real.detach(), fake)

    return vgg_loss


def fm_loss(real_preds, fake_preds):
    ''' Computes feature matching loss from nested lists of fake and real outputs from discriminator '''
    fm_loss = 0.0
    for real_features, fake_features in zip(real_preds, fake_preds):
        for real_feature, fake_feature in zip(real_features, fake_features):
            fm_loss += F.l1_loss(real_feature.detach(), fake_feature)
    
    return fm_loss


def adv_loss(discriminator_preds, is_real):
    ''' Computes adversarial loss from nested list of fakes outputs from discriminator '''
    target = torch.ones_like if is_real else torch.zeros_like

    adv_loss = 0.0
    for preds in discriminator_preds:
        pred = preds[-1]
        adv_loss += F.mse_loss(pred, target(pred))
    
    return adv_loss


def enc_loss(f_map, img_orig):
    return F.l1_loss(f_map, img_orig)


def forward_loss(img_A, img_B, img_AtoB, encoder, generator, discriminator):
    # Forward call of loss.
    #
    x_real = img_AtoB

#     feature_map = encoder(img_B)
    feature_map = encoder(torch.cat((img_A, img_B), dim=1))
    x_fake = generator(torch.cat((img_A, img_B, feature_map), dim=1))
#     print(feature_map.shape)
#     print(x_fake.shape)

    # Get necessary outputs for loss/backprop for both generator and discriminator
#     fake_preds_for_g = discriminator(x_fake)
#     fake_preds_for_d = discriminator(x_fake.detach())
#     real_preds_for_d = discriminator(x_real.detach())
    fake_preds_for_g = discriminator(torch.cat((img_A, img_B, x_fake), dim=1))
    fake_preds_for_d = discriminator(torch.cat((img_A, img_B, x_fake.detach()), dim=1))
    real_preds_for_d = discriminator(torch.cat((img_A, img_B, x_real.detach()), dim=1))

    g_loss = (
        lambda0 * adv_loss(fake_preds_for_g, False) + \
        lambda1 * fm_loss(real_preds_for_d, fake_preds_for_g) / discriminator.n_discriminators + \
        lambda2 * vgg_loss(x_fake, x_real)  + \
#         1.0 * enc_loss(feature_map, img_B) + \
        2.0 * lpips_loss(x_fake, img_AtoB)
    )

    d_loss = 0.5 * (
        adv_loss(real_preds_for_d, True) + \
        adv_loss(fake_preds_for_d, False)
    )

    return g_loss, d_loss, x_fake.detach()


In [None]:
for epoch in tqdm(range(0, 2)):
    # training epoch
    #
    mean_g_loss = 0.0
    mean_d_loss = 0.0
    epoch_steps = 0
    if not args.high_res:
        encoder.train()
    
    generator.train()
    discriminator.train()
    
    pbar = tqdm(dataloader, position=0, desc='train [G loss: -.----][D loss: -.----]')
    for batch in pbar:
        img_A, img_B, img_AtoB, img_BtoA = batch
        img_A = img_A.to(device)
        img_B = img_B.to(device)
        img_AtoB = img_AtoB.to(device)
        img_BtoA = img_BtoA.to(device)
        
        g_loss, d_loss, x_fake = forward_loss(
            img_A,
            img_B,
            img_AtoB,
            encoder,
            generator,
            discriminator
        )
        
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        mean_g_loss += g_loss.item()
        mean_d_loss += d_loss.item()
        epoch_steps += 1

        pbar.set_description(desc=f'train [G loss: {mean_g_loss/epoch_steps:.4f}][D loss: {mean_d_loss/epoch_steps:.4f}]')

    g_scheduler.step()
    d_scheduler.step()
    

In [None]:
if not args.high_res:
    encoder.eval()
generator.eval()
discriminator.eval()

for i, batch in enumerate(dataloader):
    img_A, img_B, img_AtoB, img_BtoA = batch
    img_A = img_A.to(device)
    img_B = img_B.to(device)
    img_AtoB = img_AtoB.to(device)
    img_BtoA = img_BtoA.to(device)
    
    if i==2: break
    
with torch.no_grad():
#     feature_map = encoder(img_B)
#     x_fake = generator(torch.cat((img_A, feature_map), dim=1))
    feature_map = encoder(torch.cat((img_A, img_B), dim=1))
    x_fake = generator(torch.cat((img_A, img_B, feature_map), dim=1))
    

show_tensor_images(img_A)
show_tensor_images(img_B)
show_tensor_images(img_AtoB)
show_tensor_images(img_BtoA)
show_tensor_images(x_fake)
show_tensor_images(feature_map)


In [None]:
2.0 * lpips_loss(x_fake, img_AtoB)

In [None]:
# initialize logging
loss = Pix2PixHDLoss(device=device)
# log_dir = os.path.join(train_config.log_dir, datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
# os.makedirs(log_dir, mode=0o775, exist_ok=False)

for epoch in range(start_epoch, config.epochs):
    # training epoch
    mean_g_loss = 0.0
    mean_d_loss = 0.0
    epoch_steps = 0
    if not high_res:
        encoder.train()
    generator.train()
    discriminator.train()
    pbar = tqdm(train_dataloader, position=0, desc='train [G loss: -.----][D loss: -.----]')
    for (x_real, labels, insts, bounds) in pbar:
        x_real = x_real.to(device)
        labels = labels.to(device)
        insts = insts.to(device)
        bounds = bounds.to(device)

        with torch.cuda.amp.autocast(enabled=(device=='cuda')):
            g_loss, d_loss, x_fake = loss(
                x_real, labels, insts, bounds, encoder, generator, discriminator,
            )

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        mean_g_loss += g_loss.item()
        mean_d_loss += d_loss.item()
        epoch_steps += 1

        pbar.set_description(desc=f'train [G loss: {mean_g_loss/epoch_steps:.4f}][D loss: {mean_d_loss/epoch_steps:.4f}]')

    if epoch+1 % train_config.save_every == 0:
        torch.save({
            'e_model_dict': encoder.state_dict(),
            'g_model_dict': generator.state_dict(),
            'd_model_dict': discriminator.state_dict(),
            'g_optim_dict': g_optimizer.state_dict(),
            'd_optim_dict': d_optimizer.state_dict(),
            'epoch': epoch,
        }, os.path.join(log_dir, f'epoch={epoch}.pt'))

    g_scheduler.step()
    d_scheduler.step()

