# Paper Information
Paper: SinGAN: Learning a Generative Model from a Single Natural Image, https://arxiv.org/abs/1905.01164

Authors: Tamar Rott Shaham, Tali Dekel, Tomer Michaeli

Code Authors: Ataberk Dönmez, Deniz Sayın 

# Initialization & Hyperparameters

The module imports and device selection sections must always be run before being able to run any of the other sections.

## Module Imports

In [None]:
import os
import subprocess as sp
from time import time

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

# our own helper modules
from models import *
from utils import *

## Device Selection

In [None]:
# arguments cell
DEVICE = 'cuda'

## Training Hyperparameters

Please note that running this cell and setting the parameters is only necessary for training, generating the results does not require any of these parameters.

In [None]:
# arguments cell
# image to train on
IMG_PATH = 'images/birds.png'
SAVE_PATH = 'models/birds.pt'

# training hyperparameters, 
# as given in the paper
GEN_LEARNING_RATE = 0.0005
CRIT_LEARNING_RATE = 0.0005
BETA_1 = 0.5  # beta parameters for the ADAM optimizer
BETA_2 = 0.999
NUM_ITERS = 2000  # number of iterations at each scale
LR_DROP_STEP = 1600  # step at which to decay learning rate
LR_DROP_MULT = 0.1  # lr decay multiplier
GEN_STEP_PER_ITER = 3  # optimization step per iteration for the generator
CRIT_STEP_PER_ITER = 3  # ... and the critic
REC_ALPHA = 10.0  # reconstruction loss weight
GP_WEIGHT = 0.1  # gradient penalty weight

# architecture details
NUM_SCALES = 9  # number of training scales, the most important parameter!
KERNEL_SIZE = 3  # kernel size of the convolutions, no need to change
NUM_BLOCKS = 5  # number of blocks in each network, no need to change
INITIAL_KERNEL_COUNT = 32  # the initial amount of kernels in each conv layer
INCREASE_KERNEL_COUNT_EVERY = 4  # ... and how often to double their amount
NOISE_BASE_STD = 0.5  # base noise stdev multiplier at upper scales (scaled with rmse), different values in the range [0.1, 1.0] seem to work
FIRST_SCALE_NOISE_STD = 1.0  # noise stdev at the first scale, standard normal, different because there is no input image
MAX_INPUT_SIZE = 250  # if the input image is larger than this, resize its long edge
UPSAMPLING_FACTOR = 4/3  # the factor by which images are upsampled at each scale
UPSAMPLING_MODE = 'bilinear'  # mode used when upscaling during training
DOWNSAMPLING_MODE = 'bicubic'  # mode used when downscaling the original input

# extra settings
PRINT_EVERY = 500  # show progress every X iterations
SEED = 796  # random seed value

# Training & Saving a Model

Takes between half an hour and a few hours depending on the GPU you have access to. Training a full model using only a CPU is not viable!

In [None]:
# training cell
# closures for easy use depending on settings, so that
# we can skip providing every single argument at each iteration
def make_generator(kernel_count, noise_std):
  sgnet = SGNet(NUM_BLOCKS, kernel_count, KERNEL_SIZE, final_activation=nn.Tanh(), output_channels=3).to(DEVICE)
  return SGGen(sgnet, noise_std)
  
def make_critic(kernel_count):
  return SGNet(NUM_BLOCKS, kernel_count, KERNEL_SIZE, final_activation=None, output_channels=1).to(DEVICE)

def make_optimizer_and_scheduler(net, net_learning_rate):
  optimizer = torch.optim.Adam(net.parameters(), net_learning_rate, (BETA_1, BETA_2))
  sched = torch.optim.lr_scheduler.StepLR(optimizer, LR_DROP_STEP, LR_DROP_MULT)
  return optimizer, sched

In [None]:
# training cell
# load the image along with its downsampled versions and their exact sizes
downsampling_factor = 1.0 / UPSAMPLING_FACTOR
scaled_origs, exact_sizes = load_with_reverse_pyramid(IMG_PATH, MAX_INPUT_SIZE, downsampling_factor, NUM_SCALES, 
                                                      mode=DOWNSAMPLING_MODE, device=DEVICE, verbose=True)
  
original_image = scaled_origs[-1]
print('Input image:')
plt.imshow(normed_tensor_to_np_image(original_image))
plt.show()

In [None]:
# training cell
# seed stuff
seed_rngs(SEED)

# create the scaled images
coarsest_exact_size = exact_sizes[0]

# initialize the constant noise used in reconstruction
z_rec_coarsest = FIRST_SCALE_NOISE_STD * torch.randn_like(scaled_origs[0], device=DEVICE)
z_rec = [z_rec_coarsest] # a zero tensor is appended after each scale

# constant zero input for the coarsest scale during training
coarsest_zero_input = torch.zeros_like(z_rec_coarsest)


In [None]:
# training cell
# loop values
training_start = time()
kernel_count = INITIAL_KERNEL_COUNT
generators, critics = [], []
gen_losses, crit_losses = [], []
for scale_index in range(NUM_SCALES):
  print('****************************\nScale {}'.format(scale_index))

  # get the original image at the current scale
  scale_orig_img = scaled_origs[scale_index]

  # things to be done after the first scale
  if scale_index > 0:
    # use RMSE to determine the standard deviation of the input noise
    with torch.no_grad():
      reconstruction = generator(z_rec)  # specific reconstruction noise
    scaled_reconstruction, _ = exact_interpolate(reconstruction, UPSAMPLING_FACTOR, exact_sizes[scale_index-1], UPSAMPLING_MODE)
    rmse = torch.sqrt(F.mse_loss(scaled_reconstruction, scale_orig_img))
    print('RMSE: {:.2f}'.format(rmse))
    # if the scale matches, increase kernel count
    if scale_index % INCREASE_KERNEL_COUNT_EVERY == 0:
      kernel_count *= 2
    # add a zero tensor to the reconstruction noise list
    # since it is defined as [z*, 0, 0, 0...] for some z*
    z_rec.append(torch.zeros_like(scale_orig_img))
      
  # create the noise sampler based on the RMSE
  # the first scale's stdev is different due to the zero input,
  # the noise has to be stronger than in the upper layers, although
  # this is alleviated by the rmse multiplier even if 
  # FIRST_SCALE_NOISE_STD == NOISE_BASE_STD
  scale_noise_std = FIRST_SCALE_NOISE_STD if scale_index == 0 else rmse * NOISE_BASE_STD

  ## initialize the generator
  # create a generator for this specific scale and initialize it
  scale_generator = make_generator(kernel_count, scale_noise_std)
  # copy weights from previous if possible, and add to the list
  initialize_net(scale_generator, generators)
  
  # create a single generator view from the stack of generators
  generic_generator = MultiScaleSGGenView(generators, UPSAMPLING_FACTOR, UPSAMPLING_MODE)
  # fix the input parameters for easier forward calls
  generator = FixedInputSGGenView(generic_generator, coarsest_zero_input, coarsest_exact_size)
  
  ## initialize the critic (discriminator)
  critic = make_critic(kernel_count)
  initialize_net(critic, critics)

  # create the optimizers and schedulers
  gen_optimizer, gen_sched = make_optimizer_and_scheduler(generator, GEN_LEARNING_RATE)
  crit_optimizer, crit_sched = make_optimizer_and_scheduler(critic, CRIT_LEARNING_RATE)

  # print norms to ensure correct operation
  gen_norms = ['G{}: {:.3f}'.format(i, sum_param_norms(g)) for i, g in enumerate(generators)]
  crit_norms = ['C{}: {:.3f}'.format(i, sum_param_norms(c)) for i, c in enumerate(critics)]
  print('Generator norms: ' + ', '.join(gen_norms))
  print('Critic norms: ' + ', '.join(crit_norms))
  
  # perform training
  for step in range(NUM_ITERS):

    for _ in range(CRIT_STEP_PER_ITER):
      crit_optimizer.zero_grad()
      
      # the model handles noise sampling on its own
      fake_img = generator()
      
      # gradient & adversarial loss
      grad_loss = gradient_penalty(critic, fake_img, scale_orig_img)
      fake_loss = critic(fake_img).mean()
      real_loss = -critic(scale_orig_img).mean()
      crit_loss =  fake_loss + real_loss + GP_WEIGHT * grad_loss
      
      optimization_step(crit_loss, crit_optimizer, crit_sched, crit_losses)

    # zero gradient before beginning because
    # generator was used in the crit. training
    for _ in range(GEN_STEP_PER_ITER):
      gen_optimizer.zero_grad()

      fake_img = generator()

      # adversarial & reconstruction loss
      adv_loss = -critic(fake_img).mean()
      rec_img = generator(z_rec)
      rec_loss = F.mse_loss(scale_orig_img, rec_img)
      gen_loss = adv_loss + REC_ALPHA * rec_loss
      
      optimization_step(gen_loss, gen_optimizer, gen_sched, gen_losses)

    if step % PRINT_EVERY == 0:
      # print some details
      print('Step: {}'.format(step))
      print('Generator adv: {:.3f}, rec: {:.3f}'.format(adv_loss.item(), rec_loss.item()))
      print('Critic fake: {:.3f} real: {:.3f} grad: {:.3f}'.format(fake_loss.item(), real_loss.item(), grad_loss.item()))
      if step != 0:
        elapsed = time() - last_print
        print('Steps per second: {:.2f}'.format(PRINT_EVERY / elapsed))
        
      # example noise sample at highest scale
      with torch.no_grad():
        fake_example = generator()
      plt.imshow(normed_tensor_to_np_image(fake_example))
      plt.show()
      last_print = time()

  # show the reconstruction at the end of training
  print('Reconstruction vs Original:')
  with torch.no_grad():
    final_rec = generator(z_rec)
  plt.subplot(121)
  plt.imshow(normed_tensor_to_np_image(final_rec))
  plt.subplot(122)
  plt.imshow(normed_tensor_to_np_image(scale_orig_img))
  plt.show()

# save the model when done
save_model(SAVE_PATH, original_image, generators, critics, UPSAMPLING_FACTOR, UPSAMPLING_MODE, DOWNSAMPLING_MODE)

# show the total time the training took
training_duration = time() - training_start
print('Total training time in hours: {:.2f}'.format(training_duration / 3600))

# Load a Pre-trained Model and Sample Qualitative Results

In [None]:
# training hyper-parameters have no effect here,
# all the necessary information is in the model file
MODEL_PATH = 'models/birds9.pt'  # model file
NUM_SAMPLES = 10  # number of samples to display
INPUT_SCALE = 1  # input scale
OUTPUT_SIZE = None  # None keeps the size the model was trained with. Only works with input scale 0, try larger values such as (200, 1000)!
SEED = 797 # ideally, different from the one used in training
OUTPUT_FOLDER = 'samples' # if not None, the folder is created and samples are saved

# about the INPUT_SCALE:
# 0 is coarsest, 1 one upper etc.
# when the input_scale > 0, the
# scaled original image is provided as
# input to the input_scale and lower
# scales are ignored entirely

# re-seed for reproducibility
seed_rngs(SEED)

# load the pre-trained model and maybe set-up custom input
gen, original = load_generator(MODEL_PATH, INPUT_SCALE, OUTPUT_SIZE, device=DEVICE)

# show the original
print('Original: ')
plt.imshow(normed_tensor_to_np_image(original))
plt.show()

# generate and show (and maybe save) samples
if OUTPUT_FOLDER:
  os.makedirs(OUTPUT_FOLDER, exist_ok=True)
for i in range(NUM_SAMPLES):
  random_sample = gen()
  print('Sample {}: '.format(i))
  sample_uint = normed_tensor_to_np_image(random_sample)
  plt.imshow(sample_uint)
  plt.show()
  if OUTPUT_FOLDER:
    path = os.path.join(OUTPUT_FOLDER, '{}.png'.format(i))
    Image.fromarray(sample_uint).save(path, 'PNG')
  

# Quantitative Results

In this section, we aim to reproduce the results of Table 2 (page 7) in the original paper with our own implementation of SIFID (Single-Image Fréchet Inception Distance).

However, we cannot hope to reproduce the results exactly. This is because the authors calculated the SIFID values on a survey dataset they prepared for use on Amazon Mechanical Turk. There are 50 different images, with a single high variance and mid variance sample selected for each (probably the best sample they obtained from an unknown number of generated samples). Since we did not have the time to fully train fifty different models and select the best possible sample, we instead opted to do two different things, explained above each cell below.

### Important note:
- Due to an issue in scipy version 1.4+, loading the pre-trained inception model takes too long, and scipy has to be downgraded to 1.3 to ensure proper operation (see the related issue: https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756/12). If you are using your own computer, please configure the virtual environment accordingly. If you are using Google colab, you can use the following code cell to downgrade scipy (remember to restart your kernel/runtime afterwards!):

In [None]:
# DO NOT RUN THIS CELL UNLESS NECESSARY, AS EXPLAINED ABOVE!!
!pip uninstall scipy -y
!pip install scipy==1.3.3
!pip list | grep scipy

### Initialization

In [None]:
from sifid import *

# the first time you run this cell, the pretrained inception model
# weights have to be downloaded (around 100 MB), so please be patient!

calc = SIFIDCalculator(DEVICE)

### Table 2 Reproduction

In this cell, we use the survey images provided by the authors of the paper on github. We calculate the pairwise SIFID values using our own implementation and ensure that it works, as we obtain the same results as shown in the paper on Table 2.

In [None]:
USER_STUDY_BASE = 'user study'
USER_STUDY_REAL = os.path.join(USER_STUDY_BASE, 'real')
USER_STUDY_MVAR = os.path.join(USER_STUDY_BASE, 'fake_mid_variance')
USER_STUDY_HVAR = os.path.join(USER_STUDY_BASE, 'fake_high_variance')

mid_sifid = calc.calculate_average_sifid_folders(USER_STUDY_REAL, USER_STUDY_MVAR)
high_sifid = calc.calculate_average_sifid_folders(USER_STUDY_REAL, USER_STUDY_HVAR)

print('SIFID at scale N (coarsest): {:.2f}'.format(round(high_sifid, 2)))
print('SIFID at scale N-1: {:.2f}'.format(round(mid_sifid, 2)))

### SIFID on our own models
In this section, we manually calculate an average SIFID for each of the models we trained at different input scales over a number of samples. Then, we display each average SIFID on a table along with the average over all our models. Obviously the setting is a little different from the paper due to the small number of models we have, as well as some images having multiple models with different settings, and is also an average over multiple images per model rather than the best sample for each. But we believe that the results are similar enough for our purposes.

In [None]:
from tqdm import tqdm
MODEL_DIR = 'models'
SAMPLES_PER_MODEL = 50  # around 4 sec per model with a GPU, 5-10 times slower using CPU
SEED = 798  # yet another seed for reproducibility of the cell!

seed_rngs(SEED)

# calculate an average SIFID for each model and scale
sifids = []
scales = [0, 1]
entries = list(os.scandir(MODEL_DIR))
for model_entry in tqdm(entries):
  row = []
  for input_scale in scales:
    gen, original = load_generator(model_entry.path, input_scale, device=DEVICE)
    avg_sifid = calc.calculate_average_sifid(gen, original, SAMPLES_PER_MODEL)
    row.append(avg_sifid)
  sifids.append(row)


In [None]:
import plotly.graph_objects as go

# calculate the average and round the values for printing
models = [entry.name for entry in entries]
average_per_scale = np.mean(sifids, axis=0, keepdims=True)
table_array = np.concatenate((sifids, average_per_scale), axis=0)
table_array = np.around(table_array, 2)

# format into table content and print
row_header = ['Input Scale', *models, 'Average']
col_header = [str(s) for s in scales]
table_content = [col_header] + [['{:.2f}'.format(x) for x in row] for row in table_array]

fig = go.Figure(data=[go.Table(header=dict(values=row_header),
                cells=dict(values=table_content),
                columnwidth=40)])
fig.update_layout(width=1500)
fig.show()

# Implementation challenges

## Related to the paper

Thanks to the detailed treatment of the training process given by the authors both in the original paper and the paper's supplementary material (which can be accessed from the official webpage: https://webee.technion.ac.il/people/tomermic/SinGAN/SinGAN.htm), we knew most of the details before getting into the implementation and did not have too much difficulty, except for a few details: 

- The most difficult part for us was getting the gradient penalty right. The formulation in the original paper is pretty straightforward, the critic (discriminator) is viewed as a single function, and then the norm of its gradients is used for regularization. Note that the gradients have the shape of the network's input. Let's go over three increasingly difficult cases:
  - This is straightforward in the fully connected case, having an $(N, V)$ input with $N$ samples and $V$ dimensions, we can simply compute the norm over each sample and take their average. 
  - In the convolutional network case this gets a little more confusing, the input is $(N, C, H, W)$ with $C$ being the number of channels and $H$ and $W$ the dimensions (the output is still a scalar fakeness score). In this case, if we take the norm over the whole input (flatten all dimensions except the batch), the gradient norm scales with the size of the image. This can be offset with the multiplier of the gradient penalty in a setting where the input size never changes, but in our case we re-use the weights of the networks for different input images, which is why we need the output to not change too much with the dimensions. Thus, one correct approach is taking the norm only over the channel dimension and then averaging over every pixel of every sample.
  - Our case is not the simple discriminator case because we use a patch discriminator which has an image output. A simple approach might be taking the mean of this output to obtain a single scalar, and then apply the standard discriminator approach. However, this fails our purpose, because we take the mean both before getting the gradient and afterwards, which means that our gradient norm will become smaller with larger input sizes rather than not change. Instead, we need to take the sum of the discriminator's output rather than the mean, so that the mean is applied only once. This was the part that took us the longest to figure out before getting the gradient penalty right.

We did face a few more issues during the implementation which caused us to chase bugs for a while, but those were not related to the paper or lack of information; they were simply our own small coding mistakes.