In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import data
from networks import domain_generator, domain_classifier
from utils import renormalize, show
import torch
from PIL import Image
import numpy as np

In [None]:
checkmark = '\u2713'
crossmark = '\u2717'

def classify_image(image, classifier, label):
    with torch.no_grad():
        preds = domain_classifier.postprocess(classifier(image))
        preds = preds.cpu().numpy()
    if np.ndim(preds) == 1:
        # binary prediction
        acc = ((preds > 0.5).astype(int)) == label
    elif np.ndim(preds) == 2:
        acc = np.argmax(preds, axis=-1) == label
        preds = preds[:, label]
    return preds, acc

def classify_from_gan(image, classifier, tensor_transform, label):
    with torch.no_grad():
        postprocessed_im = tensor_transform(image)
        preds = domain_classifier.postprocess(classifier(postprocessed_im))
        preds = preds.cpu().numpy()
    if np.ndim(preds) == 1:
        # binary prediction
        acc = ((preds > 0.5).astype(int)) == label
    elif np.ndim(preds) == 2:
        acc = np.argmax(preds, axis=-1) == label
        preds = preds[:, label]
    return postprocessed_im, preds, acc

def get_symbol(acc):
    return checkmark if acc else crossmark

# stylegan2 generators

In [None]:
target_domain = 'cat'
if target_domain == 'celebahq':
    dataset_name = 'celebahq'
    generator_name = 'stylegan2'
    classifier_name = 'Smiling'
    val_transform = data.get_transform(dataset_name, 'imval')
    dset = data.get_dataset(dataset_name, 'val', classifier_name, load_w=True, transform=val_transform)
    generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False)
    classifier = domain_classifier.define_classifier(dataset_name, classifier_name)
    tensor_transform_val = data.get_transform(dataset_name, 'tensorbase') # centercrop to the appropriate dimension for classifier
    tensor_transform_ensemble = data.get_transform(dataset_name, 'tensormixed') # alternatively, can just use tensorbase
elif target_domain == 'car':
    dataset_name = 'car'
    generator_name = 'stylegan2'
    classifier_name = 'latentclassifier_stylemix_fine'
    val_transform = data.get_transform(dataset_name, 'imval')
    dset = data.get_dataset(dataset_name, 'val', load_w=True, transform=val_transform)
    generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False)
    classifier = domain_classifier.define_classifier(dataset_name, classifier_name)
    tensor_transform_val = data.get_transform(dataset_name, 'tensorbase') # centercrop to the appropriate dimension for classifier
    tensor_transform_ensemble = data.get_transform(dataset_name, 'tensormixed') # alternatively, can just use tensorbase
elif target_domain == 'cat':
    dataset_name = 'cat'
    generator_name = 'stylegan2'
    classifier_name = 'latentclassifier_stylemix_coarse'
    val_transform = data.get_transform(dataset_name, 'imval')
    dset = data.get_dataset(dataset_name, 'val', load_w=True, transform=val_transform)
    generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False)
    classifier = domain_classifier.define_classifier(dataset_name, classifier_name)
    tensor_transform_val = data.get_transform(dataset_name, 'tensorbase') # centercrop to the appropriate dimension for classifier
    tensor_transform_ensemble = data.get_transform(dataset_name, 'tensormixed') # alternatively, can just use tensorbase
else:
    assert(False)

In [None]:
index = 100
with torch.no_grad():
    label = dset[index][2]
    # original image prediction
    pred_original, acc_original = classify_image(dset[index][0][None].cuda(), classifier, label)
    show.a(['original: %0.2f %s' % (pred_original[0], get_symbol(acc_original[0])),
            renormalize.as_image(dset[index][0]).resize((256, 256), Image.ANTIALIAS)])
    
    # gan reconstruction prediction
    latent = dset[index][1][None].cuda()
    reconstruction = generator.decode(latent)
    postprocessed_rec, pred_rec, acc_rec = classify_from_gan(reconstruction, classifier, tensor_transform_val, label)
    show.a(['reconstruction %0.2f %s' % (pred_rec[0], get_symbol(acc_rec[0])),
            renormalize.as_image(postprocessed_rec[0]).resize((256, 256), Image.ANTIALIAS)])
    show.flush()
    
    # isotropic fine
    eps = np.max(generator.perturb_settings['isotropic_eps_fine'])
    perturbed_im = generator.perturb_isotropic(latent, 'fine', eps=eps,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
        perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['isotropic fine %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    
    # isotropic coarse
    eps = np.max(generator.perturb_settings['isotropic_eps_coarse'])
    perturbed_im = generator.perturb_isotropic(latent, 'coarse', eps=eps,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['isotropic coarse %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    
    # pca fine
    eps = np.max(generator.perturb_settings['pca_eps'])
    perturbed_im = generator.perturb_pca(latent, 'fine', eps=eps,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['pca fine %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    
    # pca coarse
    eps = np.max(generator.perturb_settings['pca_eps'])
    perturbed_im = generator.perturb_pca(latent, 'coarse', eps=eps,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['pca coarse %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    
    # stylemix fine
    mix_latent = generator.seed2w(n=4, seed=0)
    perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['stylemix fine %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    
    # stylemix coarse
    mix_latent = generator.seed2w(n=4, seed=0)
    perturbed_im = generator.perturb_stylemix(latent, 'coarse', mix_latent,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['stylemix coarse %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()


# cifar10 generator

In [None]:
target_domain = 'cifar10'
dataset_name = 'cifar10'
generator_name = 'stylegan2-cc'
classifier_name = 'imageclassifier'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False)
classifier = domain_classifier.define_classifier(dataset_name, classifier_name)
tensor_transform_val = data.get_transform(dataset_name, 'tensorbase') # centercrop to the appropriate dimension for classifier

In [None]:
index = 100
with torch.no_grad():
    label = dset[index][2]
    # original image prediction
    pred_original, acc_original = classify_image(dset[index][0][None].cuda(), classifier, label)
    show.a(['original: %0.2f %s' % (pred_original[0], get_symbol(acc_original[0])),
            renormalize.as_image(dset[index][0]).resize((256, 256), Image.ANTIALIAS)])
    
    # gan reconstruction prediction
    latent = dset[index][1][None].cuda()
    reconstruction = generator.decode(latent)
    postprocessed_rec, pred_rec, acc_rec = classify_from_gan(reconstruction, classifier, tensor_transform_val, label)
    show.a(['reconstruction %0.2f %s' % (pred_rec[0], get_symbol(acc_rec[0])),
            renormalize.as_image(postprocessed_rec[0]).resize((256, 256), Image.ANTIALIAS)])
    show.flush()
    
    # stylemix fine
    # use predicted labels to generated class-conditional random samples
    pred_original = domain_classifier.postprocess(classifier(dset[index][0][None].cuda()))
    lab = torch.zeros([4, generator.generator.c_dim],
                      device=generator.device)
    _, pred_label = pred_original.max(1)
    pred_label = pred_label.item()
    lab[:, pred_label] = 1
    mix_latent = generator.seed2w(seed=np.random.randint(1000), n=4, labels=lab)
    perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_val, label)
    for i in range(len(perturbed_im)):
        show.a(['stylemix fine %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()
    

# stylegan-idinvert generator

NOTE: the pretrained generator and encoder needs to be downloaded before running the following blocks


In [None]:
dataset_name = 'celebahq-idinvert'
generator_name = 'stylegan-idinvert'
classifier_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', classifier_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False)
classifier = domain_classifier.define_classifier(dataset_name, classifier_name)
tensor_transform_val = data.get_transform(dataset_name, 'tensorbase') # centercrop to the appropriate dimension for classifier
tensor_transform_ensemble = data.get_transform(dataset_name, 'tensormixed') # alternatively, can just use tensorbase

In [None]:
index = 100
with torch.no_grad():
    label = dset[index][2]
    # original image prediction
    pred_original, acc_original = classify_image(dset[index][0][None].cuda(), classifier, label)
    show.a(['original: %0.2f %s' % (pred_original[0], get_symbol(acc_original[0])),
            renormalize.as_image(dset[index][0]).resize((256, 256), Image.ANTIALIAS)])
    
    # gan reconstruction prediction
    latent = dset[index][1][None].cuda()
    reconstruction = generator.decode(latent)
    postprocessed_rec, pred_rec, acc_rec = classify_from_gan(reconstruction, classifier, tensor_transform_val, label)
    show.a(['reconstruction %0.2f %s' % (pred_rec[0], get_symbol(acc_rec[0])),
            renormalize.as_image(postprocessed_rec[0]).resize((256, 256), Image.ANTIALIAS)])
    show.flush()
    
    # stylemix fine
    mix_latent = generator.seed2w(n=4, seed=np.random.randint(1000))
    perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent,n=4)
    postprocessed_preturbed, pred_perturbed, acc_perturbed = classify_from_gan(
            perturbed_im, classifier, tensor_transform_ensemble, label)
    for i in range(len(perturbed_im)):
        show.a(['stylemix fine %0.2f %s' % (pred_perturbed[i],get_symbol(acc_perturbed[i])),
                renormalize.as_image(postprocessed_preturbed[i]).resize((150, 150), Image.ANTIALIAS)])
    show.flush()