In [17]:
# StarGAN imports
from data_loader import get_loader
from solver import Solver

from torchvision.utils import save_image
import torch
import os

In [2]:
class Config:
    pass

def get_default_config():
    config = Config()
    
    # Model configuration.
    config.c_dim = 5
    config.c2_dim = 8
    config.celeba_crop_size = 178
    config.rafd_crop_size = 256
    config.image_size = 128
    config.g_conv_dim = 64
    config.d_conv_dim = 64
    config.g_repeat_num = 6
    config.d_repeat_num = 6
    config.lambda_cls = 1
    config.lambda_rec = 10
    config.lambda_gp = 10
    
    # Training configuration.
    config.dataset = "CelebA"
    config.batch_size = 4
    config.num_iters = 200000
    config.num_iters_decay = 100000
    config.g_lr = 0.0001
    config.d_lr = 0.0001
    config.n_critic = 5
    config.beta1 = 0.5
    config.beta2 = 0.999
    config.resume_iters = None
    config.selected_attrs= ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']

    # Test configuration.
    config.test_iters = 200000

    # Miscellaneous.
    config.num_workers = 1
    config.mode = "test"
    config.use_tensorboard = False

    # Directories.
    config.celeba_image_dir = 'data/celeba/images'
    config.attr_path = 'data/celeba/list_attr_celeba.txt'
    config.rafd_image_dir = 'data/RaFD/train'
    config.log_dir = 'stargan/logs'
    config.model_save_dir = 'stargan_celeba_128/models'
    config.sample_dir = 'stargan/samples'
    config.result_dir = 'stargan_celeba_128/results'
    
    # Step size.
    config.log_step = 10
    config.sample_step = 1000
    config.model_save_step = 10000
    config.lr_update_step = 1000

    return config

In [3]:
config = get_default_config()

In [4]:
def get_default_solver(config):
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)
    
    solver = Solver(celeba_loader, rafd_loader, config)
    return solver

In [5]:
solver = get_default_solver(config)

Finished preprocessing the CelebA dataset...
Generator(
  (main): Sequential(
    (0): Conv2d(8, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ResidualBlock(
      (main): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 

In [21]:
def test(solver, test_data_loader):
    # Create results dir if not exists.
    if not os.path.exists(solver.result_dir):
        os.makedirs(solver.result_dir) 
    
    # Load the trained generator.
    solver.restore_model(solver.test_iters)
    
    with torch.no_grad():
        for i, (x_real, c_org) in enumerate(test_data_loader):

            # Prepare input images and target domain labels.
            x_real = x_real.to(solver.device)
            c_trg_list = solver.create_labels(c_org, solver.c_dim, solver.dataset, solver.selected_attrs)

            # Translate images.
            x_fake_list = [x_real]
            for c_trg in c_trg_list:
                x_fake_list.append(solver.G(x_real, c_trg))

            # Save the translated images.
            x_concat = torch.cat(x_fake_list, dim=3)
            result_path = os.path.join(solver.result_dir, '{}-images.jpg'.format(i+1))
            save_image(solver.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
            print('Saved real and fake images into {}...'.format(result_path))

In [12]:
test_data_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                           config.celeba_crop_size, config.image_size, config.batch_size,
                           'CelebASmall', config.mode, config.num_workers)

Finished preprocessing the CelebA dataset...


In [13]:
x, y = next(iter(test_data_loader))
print(x.shape)
print(y.shape)

torch.Size([4, 3, 128, 128])
torch.Size([4, 5])


In [22]:
test(solver, test_data_loader)

Loading the trained models from step 200000...
Saved real and fake images into stargan_celeba_128/results/1-images.jpg...
Saved real and fake images into stargan_celeba_128/results/2-images.jpg...
Saved real and fake images into stargan_celeba_128/results/3-images.jpg...
Saved real and fake images into stargan_celeba_128/results/4-images.jpg...
Saved real and fake images into stargan_celeba_128/results/5-images.jpg...
Saved real and fake images into stargan_celeba_128/results/6-images.jpg...
Saved real and fake images into stargan_celeba_128/results/7-images.jpg...
Saved real and fake images into stargan_celeba_128/results/8-images.jpg...
