### Debugger

In [1]:
# Imports
import os
import time
import datetime
import random
import numpy as np
import shutil
import argparse
import torch
from torch.backends import cudnn
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.utils import data
from torchvision import transforms as T
from PIL import Image

#### Data Loader

In [2]:
# Data Loader
class ImageNet(data.Dataset):
    def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attrs = selected_attrs
        self.transform = transform
        self.mode = mode
        self.train_dataset = []
        self.test_dataset = []
        self.attr2idx = {}
        self.idx2attr = {}
        self.preprocess()

        if mode == 'train':
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)
    
    def preprocess(self):
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]
        all_attr_names = lines[1].split()
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i
            self.idx2attr[i] = attr_name

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]
            values = split[1:]

            label = []
            for attr_name in self.selected_attrs:
                idx = self.attr2idx[attr_name]
                label.append(values[idx] == '1')

            if (i+1) < 2000:
                self.test_dataset.append([filename, label])
            else:
                self.train_dataset.append([filename, label])

        print("[INFO] Finished processing the 'ImageNet' dataset.")
    
    def __getitem__(self, index):       
        """Return one image and its corresponding attribute label."""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)
    
    def __len__(self):     
        """Return the number of images."""
        return self.num_images

In [3]:
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, batch_size=16, dataset='CelebA', mode='train', 
               num_workers=0):    
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    # TODO: Change the directory of the test mode.
    if mode == 'train':
        dataset = ImageNet(image_dir, attr_path, selected_attrs, transform, mode)
    elif mode == 'test':
        dataset = ImageNet(image_dir, attr_path, selected_attrs, transform, mode)

    data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'), num_workers=num_workers)
    return data_loader

In [4]:
def get_loader_class(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, batch_size=16, dataset='ImageNet', mode='train',
                     num_workers=0):
    """Build and return a data loader for training the classifier with data augmentation."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip(p=0.5))
    transform.append(T.CenterCrop(crop_size))
    transform.append(T.Resize(image_size))
    transform.append(T.RandomRotation(degrees = (-20,20)))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = T.Compose(transform)

    dataset = ImageNet(image_dir, attr_path, selected_attrs, transform, mode)
         
    data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=(mode=='train'), num_workers=num_workers)
    return data_loader

#### Solver

In [5]:
class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, imagenet_loader, imagenet_class_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.imagenet_loader = imagenet_loader
        self.imagenet_class_loader = imagenet_class_loader

        # Model configurations.
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.c_conv_dim = config.c_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.c_repeat_num = config.c_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.c_lr = config.c_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.c_beta1 = config.c_beta1
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        # if self.use_tensorboard:
        #     self.build_tensorboard()

    def build_model(self):
        """Create a generator and a discriminator."""

        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
        self.C = Classifier(self.image_size, self.c_conv_dim, self.c_dim, self.c_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.c_lr, [self.c_beta1, self.beta2])

        # self.print_network(self.G, 'G')
        # self.print_network(self.D, 'D')
        # self.print_network(self.C, 'C')

        self.G.to(self.device)
        self.D.to(self.device)
        self.C.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('[INFO] Loading the trained models from step {}...'.format(resume_iters))

        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(resume_iters))

        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage))

    # def build_tensorboard(self):
    #     """Build a tensorboard logger."""
    #     from logger import Logger
    #     self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr, c_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.c_optimizer.param_groups:
            param_group['lr'] = c_lr

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        self.c_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def create_labels(self, c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        c_trg_list = []
        for i in range(c_dim):
            c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def classification_loss(self, logit, target):
        """Compute binary or softmax cross entropy loss."""
        return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)

    def train(self):
        """Train StarGAN within a single dataset."""
        data_loader = self.imagenet_loader
        data_loader_class = self.imagenet_class_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.selected_attrs)
        data_iter_class = iter(data_loader_class)

        # Learning rate cache for decaying.
        g_lr = self.g_lr
        d_lr = self.d_lr
        c_lr = self.c_lr

        # Start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.
        print('[INFO] Training started!')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels.
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org = next(data_iter)

            try:
                x_real_class, label_org_class = next(data_iter_class)
            except:
                data_iter_class = iter(data_loader_class)
                x_real_class, label_org_class = next(data_iter_class)

            # Generate target domain labels randomly.
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            # TODO: Plan to add the ImageNet dataset with labels.
            if self.dataset == 'ImageNet':
                c_org = label_org.clone()
                c_trg = label_trg.clone()
            elif self.dataset == 'AFHQ':
                c_org = self.label2onehot(label_org, self.c_dim)
                c_trg = self.label2onehot(label_trg, self.c_dim)

            x_real = x_real.to(self.device) # Input images.
            x_real_class = x_real_class.to(self.device)

            c_org = c_org.to(self.device) # Original domain labels.
            c_trg = c_trg.to(self.device) # Target domain labels.
            
            label_org = label_org.to(self.device) # Labels for computing classification loss.
            label_trg = label_trg.to(self.device) # Labels for computing classification loss.
            label_org_class = label_org_class.to(self.device)

            # =================================================================================== #
            #                             2-0. Train the Classifier                               #
            # =================================================================================== #

            # Compute loss with real images.
            out_cls = self.C(x_real_class)

            c_loss = self.classification_loss(out_cls, label_org_class)

            self.reset_grad()
            c_loss.backward(retain_graph=True)
            self.c_optimizer.step()

            # Logging.
            loss = {}
            loss['C/loss'] = c_loss.item()

            # =================================================================================== #
            #                             2-1. Train the discriminator                            #
            # =================================================================================== #

            # Compute loss with real images.
            out_src = self.D(x_real)
            # d_loss_real = - torch.mean(out_src)
            d_loss_real = torch.mean(F.relu(1. - torch.mul(out_src, 1.0)))

            # Compute loss with fake images.
            x_fake = self.G(x_real, c_trg)
            out_src = self.D(x_fake.detach())
            d_loss_fake = torch.mean(F.relu(1. - torch.mul(out_src, -1.0)))

            # Compute loss for gradient penalty.
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # Backward and optimize.
            d_loss = d_loss_real + d_loss_fake + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               2-2. Train the generator                              #
            # =================================================================================== #

            if (i+1) % self.n_critic == 0:
                # Original-to-target domain.
                x_fake = self.G(x_real, c_trg)
                out_src = self.D(x_fake)
                g_loss_fake = - torch.mean(out_src)
                out_cls_f = self.C(x_fake)
                c_loss_f = self.classification_loss(out_cls_f, c_trg)

                # Target-to-original domain.
                x_reconst = self.G(x_fake, c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # Backward and optimize.
                g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * c_loss_f
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = self.lambda_rec * g_loss_rec.item()
                loss['G/loss_cls'] = self.lambda_cls * c_loss_f.item()

            # =================================================================================== #
            #                                 3. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                # if self.use_tensorboard:
                #     for tag, value in loss.items():
                #         self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        x_fake_list.append(self.G(x_fixed, c_fixed))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print("[INFO] Saving images into '{}'.".format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1))

                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                torch.save(self.C.state_dict(), C_path)
                print("[INFO] Saving checkpoints into '{}'".format(self.model_save_dir))

            # Decay learning rates.
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                c_lr -= (self.c_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr, c_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}, c_lr: {}.'.format(g_lr, d_lr, c_lr))

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        # Set data loader.
        data_loader = self.imagenet_loader

        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):
                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim, self.selected_attrs)
                # Translate images.
                x_fake_list = []
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))
                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
                break

class HingeLoss(torch.nn.Module):
    def __init__(self):
        super(HingeLoss, self).__init__()

    def forward(self, output, target):
        hinge_loss = 1. - torch.mul(output, target)
        return torch.mean(F.relu(hinge_loss))

In [6]:
def str2bool(v):
    return v.lower() in ('true')

def clean_and_create_dir(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

def main(config):
    cudnn.benchmark = True

    if config.mode == 'train':
        # Checks if the log folder exists.
        if not os.path.exists(config.log_dir):
            os.makedirs(config.log_dir)

        # Checks if the model folder exists and clean it.
        clean_and_create_dir(config.model_save_dir)

        # Checks if the sample folder exists and clean it.
        clean_and_create_dir(config.sample_dir)

        # Checks if the result folder exists and clean it.
        clean_and_create_dir(config.result_dir)

    if config.mode == 'test':
        # Checks if the result folder exists and clean it.
        clean_and_create_dir(config.result_dir)
        
    imagenet_loader = None
    imagenet_class_loader = None

    if config.dataset == 'ImageNet':
        imagenet_loader = get_loader(config.imagenet_image_dir, config.attr_path, config.selected_attrs, config.crop_size, config.image_size,
                                     config.batch_size, 'ImageNet', config.mode, config.num_workers)

        imagenet_class_loader = get_loader_class(config.imagenet_image_dir, config.attr_path, config.selected_attrs, config.crop_size, config.
                                                 image_size, config.batch_size, 'ImageNet', config.mode, config.num_workers)
    
    solver = Solver(imagenet_loader, imagenet_class_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'AFHQ', 'ImageNet']:
            solver.train()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'AFHQ', 'ImageNet']:
            solver.test()

In [7]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Model configuration.
    parser.add_argument('--c_dim', type=int, default=40)
    parser.add_argument('--crop_size', type=int, default=178)
    parser.add_argument('--image_size', type=int, default=128)
    parser.add_argument('--g_conv_dim', type=int, default=8)
    parser.add_argument('--d_conv_dim', type=int, default=8)
    parser.add_argument('--c_conv_dim', type=int, default=8) 
    parser.add_argument('--g_repeat_num', type=int, default=6)
    parser.add_argument('--d_repeat_num', type=int, default=6)
    parser.add_argument('--c_repeat_num', type=int, default=6)     
    parser.add_argument('--lambda_cls', type=float, default=0.25)  
    parser.add_argument('--lambda_rec', type=float, default=1.3)
    parser.add_argument('--lambda_gp', type=float, default=1)
                                            
    # Training configuration.
    parser.add_argument('--dataset', type=str, default='ImageNet')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--num_iters', type=int, default=1000000)
    parser.add_argument('--num_iters_decay', type=int, default=100000)
    parser.add_argument('--g_lr', type=float, default=0.0001)
    parser.add_argument('--d_lr', type=float, default=0.0001)
    parser.add_argument('--c_lr', type=float, default=0.00012)      
    parser.add_argument('--n_critic', type=int, default=5)
    parser.add_argument('--beta1', type=float, default=0.0)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--c_beta1', type=float, default=0.9)
    parser.add_argument('--resume_iters', type=int, default=None)  
    parser.add_argument('--selected_attrs', '--list', nargs='+', default=['original', 'perturbation'])

    # Test configuration.
    parser.add_argument('--test_iters', type=int, default=1000000)

    # Miscellaneous.
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])
    parser.add_argument('--use_tensorboard', type=str2bool, default=True)

    # Directories.
    parser.add_argument('--attr_path', type=str, default='../../Datasets/ImageNet5/image_data.txt')
    parser.add_argument('--imagenet_image_dir', type=str, default='../../Datasets/ImageNet5/Images')
    parser.add_argument('--log_dir', type=str, default='../logs/imagenet')
    parser.add_argument('--model_save_dir', type=str, default='../models/imagenet')
    parser.add_argument('--sample_dir', type=str, default='../samples/imagenet')
    parser.add_argument('--result_dir', type=str, default='../results/imagenet')
    
    # Step size.
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--model_save_step', type=int, default=10000)
    parser.add_argument('--lr_update_step', type=int, default=1000)

    config = parser.parse_args()

    # Writing the parsing arguments into a logging file.
    log_file = os.path.join(config.log_dir, 'log_arguments.txt')
    with open(log_file, "w") as log_arg:
        log_arg.write(str(config))
    print("[INFO] Parameters saved into '{}'.".format(log_file))

    print(log_file)
    
    # main(config)

usage: ipykernel_launcher.py [-h] [--c_dim C_DIM] [--crop_size CROP_SIZE]
                             [--image_size IMAGE_SIZE]
                             [--g_conv_dim G_CONV_DIM]
                             [--d_conv_dim D_CONV_DIM]
                             [--c_conv_dim C_CONV_DIM]
                             [--g_repeat_num G_REPEAT_NUM]
                             [--d_repeat_num D_REPEAT_NUM]
                             [--c_repeat_num C_REPEAT_NUM]
                             [--lambda_cls LAMBDA_CLS]
                             [--lambda_rec LAMBDA_REC] [--lambda_gp LAMBDA_GP]
                             [--dataset DATASET] [--batch_size BATCH_SIZE]
                             [--num_iters NUM_ITERS]
                             [--num_iters_decay NUM_ITERS_DECAY] [--g_lr G_LR]
                             [--d_lr D_LR] [--c_lr C_LR] [--n_critic N_CRITIC]
                             [--beta1 BETA1] [--beta2 BETA2]
                             [--c_beta1 C_BETA1]

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
