# Information about paper and code

## Code Authors
Erce Guder - Adnan Harun Dogan

## Paper Name:
FEW-SHOT CROSS-DOMAIN IMAGE GENERATION VIA INFERENCE-TIME LATENT-CODE LEARNING (ICLR2023) https://openreview.net/pdf?id=sCYXJr3QJM8

## Paper Summary
_"Can a GAN trained on a single large-scale source dataset be adapted to multiple target domains containing very few examples without re-training the pre-trained source generator?"_

The goal of the paper is to learn a latent-generation network (during inference stage) that maps random Gaussian noise to latents in the W space (style) of a pre-trained StyleGAN2 **without updating the generator** (we don't want to overfit / forget rich prior knowledge) such that the generator can sample from the target domain. There are two loss functions used to ensure that the samples belong to target domain:

  - Adversarial loss
  - Style Loss (Content Loss / VGG-Loss)

## Datasets: 
We shall take StyleGAN2 checkpoints trained on:
* *Flickr Faces HQ (FFHQ)* dataset. Our target domains will be

  - FFHQ-Babies,
  - FFHQ-Sunglasses, 
  - Face sketches, 
  - Emoji faces from bitmoji.com, 
  - Portrait paintings from the artistic faces dataset.

* *LSUN Church* as a source domain and adapt to
  - the haunted houses, 
  - Van Goh’s house paintings.

# Hyper-parameters of the model

As the authors of this paper also did, we took checkpoints that were presented by StyleGAN2 repositories. So, we are not sure if we should present the hyper-parameters of StyleGAN2.

Nevertheless, the hyper-parameters of StyleGAN2:
 - Output size of 256x256 pixels (generator & discriminator)
 - W dimension of 512 (generator)
 - Number of layers mapping Z to W: 8 (generator)

The latent generation network is decided to be (mentioned in the paper):
 - 3-Layer MLP with
 - ReLU activations

Rest of the hyper-parameters are as follows:
 - Learning rate for discriminator & latent generation network: 5e-4
 - Optimizer for discriminator & latent generation network:
     - Adam with betas (0.0, 0.99)
 - Batch size: 4

In [None]:
%load_ext autoreload
%autoreload 2

# Training and saving a model. 


In [None]:
from model import Generator, Discriminator
from latent_learner import LatentLearner
from dataset import Dataset
from tqdm import tqdm
import loss
import random

import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt

import gc
import os

In [None]:
@torch.no_grad()
def save_samples(x_1, x_2, generator, latent_learner, iter_):
    os.makedirs("samples", exist_ok=True)

    noise = torch.cat([x_1, x_2], axis=-1)
    # Map to latent
    w = latent_learner(noise)

    # Pass through Generator
    samples, _ = generator([w])

    # Save for later examination
    torchvision.utils.save_image(
        samples.detach(),
        f"samples/samples_{iter_}.png",
        nrow=1,
        normalize=True,
        range=(-1, 1),
    )

In [None]:
def disable_grad(model):
    for _, param in model.named_parameters():
        param.requires_grad = False

def train(device, max_iters=150):
    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)

    generator = Generator(size=256, style_dim=512, n_mlp=8).to(device)
    discriminator = Discriminator(size=256).to(device)
    latent_learner = LatentLearner().to(device)

    vgg = torchvision.models.vgg19(weights='IMAGENET1K_V1').features.to(device).eval()

    # No need for gradients on the parameters of these
    disable_grad(generator)
    disable_grad(vgg)

    # Take sub-networks from the vgg, later used to compute style loss
    subnetworks = loss.subnetworks(vgg, max_layers=5)

    # Garbage collection
    del vgg
    gc.collect()
    torch.cuda.empty_cache()

    # Load checkpoint and weights
    ckpt = torch.load("550000.pt")

    generator.load_state_dict(ckpt["g_ema"], strict=False)
    generator.eval()

    discriminator.load_state_dict(ckpt["d"])

    # Initialize optimizers (no optimizer for generator :)
    disc_opt = torch.optim.Adam(
        discriminator.parameters(),
        lr = 5e-4,
        betas = (0.0, 0.99)
    )
    latent_learner_opt = torch.optim.Adam(
        latent_learner.parameters(),
        lr = 5e-4,
        betas = (0.0, 0.99)
    )

    # Simple transformation pipeline
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # Create simple dataset & loader
    x_1 = torch.randn(10, 14, 512, device=device)
    x_2 = torch.randn(10, 14, 512, device=device)
    #x = torch.randn(10, 1024, device=device)

    dataset = Dataset(path="./babies", device=device, transforms=transforms)

    bar = tqdm(range(max_iters))

    d_loss_hist = list()
    g_loss_hist = list()

    # 150 iterations
    for idx in bar:
        i = np.random.choice(10, size=4, replace=False)
        imgs = dataset[i]
#        noise = x[i]

        idx_1 = np.random.choice(10, size=imgs.shape[0], replace=False)
        idx_2 = np.random.choice(10, size=imgs.shape[0], replace=False)

        x = torch.cat([x_1[idx_1], x_2[idx_2]], axis=-1)

        ##### Adversarial Loss ##### 
        # first forward pass
        w = latent_learner(x)
        samples, _ = generator([w], input_is_latent=True)

        real_scores = discriminator(imgs)
        fake_scores = discriminator(samples)

#        print(f"fake_scores: {fake_scores.view(-1).cpu()}, real_scores: {real_scores.view(-1).cpu()}")

        d_loss = loss.d_logistic_loss(real_scores, fake_scores)

        # optimization step on discriminator
        disc_opt.zero_grad()
        d_loss.backward()
        disc_opt.step()

        # second forward pass (needed)
        w = latent_learner(x)
        samples, _ = generator([w], input_is_latent=True)

        fake_scores = discriminator(samples)

        g_loss = 5 * (1 - idx/max_iters) * loss.g_nonsaturating_loss(fake_scores)
        #g_loss = loss.g_nonsaturating_loss(fake_scores)

        # optimization step on latent learner
        latent_learner_opt.zero_grad()
        g_loss.backward()
        latent_learner_opt.step()

#        bar.set_description(f"d_loss: {d_loss.cpu():.2f}, g_loss: {g_loss.cpu():.2f}")

        d_loss_hist.append(d_loss.detach().cpu())
        g_loss_hist.append(g_loss.detach().cpu())
#        continue

        ##### Style Loss #####
        img_idx = np.random.choice(10, size=5, replace=False)
        imgs = dataset[img_idx]

        idx_1 = np.random.choice(10, size=imgs.shape[0], replace=False)
        #idx_2 = np.random.choice(10, size=imgs.shape[0], replace=False)

        x = torch.cat([x_1[idx_1], x_2[img_idx]], axis=-1)

        w = latent_learner(x)
        samples, _ = generator([w], input_is_latent=True)

        style_loss = 0.0

        for i, img in enumerate(imgs):
            style_loss += 50 * loss.style_loss(subnetworks, img, samples[i])

        # take mean over the batch
        style_loss /= len(imgs)

        # optimization step on latent learner
        latent_learner_opt.zero_grad()
        style_loss.backward()
        latent_learner_opt.step()

        bar.set_description(f"d_loss: {d_loss.cpu():.2f}, g_loss: {g_loss.cpu():.2f}, style_loss: {style_loss:.2f}")

        if (idx+1) % 50 == 0:
            save_samples(x_1, x_2, generator, latent_learner, idx+1)

    return torch.cat([x_1, x_2], axis=-1), latent_learner, generator, d_loss_hist, g_loss_hist

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
noise, latent_learner, generator, d_loss_hist, g_loss_hist = train(device, max_iters=50)

In [None]:
plt.plot(d_loss_hist, label="d_loss")
plt.plot(g_loss_hist, label="g_loss")
plt.legend()
plt.plot()

In [None]:
os.makedirs("ckpts", exist_ok=True)

torch.save(noise, "ckpts/noise.pt")
torch.save(latent_learner.state_dict(), "ckpts/latent_learner.pt")
torch.save(generator.state_dict(), "ckpts/generator.pt")

# Loading a pre-trained model and computing qualitative samples/outputs from that model

In [None]:
generator = Generator(size=256, style_dim=512, n_mlp=8).to(device)
generator.load_state_dict(torch.load("ckpts/generator.pt"))
generator.eval()

latent_learner = LatentLearner().to(device)
latent_learner.load_state_dict(torch.load("ckpts/latent_learner.pt"))

noise = torch.load("ckpts/noise.pt")

In [None]:
with torch.no_grad():
    # Map to latent
    w = latent_learner(noise)

    # Pass through Generator
    samples, _ = generator([w])

    # Save for later examination
    torchvision.utils.save_image(
        samples.detach(),
        "samples.png",
        nrow=1,
        normalize=True,
        range=(-1, 1),
    )

    from PIL import Image

    display(Image.open("samples.png"))

**It is obvious that the generator collapsed, and the results are not baby images.**

We believe that this is a  direct consequence of us not being able to fully understand the algorithm.