# Generating images of new classes based on input images

Perform directed generation of images by selecting an image and class,
and generating a new image that represents the input image in the class.

In [None]:
!nvidia-smi -L

#### Mount the Drive and navigate to the project directory
Only relevent if you are running in colab.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd "/content/drive/MyDrive/MastersDegree/Semester3/046211_project/046211_project_repo"

#### Add all the imports

In [None]:
import os
import cv2
import torch
from matplotlib import pyplot as plt

import dnnlib
import legacy
from facenet_encoder.inception_resnet_v1 import InceptionResnetV1
from facenet_encoder.utils import one_hot_vector

Determine if an nvidia GPU is available

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

#### Import a Generator

In [None]:
def import_generator(generator_pkl):
    """
    Load a StyleGAN2-ADA generator from a pickle file.
    Args:
        generator_pkl:  Path to the pkl file from which to load the generator.

    Returns: Instance of the  generator loaded from the pkl file.
    """
    g_kwargs = dnnlib.EasyDict()
    g_kwargs.size = None
    g_kwargs.scale_type = 'pad'
    with dnnlib.util.open_url(generator_pkl) as f:
        generator = legacy.load_network_pkl(f, custom=True, **g_kwargs)['G_ema']
    return generator

#### Import an Encoder

In [None]:
def import_encoder(encoder_pt):
    """
    Load an InceptionResnetV1 encoder from a pytorch file.
    Args:
        encoder_pt:  Path to the pt file from which to load the encoder.

    Returns: Instance of the encoder loaded from the pkl file.
    """
    encoder = InceptionResnetV1()
    state_dict = torch.load(encoder_pt)
    encoder.load_state_dict(state_dict)
    return encoder


#### Generate iamges

Define the image for the generation

In [None]:
input_path = os.path.join('..','data','0001.png')
input_image = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB).to(device)

Import the generator from a pkl file.

In [None]:
generator_pkl = os.path.join('..', 'pretrained', 'generator.pkl')
generator = import_generator(generator_pkl).to(device)
print('Loaded generator model from {}.'.format(generator_pkl))

Import the encoder from a pt file.

In [None]:
encoder_pt = os.path.join('..', 'pretrained', 'encoder.pt')
encoder = import_encoder(encoder_pt).to(device)
print('Loaded generator model from {}.'.format(generator_pkl))

Load an input image and the new class

In [None]:
# list for saving the generated image of every class
new_images = []

# pass the input image to the encoder to get the latent vector (embedding)
latent_img = encoder(input_image).to(device)
for new_class in range(generator.c_dim):
    # pass the latent vector and the new class to the generator to get the new image
    new_image = generator(latent_img, one_hot_vector(generator.c_dim, new_class, device))
    new_images.append(new_image)

classes_image = new_images.concatenate(axis=1)

Display and save the new generated images

In [None]:
plt.imshow(classes_image)
plt.axis('off')

output_path = os.path.join('..', 'out', 'generated_images.jpg')
cv2.imwrite(output_path, cv2.cvtColor(classes_image, cv2.COLOR_RGB2BGR))

Enjoy your generated images!