In [3]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
import sys

import torch
sys.path.append("../")
import models
from utils.stylegan2_utils import StyleGAN2SampleGenerator
from utils.segmentation_utils import FaceSegmentation, StuffSegmentation, GANLinearSegmentation
from lelsd import LELSD

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Training StyleGAN2 with Supervised Segmentation

### StyleGAN2 FFHQ

In [None]:
device = torch.device('cuda')

exp_dir = "../out"
G2 = models.get_model("stylegan2", "../pretrained/stylegan2/ffhq.pkl")
stylegan2_sample_generator = StyleGAN2SampleGenerator(G=G2, device=device)

face_bisenet = models.get_model("face_bisenet", "../pretrained/face_bisenet/model.pth")
face_segmentation = FaceSegmentation(face_bisenet=face_bisenet, device=device)

for latent_space in ["Z", "W", "W+"]:
    for loss_function in ["L2"]:
        for mask_aggregation in [
            'average',
            'union',
            'intersection',
        ]:

            for num_latent_dirs in [1, 2]:
                for part_name, face_parts in zip(
                        [
                            "mouth",
                            "skin",
                            "eyes",
                            "nose",
                            "ears",
                            "background",
                            "eyebrows",
                            "hair",
                            "cloth", "eyeglass"

                        ],
                        [
                            ["mouth", "u_lip", "l_lip"],
                            ["skin"],
                            ["l_eye", "r_eye"],
                            ["nose"],
                            ["l_ear", "r_ear", "earrings"],
                            ["background"],
                            ["l_brow", "r_brow"],
                            ["hair", "hat"],
                            ["hair"],
                            ["cloth", "neck", "necklace"],
                            ["eyeglass"]

                        ]
                ):
                    lr = 0.001
                    min_alpha_value = -1.0
                    max_alpha_value = 1.0
                    min_abs_alpha_value = 0.0
                    gamma_correlation = 5.0
                    onehot_temperature = 0.001
                    batch_size = 4
                    localization_layers = list(range(1, 18))
                    localization_layer_weights = None
                    log_dir = f'{exp_dir}/lelsd_stylegan2_ffhq/{latent_space}_{loss_function}_{mask_aggregation}/{num_latent_dirs}D/face_bisenet/{part_name}'
                    lelsd = LELSD(device=device,
                                  localization_layers=localization_layers,
                                  semantic_parts=face_parts,
                                  loss_function=loss_function,
                                  localization_layer_weights=localization_layer_weights,
                                  mode='foreground',
                                  mask_aggregation=mask_aggregation,
                                  n_layers=18,
                                  latent_dim=512,
                                  num_latent_dirs=num_latent_dirs,
                                  learning_rate=lr,
                                  batch_size=batch_size,
                                  gamma_correlation=gamma_correlation,
                                  unit_norm=False,
                                  latent_space=latent_space,
                                  onehot_temperature=onehot_temperature,
                                  min_alpha_value=min_alpha_value,
                                  max_alpha_value=max_alpha_value,
                                  min_abs_alpha_value=min_abs_alpha_value,
                                  log_dir=log_dir,
                                  )

                    lelsd.fit(stylegan2_sample_generator, face_segmentation, num_batches=200 * num_latent_dirs,
                              num_lr_halvings=3,
                              pgbar=True, summary=True)
                    lelsd.save()


### StyleGAN2 LSUN Church

In [None]:
device = torch.device('cuda')

exp_dir = "../out"
G2 = models.get_model("stylegan2", "../pretrained/stylegan2/stylegan2-church-config-f.pkl")
stylegan2_sample_generator = StyleGAN2SampleGenerator(G=G2, device=device)

deeplabv2_resnet101 = models.get_model("cocostuff_deeplab",
                                       "../pretrained/cocostuff_deeplab/deeplabv2_resnet101_msc-cocostuff164k-100000.pth")
segmentation_model = StuffSegmentation(deeplabv2_resnet101=deeplabv2_resnet101, 
                                       config_path="../pretrained/cocostuff_deeplab/", device=device)

for latent_space in ["Z", "W", "W+"]:
    for loss_function in ["L2"]:
        for mask_aggregation in [
            'average',
            'union',
            'intersection',
        ]:

            for num_latent_dirs in [1, 2]:
                for part_name, sub_parts in zip(
                        [
                            "church",
                            "sky", "vegetation", "ground"

                        ],
                        [
                            ["building-other", "house"],
                            ["sky-other", "clouds"],
                            ["tree", "grass", "bush", "plant-other"],
                            ["dirt", "mud", "sand", "gravel", "ground-other", "road", "pavement"],

                        ]
                ):
                    lr = 0.001
                    min_alpha_value = -1.0
                    max_alpha_value = 1.0
                    min_abs_alpha_value = 0.0
                    gamma_correlation = 5.0
                    onehot_temperature = 0.001
                    batch_size = 4
                    localization_layers = list(range(1, 14))
                    localization_layer_weights = None
                    log_dir = f'{exp_dir}/lelsd_stylegan2_lsun_church/{latent_space}_{loss_function}_{mask_aggregation}/{num_latent_dirs}D/deeplab/{part_name}'
                    lelsd = LELSD(device=device,
                                  localization_layers=localization_layers,
                                  semantic_parts=sub_parts,
                                  loss_function=loss_function,
                                  localization_layer_weights=localization_layer_weights,
                                  mode='foreground',
                                  mask_aggregation=mask_aggregation,
                                  n_layers=14,
                                  latent_dim=512,
                                  num_latent_dirs=num_latent_dirs,
                                  learning_rate=lr,
                                  batch_size=batch_size,
                                  gamma_correlation=gamma_correlation,
                                  unit_norm=False,
                                  latent_space=latent_space,
                                  onehot_temperature=onehot_temperature,
                                  min_alpha_value=min_alpha_value,
                                  max_alpha_value=max_alpha_value,
                                  min_abs_alpha_value=min_abs_alpha_value,
                                  log_dir=log_dir,
                                  )

                    lelsd.fit(stylegan2_sample_generator, segmentation_model, num_batches=200 * num_latent_dirs,
                              num_lr_halvings=3,
                              pgbar=True, summary=True)
                    lelsd.save()


### StyleGAN2 LSUN Car

In [None]:
device = torch.device('cuda')

exp_dir = "../out"
G2 = models.get_model("stylegan2", "../pretrained/stylegan2/stylegan2-car-config-f.pkl")
stylegan2_sample_generator = StyleGAN2SampleGenerator(G=G2, device=device)

deeplabv2_resnet101 = models.get_model("cocostuff_deeplab",
                                       "../pretrained/cocostuff_deeplab/deeplabv2_resnet101_msc-cocostuff164k-100000.pth")
segmentation_model = StuffSegmentation(deeplabv2_resnet101=deeplabv2_resnet101,
                                       config_path="../pretrained/cocostuff_deeplab/", device=device)

for latent_space in ["W", "W+"]:
    for loss_function in ["L2"]:
        for mask_aggregation in [
            'average',
        ]:
            for num_latent_dirs in [1, 2]:
                for part_name, sub_parts in zip(
                        [
                            "car",
                            "road", "sky", "grass+tree",

                        ],
                        [
                            ["car", "truck", "bus", "motorcycle"],
                            ["road", "pavement", "dirt"],
                            ["sky-other", "clouds"],
                            ["tree", "grass", "bush", "plant-other"],
                        ]
                ):
                    lr = 0.001
                    min_alpha_value = -1.0
                    max_alpha_value = 1.0
                    min_abs_alpha_value = 0.0
                    gamma_correlation = 5.0
                    onehot_temperature = 0.001
                    batch_size = 4
                    localization_layers = list(range(1, 16))
                    localization_layer_weights = None
                    log_dir = f'{exp_dir}/lelsd_stylegan2_lsun_car/{latent_space}_{loss_function}_{mask_aggregation}/{num_latent_dirs}D/deeplab/{part_name}'
                    lelsd = LELSD(device=device,
                                  localization_layers=localization_layers,
                                  semantic_parts=sub_parts,
                                  loss_function=loss_function,
                                  localization_layer_weights=localization_layer_weights,
                                  mode='foreground',
                                  mask_aggregation=mask_aggregation,
                                  n_layers=16,
                                  latent_dim=512,
                                  num_latent_dirs=num_latent_dirs,
                                  learning_rate=lr,
                                  batch_size=batch_size,
                                  gamma_correlation=gamma_correlation,
                                  unit_norm=False,
                                  latent_space=latent_space,
                                  onehot_temperature=onehot_temperature,
                                  min_alpha_value=min_alpha_value,
                                  max_alpha_value=max_alpha_value,
                                  min_abs_alpha_value=min_abs_alpha_value,
                                  log_dir=log_dir,
                                  )

                    lelsd.fit(stylegan2_sample_generator, segmentation_model, num_batches=200 * num_latent_dirs,
                              num_lr_halvings=3,
                              pgbar=True, summary=True)
                    lelsd.save()


### StyleGAN2 LSUN Horse

In [None]:
device = torch.device('cuda')

exp_dir = "../out"
G2 = models.get_model("stylegan2", "../pretrained/stylegan2/stylegan2-horse-config-f.pkl")
stylegan2_sample_generator = StyleGAN2SampleGenerator(G=G2, device=device)

deeplabv2_resnet101 = models.get_model("cocostuff_deeplab",
                                       "../pretrained/cocostuff_deeplab/deeplabv2_resnet101_msc-cocostuff164k-100000.pth")
segmentation_model = StuffSegmentation(deeplabv2_resnet101=deeplabv2_resnet101,
                                       config_path="../pretrained/cocostuff_deeplab/", device=device)

for latent_space in ["W", "W+"]:
    for loss_function in ["L2"]:
        for mask_aggregation in [
            'average',
        ]:
            for num_latent_dirs in [1, 2]:
                for part_name, sub_parts in zip(
                        [
                            "horse",
                            "person", "sky", "grass+tree", "ground"

                        ],
                        [
                            ["horse"],
                            ["person"],
                            ["sky-other", "clouds"],
                            ["tree", "grass", "bush", "plant-other"],
                            ["dirt", "mud", "sand", "gravel", "ground-other", "road", "pavement"],

                        ]
                ):
                    lr = 0.001
                    min_alpha_value = -1.0
                    max_alpha_value = 1.0
                    min_abs_alpha_value = 0.0
                    gamma_correlation = 5.0
                    onehot_temperature = 0.001
                    batch_size = 4
                    localization_layers = list(range(1, 14))
                    localization_layer_weights = None
                    log_dir = f'{exp_dir}/lelsd_stylegan2_lsun_horse/{latent_space}_{loss_function}_{mask_aggregation}/{num_latent_dirs}D/deeplab/{part_name}'
                    lelsd = LELSD(device=device,
                                  localization_layers=localization_layers,
                                  semantic_parts=sub_parts,
                                  loss_function=loss_function,
                                  localization_layer_weights=localization_layer_weights,
                                  mode='foreground',
                                  mask_aggregation=mask_aggregation,
                                  n_layers=14,
                                  latent_dim=512,
                                  num_latent_dirs=num_latent_dirs,
                                  learning_rate=lr,
                                  batch_size=batch_size,
                                  gamma_correlation=gamma_correlation,
                                  unit_norm=False,
                                  latent_space=latent_space,
                                  onehot_temperature=onehot_temperature,
                                  min_alpha_value=min_alpha_value,
                                  max_alpha_value=max_alpha_value,
                                  min_abs_alpha_value=min_abs_alpha_value,
                                  log_dir=log_dir,
                                  )

                    lelsd.fit(stylegan2_sample_generator, segmentation_model, num_batches=200 * num_latent_dirs,
                              num_lr_halvings=3,
                              pgbar=True, summary=True)
                    lelsd.save()


### StyleGAN2 MetFaces

In [None]:
device = torch.device('cuda')

exp_dir = "../out"
G2 = models.get_model("stylegan2", "../pretrained/stylegan2/metfaces.pkl")
stylegan2_sample_generator = StyleGAN2SampleGenerator(G=G2, device=device)

face_bisenet = models.get_model("face_bisenet", "../pretrained/face_bisenet/model.pth")
face_segmentation = FaceSegmentation(face_bisenet=face_bisenet, device=device)

for latent_space in ["W", "W+"]:
    for loss_function in ["L2"]:
        for mask_aggregation in [
            'average',
        ]:

            for num_latent_dirs in [1, 2]:
                for part_name, face_parts in zip(
                        [
                            "mouth",
                            "skin",
                            "eyes",
                            "nose",
                            "ears",
                            "background",
                            "eyebrows",
                            "hair",
                            "cloth",
                        ],
                        [
                            ["mouth", "u_lip", "l_lip"],
                            ["skin"],
                            ["l_eye", "r_eye"],
                            ["nose"],
                            ["l_ear", "r_ear", "earrings"],
                            ["background"],
                            ["l_brow", "r_brow"],
                            ["hair", "hat"],
                            ["hair"],
                            ["cloth", "neck", "necklace"],

                        ]
                ):
                    lr = 0.001
                    min_alpha_value = -1.0
                    max_alpha_value = 1.0
                    min_abs_alpha_value = 0.0
                    gamma_correlation = 5.0
                    onehot_temperature = 0.001
                    batch_size = 4
                    localization_layers = list(range(1, 18))
                    localization_layer_weights = None
                    log_dir = f'{exp_dir}/lelsd_stylegan2_metfaces/{latent_space}_{loss_function}_{mask_aggregation}/{num_latent_dirs}D/face_bisenet/{part_name}'
                    lelsd = LELSD(device=device,
                                  localization_layers=localization_layers,
                                  semantic_parts=face_parts,
                                  loss_function=loss_function,
                                  localization_layer_weights=localization_layer_weights,
                                  mode='foreground',
                                  mask_aggregation=mask_aggregation,
                                  n_layers=18,
                                  latent_dim=512,
                                  num_latent_dirs=num_latent_dirs,
                                  learning_rate=lr,
                                  batch_size=batch_size,
                                  gamma_correlation=gamma_correlation,
                                  unit_norm=False,
                                  latent_space=latent_space,
                                  onehot_temperature=onehot_temperature,
                                  min_alpha_value=min_alpha_value,
                                  max_alpha_value=max_alpha_value,
                                  min_abs_alpha_value=min_abs_alpha_value,
                                  log_dir=log_dir,
                                  )

                    lelsd.fit(stylegan2_sample_generator, face_segmentation, num_batches=200 * num_latent_dirs,
                              num_lr_halvings=3,
                              pgbar=True, summary=True)
                    lelsd.save()
