# Check out PyTorch StyleGAN Encoder



In [None]:
# !rm -r pytorch_stylegan_encoder
!git clone --recurse-submodules https://github.com/ficinator/pytorch_stylegan_encoder.git
%cd pytorch_stylegan_encoder

In [None]:
from InterFaceGAN.models.stylegan_generator import StyleGANGenerator
from models.latent_optimizer import PostSynthesisProcessing
from models.image_to_latent import ImageToLatent, ImageLatentDataset
from models.losses import LogCoshLoss
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
from glob import glob
from tqdm import tqdm_notebook as tqdm
import numpy as np
from pathlib import Path
from datetime import datetime

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Generate Images

* download the pretrained model
* use it to generate 50k faces with corresponding dlatents

In [None]:
!wget https://www.dropbox.com/s/qyv37eaobnow7fu/stylegan_ffhq.pth?dl=1 -O InterFaceGAN/models/pretrain/stylegan_ffhq.pth --quiet

In [None]:
MODEL_NAME = 'stylegan_ffhq'
DRIVE_DIR = Path('../drive/MyDrive/ML')
# DATA_DIR = DRIVE_DIR/'data'/MODEL_NAME
DATA_DIR = Path('data')/MODEL_NAME
NUM_IMAGES = 50000
!rm -r $DATA_DIR
!python InterFaceGAN/generate_data.py -m $MODEL_NAME -o $DATA_DIR -n $NUM_IMAGES

# Create Dataloaders
Using a 50,000 image dataset. Generated with the generated_data.py script at https://github.com/ShenYujun/InterFaceGAN.

In [None]:
augments = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

image_size = 256
num_images_train = int(.8 * NUM_IMAGES)

filenames = sorted(glob(str(DATA_DIR/'*.jpg')))

train_filenames = filenames[:num_images_train]
validation_filenames = filenames[num_images_train:]

dlatents = np.load(DATA_DIR/'wp.npy')

train_dlatents = dlatents[:num_images_train]
validation_dlatents = dlatents[num_images_train:]

train_dataset = ImageLatentDataset(train_filenames, train_dlatents, transforms=augments)
validation_dataset = ImageLatentDataset(validation_filenames, validation_dlatents, transforms=augments)

train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=32)
validation_generator = torch.utils.data.DataLoader(validation_dataset, batch_size=32)

# Instantiate Model

In [None]:
image_to_latent = ImageToLatent(image_size).cuda()
optimizer = torch.optim.Adam(image_to_latent.parameters())
criterion = LogCoshLoss()

# Train Model

In [None]:
epochs = 20
validation_loss = 0.0

progress_bar = tqdm(range(epochs))
for epoch in progress_bar:    
    running_loss = 0.0
    
    image_to_latent.train()
    for i, (images, latents) in enumerate(train_generator, 1):
        optimizer.zero_grad()

        images, latents = images.cuda(), latents.cuda()
        pred_latents = image_to_latent(images)
        loss = criterion(pred_latents, latents)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_description("Step: {0}, Loss: {1:4f}, Validation Loss: {2:4f}".format(i, running_loss / i, validation_loss))
    
    validation_loss = 0.0
    
    image_to_latent.eval()
    for i, (images, latents) in enumerate(validation_generator, 1):
        with torch.no_grad():
            images, latents = images.cuda(), latents.cuda()
            pred_latents = image_to_latent(images)
            loss =  criterion(pred_latents, latents)
            
            validation_loss += loss.item()
    
    validation_loss /= i
    progress_bar.set_description("Step: {0}, Loss: {1:4f}, Validation Loss: {2:4f}".format(i, running_loss / i, validation_loss))

# Save Model

In [None]:
model_dir = DRIVE_DIR/'models/image2latent'
torch.save(image_to_latent.state_dict(), model_dir/f"{datetime.utcnow().strftime('%Y-%m-%d_%H:%M')}_{NUM_IMAGES}.pt")

# Load Model

In [None]:
image_to_latent = ImageToLatent(image_size).cuda()
image_to_latent.load_state_dict(torch.load(model_dir/'2022-11-17_13:17_1000.pt'))
image_to_latent.eval()

# Test Model

In [None]:
def normalized_to_normal_image(image):
    mean=torch.tensor([0.485, 0.456, 0.406]).view(-1,1,1).float()
    std=torch.tensor([0.229, 0.224, 0.225]).view(-1,1,1).float()
    
    image = image.detach().cpu()
    
    image *= std
    image += mean
    image *= 255
    
    image = image.numpy()[0]
    image = np.transpose(image, (1,2,0))
    return image.astype(np.uint8)


num_test_images = 5
images = [validation_dataset[i][0].unsqueeze(0).cuda() for i in range(num_test_images)]
normal_images = list(map(normalized_to_normal_image, images))

pred_dlatents = map(image_to_latent, images)

synthesizer = StyleGANGenerator(MODEL_NAME).model.synthesis
post_processing = PostSynthesisProcessing()
post_process = lambda image: post_processing(image).detach().cpu().numpy().astype(np.uint8)[0]

pred_images = map(synthesizer, pred_dlatents)
pred_images = map(post_process, pred_images)
pred_images = list(map(lambda image: np.transpose(image, (1,2,0)), pred_images))

In [None]:
figure = plt.figure(figsize=(25,10))
columns = len(normal_images)
rows = 2

axis = []

for i in range(columns):
    axis.append(figure.add_subplot(rows, columns, i + 1))
    axis[-1].set_title("Reference Image")
    plt.imshow(normal_images[i])

for i in range(columns, columns*rows):
    axis.append(figure.add_subplot(rows, columns, i + 1))
    axis[-1].set_title("Generated With Predicted Latents")
    plt.imshow(pred_images[i - columns])

plt.show()

In [None]:
!zip -q $DATA_DIR $DATA_DIR