In [2]:
from torch.utils.data import DataLoader, random_split
from utils.preprocess import FaceCompletionDataset
from utils.network_seq_contour import Parser
from utils.models import Generator, Discriminator
import torchvision.models as models
import torch
import copy
import pandas as pd
import matplotlib.pyplot as plt
import os
import cv2
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor



In [3]:
image_dir = "preprocessed_images/"
save_path = './model_trained/'
celeba_dataset = FaceCompletionDataset(image_dir)

# train-test split
train_size = int(0.8 * len(celeba_dataset))  # Use 80% of the dataset for training
val_size = len(celeba_dataset) - train_size
train_dataset, val_dataset = random_split(celeba_dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
validation_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

import joblib
joblib.dump(val_dataset, 'val_dataset.pkl') ## preserve val set
joblib.dump(train_dataset, 'train_dataset.pkl')

In [29]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [4]:
import yaml
from argparse import ArgumentParser, Namespace

## for visualize result
def return_image_numpy(images, scale=False):
    if scale:
        completed_images_np = images[0].cpu().detach().numpy()
 
        completed_images_np = ((completed_images_np + 1) * 127.5)
        print(completed_images_np.shape)
        return completed_images_np.transpose(1, 2, 0)
    else:
        return images[0].detach().cpu().numpy().transpose(1, 2, 0)

def adjust_learning_rate(optimizer, gamma=0.1, num_steps=1):
    for i in range(num_steps):
        for param_group in optimizer.param_groups:
            param_group['lr'] *= gamma
    
## for loading parsing network for perceptual loss
def get_config(config_file):
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    return config

def get_args():
    parser = ArgumentParser()
    parser.add_argument('--config', type=str, default='config/seg_config.yaml', help="training configuration")
    parser.add_argument('--seed', type=int, default=2023, help='manual seed')

    try:
        args = parser.parse_args()
    except SystemExit:
        args = Namespace(config='config/seg_config.yaml', seed=2023)
    return args

def load_face_parsing_model(model_path):    
    netG.load_state_dict(torch.load(model_path), strict=False)
    netG.eval()
    return netG

## load parser network
args = get_args()
config = get_config(args.config)
netG = Parser(config)

# load parsing model
face_parsing_model = load_face_parsing_model('pretrained_model/parser_00100000.pt')
face_parsing_model = face_parsing_model.to(device)
face_parsing_model.eval()

usage: ipykernel_launcher.py [-h] [--config CONFIG] [--seed SEED]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/x1112373/.local/share/jupyter/runtime/kernel-7b2b1a6e-bc9c-4c33-b6be-90a56fcac0b1.json


Parser(
  (conv1_1): Conv2dBlock(
    (pad): ZeroPad2d((0, 0, 0, 0))
    (activation): ELU(alpha=1.0, inplace=True)
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
  )
  (conv1_2): Conv2dBlock(
    (pad): ZeroPad2d((0, 0, 0, 0))
    (activation): ELU(alpha=1.0, inplace=True)
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (conv2_1): Conv2dBlock(
    (pad): ZeroPad2d((0, 0, 0, 0))
    (activation): ELU(alpha=1.0, inplace=True)
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv2_2): Conv2dBlock(
    (pad): ZeroPad2d((0, 0, 0, 0))
    (activation): ELU(alpha=1.0, inplace=True)
    (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (conv3_1): Conv2dBlock(
    (pad): ZeroPad2d((0, 0, 0, 0))
    (activation): ELU(alpha=1.0, inplace=True)
    (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv3_2): Conv2dBlock(
    (pad): Z

In [33]:
class GANLoss(nn.Module):
    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.MSELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real).to(input.device)
        return self.loss(input, target_tensor)

In [34]:
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.models import Generator, Discriminator
from utils.networks import GatedGenerator

## initialize first model
generator = GatedGenerator()
discriminator_global = Discriminator(in_channels=3)
discriminator_local = Discriminator(in_channels=3)

## loss
criterion_context = nn.MSELoss(reduction='sum')
criterion_adv = GANLoss(target_real_label=0.9, target_fake_label=0.1)
criterion_parsing = nn.SmoothL1Loss()
criterion_rec = nn.SmoothL1Loss()
criterion_ssim = SSIM(window_size = 11)

# Create Optimizer
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_discriminator_global = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_discriminator_local = torch.optim.Adam(generator.parameters(), lr=0.0001)

# Create the schedulers - when there is scheduler, learning rate goes too small value
# scheduler_generator = ReduceLROnPlateau(optimizer_generator, mode='min', factor=0.1, patience=5, verbose=True)
# scheduler_discriminator_global = ReduceLROnPlateau(optimizer_discriminator_global, mode='min', factor=0.1, patience=5, verbose=True)
# scheduler_discriminator_local = ReduceLROnPlateau(optimizer_discriminator_local, mode='min', factor=0.1, patience=5, verbose=True)

# device setting
generator = generator.to(device)
discriminator_global = discriminator_global.to(device)
discriminator_local = discriminator_local.to(device)


In [36]:
def load_checkpoint(generator, discriminator_local, discriminator_global, path, device):
    state = torch.load(path,map_location=device)
    generator.load_state_dict(state['G'])
    discriminator_global.load_state_dict(state['D_G'])
    discriminator_local.load_state_dict(state['D_L'])
    print('Loaded checkpoint successfully')
    return generator, discriminator_local, discriminator_global

In [None]:
import random
import torch
import pandas as pd
from utils.evaluate import evaluate_models, evaluate_model_external
resume = True
num_epochs = 20
best_loss = float('inf')
patience = 5
lambda_gp = 10  # Gradient penalty weight
result_df = pd.DataFrame(columns= ['epoch', 'step', 'val_gen_loss', 'val_disc_global_loss', 'val_disc_local_loss', 'val_psnr', 'val_ssim' ])
generator.train()
discriminator_local.train()
discriminator_global.train()
for epoch in range(num_epochs):
    for i, (images, masks, masked_images) in enumerate(train_dataloader):
        # Resize the images and masks to a consistent size

        images = images.to(device)
        masked_images = masked_images.to(device)
        masks = masks.to(device)

        ## train discriminator
        optimizer_discriminator_global.zero_grad()
        optimizer_discriminator_local.zero_grad()
        optimizer_generator.zero_grad()

        first_out, second_out  = generator(images, masks)
        first_out_wholeimg = images * (1 - masks) + first_out * masks     
        second_out_wholeimg = images * (1 - masks) + second_out * masks 

        # Compute the adversarial loss for the generator using the global and local discriminators

        local_real_D = discriminator_local(images)
        local_fake_D = discriminator_local(second_out_wholeimg.detach())

        global_real_D = discriminator_global(images)
        global_fake_D = discriminator_global(second_out_wholeimg.detach())

        loss_local_fake_D = criterion_adv(local_fake_D, target_is_real=False)
        loss_local_real_D = criterion_adv(local_real_D, target_is_real=True)

        loss_global_fake_D = criterion_adv(global_fake_D, target_is_real=False)
        loss_global_real_D = criterion_adv(global_real_D, target_is_real=True)

        #gp = gradient_penalty(discriminator_global, images, completed_images)
        loss_d = (loss_local_real_D + loss_local_fake_D + loss_global_fake_D + loss_global_real_D) * 0.25 #+ lambda_gp * gp
        loss_d.backward(retain_graph=True)        

        optimizer_discriminator_global.step()
        optimizer_discriminator_local.step()

        local_real_output = None
        global_real_output = None

        # train generator
        optimizer_discriminator_global.zero_grad()
        optimizer_discriminator_local.zero_grad()
        optimizer_generator.zero_grad()

        # generator loss 
        local_fake_output = discriminator_local(second_out_wholeimg)
        global_fake_output = discriminator_global(second_out_wholeimg)
        fake_D = (local_fake_output + global_fake_output) * 0.5
        loss_G = criterion_adv(fake_D, target_is_real=True)

        local_fake_output = None
        global_fake_output = None
        fake_D = None


        # generator reconstruction loss
        # Reconstruction loss
        loss_l1_1 = criterion_rec(first_out_wholeimg, images)
        loss_l1_2 = criterion_rec(second_out_wholeimg, images)
        loss_ssim_1 = criterion_ssim(first_out_wholeimg, images)
        loss_ssim_2 = criterion_ssim(second_out_wholeimg, images)
        loss_rec_1 = 0.5 * loss_l1_1 + 0.5 * (1 - loss_ssim_1)
        loss_rec_2 = 0.5 * loss_l1_2 + 0.5 * (1 - loss_ssim_2)

        lambda_G = 1.0
        lambda_rec_1 = 100.0
        lambda_rec_2 = 100.0
        lambda_per = 10.0

        loss_P = criterion_parsing(face_parsing_model(second_out_wholeimg), face_parsing_model(images))
        loss_generator = lambda_G * loss_G + lambda_rec_1 * loss_rec_1 + lambda_rec_2 * loss_rec_2 + lambda_per * loss_P
        loss_generator.backward(retain_graph=True)
        optimizer_generator.step()
        if epoch == 5:
            adjust_learning_rate(optimizer_discriminator_global)
            adjust_learning_rate(optimizer_discriminator_local)
            adjust_learning_rate(optimizer_generator)
        if epoch == 10:
            adjust_learning_rate(optimizer_discriminator_global)
            adjust_learning_rate(optimizer_discriminator_local)
            adjust_learning_rate(optimizer_generator)
        if epoch == 15:
            adjust_learning_rate(optimizer_discriminator_global)
            adjust_learning_rate(optimizer_discriminator_local)
            adjust_learning_rate(optimizer_generator)


        if i % 10000 == 0:
            fig, axs = plt.subplots(1, 4, figsize=(15, 5))
            axs[0].imshow(return_image_numpy(images))
            axs[0].set_title("Original Image")
            axs[1].imshow(return_image_numpy(masks))
            axs[1].set_title("masks")
            axs[2].imshow(return_image_numpy(first_out_wholeimg))
            axs[2].set_title("first_out_wholeimg")
            axs[3].imshow(return_image_numpy(second_out_wholeimg))
            axs[3].set_title("second_out_wholeimg")

            for ax in axs:
                ax.set_xticks([])
                ax.set_yticks([])

            plt.show()
    
    # Evaluation
    # Inside your training loop, after each epoch
    print(f"Epoch: {epoch}, Step: {i}, Training Losses - Generator: {loss_generator.item()}") #, Discriminator Global: {loss_adv_global.item()}, Discriminator Local: {loss_adv_local.item()}")
    val_gen_loss, val_disc_global_loss, val_disc_local_loss = evaluate_models(generator, discriminator_global, discriminator_local, validation_dataloader, criterion_rec, criterion_ssim, face_parsing_model, device)
    print(f"Epoch: {epoch}, Validation Losses - Generator: {val_gen_loss:.8f}, Discriminator Global: {val_disc_global_loss:.8f}, Discriminator Local: {val_disc_local_loss:.8f}")
    val_psnr, val_ssim = evaluate_model_external(generator, validation_dataloader, device)
    print(f"Epoch: {epoch}, Validation PSNR: {val_psnr:.8f}, SSIM: {val_ssim:.8f}")
    if epoch == 0:
        result_df = pd.DataFrame([{'epoch': epoch, 'step': i, 'val_gen_loss': val_gen_loss, 'val_disc_global_loss': val_disc_global_loss, 'val_disc_local_loss': val_disc_local_loss, 'val_psnr': val_psnr, 'val_ssim': val_ssim}])
    else:
        aa = pd.DataFrame([{'epoch': epoch, 'step': i, 'val_gen_loss': val_gen_loss, 'val_disc_global_loss': val_disc_global_loss, 'val_disc_local_loss': val_disc_local_loss, 'val_psnr': val_psnr, 'val_ssim': val_ssim}])
        result_df = pd.concat([result_df, aa], axis=0)
        result_df.to_csv('result.csv') 
    # Update the best loss and save the model if necessary
    
    if (val_gen_loss + val_disc_global_loss) < best_loss:
        best_loss = (val_gen_loss + val_disc_global_loss)
        best_generator = copy.deepcopy(generator.state_dict())
        best_discriminator_global = copy.deepcopy(discriminator_global.state_dict())
        best_discriminator_local = copy.deepcopy(discriminator_local.state_dict())
        torch.save({
                'D_G': best_discriminator_global,
                'D_L': best_discriminator_local,
                'G': best_generator,
        }, os.path.join("model_trained", f"model_{epoch}.pth"))
        
        counter = 0
    else:
        counter += 1

    # Check for early stopping
    if counter >= patience:
        print("Early stopping triggered.")
        break
