In [1]:
from generic_dataset import GenericDataset, VisualizationDataset, CELEBA_FORMAT_DATASET, RAFD_FORMAT_DATASET, get_loader
from custom_solver import CustomSolver

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchvision.utils import save_image
from pathlib import Path
import torch

In [3]:
celeba_selected_attrs1 = ["Male", "Young", "Black_Hair", "Blond_Hair", "Brown_Hair"]
celeba_selected_attrs2 = ["Smiling", "Arched_Eyebrows", "Bangs"]
selected_attrs = celeba_selected_attrs1 + celeba_selected_attrs2

In [4]:
vis_celeba = GenericDataset("data/celeba", CELEBA_FORMAT_DATASET, selected_attrs, "train")

Finished preprocessing CelebA-format dataset at data/celeba: 198548 train images, 4051 test images


In [5]:
# Build solver
celeba1 = GenericDataset("data/celeba", CELEBA_FORMAT_DATASET, celeba_selected_attrs1, "train")
celeba2 = GenericDataset("data/celeba", CELEBA_FORMAT_DATASET, celeba_selected_attrs2, "train")
comics = GenericDataset("data/comics", RAFD_FORMAT_DATASET, ["faces", "comics"], "train")
visualization_ds = VisualizationDataset("data/visualization")

solver = CustomSolver([celeba1, celeba2, comics], visualization_ds, "75k")

Finished preprocessing CelebA-format dataset at data/celeba: 198548 train images, 4051 test images
Finished preprocessing CelebA-format dataset at data/celeba: 198548 train images, 4051 test images
Finished preprocessing RAFD-format dataset at data/comics: 0 train images, 0 test images
sum_label_size: 10, mask_vector_size: 3
Zero padding left: [0, 5, 8]
Zero padding right: [5, 2, 0]
Workdir set at 75k
Generator(
  (main): Sequential(
    (0): Conv2d(16, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): InstanceNorm2d(256, eps=1e-05, momentum=0.1, aff

In [6]:
def visualize(solver, restore_iters, dataset, save_dir, toggle=False):
    solver.restore_model(restore_iters)
    
    save_dir = Path(save135e38ce020d2314b840dc62463b085d0aeecb48_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    d1, outside_d1 = [0, 1, 2, 3, 4], [5, 6, 7, 8, 9]
    d2, outside_d2 = [5, 6, 7], [0, 1, 2, 3, 4, 8, 9]
    d3, outside_d3 = [9], [0, 1, 2, 3, 4, 5, 6, 7, 8]
    
    dataloader = get_loader(dataset, 6, "test", num_workers=2)
    with torch.no_grad():
        for j, (x, label) in enumerate(dataloader):
            x = x.to(solver.device)
            x_fake_list = [x]

            for b in range(0, 9+1):
                if b == 8:     # skip dataset 3 "faces"
                    continue

                # clone original label
                target = label.clone()
                target = torch.cat([target, torch.zeros(label.size(0), 5)], dim=1)
                
                # clean vector mask
                target[:, 10] = 0
                target[:, 11] = 0
                target[:, 12] = 0

                # clean hair colors
                if b in [2, 3, 4]:
                    target[:, 2] = 0
                    target[:, 3] = 0
                    target[:, 4] = 0

                # clean unknown attributes
                if b in d1:
                    target[:, 10] = 1
                    for bb in outside_d1:
                        target[:, bb] = 0
                elif b in d2:
                    target[:, 11] = 1
                    for bb in outside_d2:
                        target[:, bb] = 0
                else:
                    target[:, 12] = 1
                    for bb in outside_d3:
                        target[:, bb] = 0

                # set target attribute on
                if not toggle:
                    target[:, b] = 1
                else:
                    for bs in range(label.size(0)):
                        target[bs, b] = 1 - target[bs, b]

                x_fake_list.append(solver.G(x, target))
            
            x_concat = torch.cat(x_fake_list, dim=3)
            vis_path = save_dir / f'{j+1}.jpg'
            save_image(solver.denorm(x_concat.data.cpu()), vis_path, nrow=1, padding=0)

            # generate at most 100 images
            if j >= 100:
                break

In [21]:
visualize(solver, 75000, vis_celeba, "save_vis_nontoggle", toggle=False)

Loading the trained models from step 75000...


In [22]:
visualize(solver, 75000, vis_celeba, "save_vis_toggle", toggle=True)

Loading the trained models from step 75000...


In [15]:
authors = GenericDataset("data/authors", CELEBA_FORMAT_DATASET, selected_attrs, "train")

Finished preprocessing CelebA-format dataset at data/authors: 12 train images, 0 test images


In [16]:
visualize(solver, 75000, authors, "authors_nontoggle", toggle=False)

Loading the trained models from step 75000...


In [17]:
visualize(solver, 75000, authors, "authors_toggle", toggle=True)

Loading the trained models from step 75000...
