# Paper Information
Name: GANSeg: Learning to Segment by Unsupervised Hierarchical Image Generation
Link: https://arxiv.org/pdf/2112.01036.pdf
    
# Authors of Code
Ahmet Kağan Kaya - 2598555 - "kagan.kaya@metu.edu.tr"

# Paper Summary
This paper proposes a GAN-based approach that generates images conditioned on latent masks, thereby alleviating full or weak annotations required by previous approaches. Part and background appearances are controlled by latent space and image generation is done by using both Point and Background Generator. Without requiring supervision of masks or points, this strategy increases robustness of mask to viewpoint and object position changes. It also lets us generate image-mask pairs for training a segmentation network, which outperforms state-of-the-art unsupervised segmentation methods on established benchmark.

<div align="center">
<img src="archi.png" width="700"/>
<figcaption>Overall Architecture with 3 Different Levels</figcaption>
</div>

# Method
## Level 1: Point Generation and Part Scale
In the first level, independent noise vectors are used to generate the locations and appearances of $K$ parts. It is found that training can be stabilized by first predicting $n_{per} × K$ points. The part location and scale are computed from the mean and standard deviation of the corresponding $n_{per}$ points, which also regularizes training.

$$ x_k = \frac{1}{n_\text{per}}\sum_{i=1}^{n_\text{per}} x_k^i, \quad \sigma_k = \frac{\sqrt{\sum_i^{n_\text{per}}\|x_k^i- x_k\|^2}}{n_\text{per}-1}$$ 
$$ \text{with }\{x_k^{1}, ...,x_k^{n_\text{per}}\}_{k=1}^K = \text{MLP}_\text{point}(z_\text{point}) $$ 

Again 3-layer multi-layer perceptron (MLP) is used to map $z_\text{point}$ to $n_\text{per}\times K$ points $\{x_k^{1}, ...,x_k^{n_\text{per}}\}_{k=1}^K$. 
Then part locations are calculated in terms of $\{x_1,...,x_K\}$ and part scales $\{\sigma_1,...,\sigma_K\}$

<div align="center">
<img src="level1.png" width="700"/>
<figcaption>Strategy for Level 1</figcaption>
</div>

## Level 2: From Points to Masks
In the second level, Gaussian heatmaps are use to create masks for each image and model is used to generate the local independence and positional encoding  masks relative to the predicted part location. Gaussian heatmaps are generated for each part using the mean and standard deviation of each part defined in Level 1. The embedding $w_k$ is then multiplied with every pixel of the corresponding heatmap, generating a spatially localized embedding map. All $K$ part-specific embeddings are summed to form a single feature map $W_\text{mask}\in\R^{D_\text{emb}\times H\times W}$.
$$ H_k(p)=\exp\left(-\|p- x_k\|_2^2 / \sigma_k^2\right) $$
$$ W_\text{mask}(p) = \sum_{k=1}^K H_k(p)w_k. $$

Generated embedding map $W_\text{mask}$ will subsequently be used to generate masks, together with the mask starting tensor $M^{(0)}\in\R^{D_\text{emb}\times H\times W}$. HOwever, instead of using constant tensor, it is better to use low frequenct positional embedding:

$$ M^{(0)}(p) = [\sin(\pi\text{FC}([p-x^1_1, ..., p-x^{n_\text{per}}_K])), 
        \cos(\pi\text{FC}([p-x^1_1, ..., p-x^{n_\text{per}}_K]))] $$

After that all obtain results are put on the SPADE ResBlock which is proposed in this paper. SPADE takes two feature maps as input. First use BatchNorm to
normalize input followed by two convolutions to map to the new mean and new standard deviation.

$$ M^{(i)} = \text{SPADE ResBlock} (M^{(i-1)}, W_\text{mask})\\
        M = \text{softmax}(M^{(T_\text{mask})}) $$
        
<div align="center">
<img src="mask_gen.png" width="700"/>
<figcaption>Strategy for Level 2</figcaption>
</div>

## Level 3: Mask-conditioned Image Generation
In this level, foreground and the background are generated separately and blend them linearly by reusing the masks from the previous level. Embedding maps of both foreground and background are generated seperately:

$$ W_\text{fg}(p) = \sum_{k=1}^K M_k(p) w_k. $$
$$ W_\text{bg} = \text{MLP}_\text{bg\_app}(z_\text{bg\_app}). $$

These wieghts are used to generate background and foreground:
$$ F^{(i)} =  \text{SPADE ResBlock} (F^{(i-1)}, W_\text{fg}) $$
$$ B^{(i)} = \text{AdaIN ConvBlock}  (B^{(i-1)}, W_\text{bg}) $$

All obtained results are concataned to get final mask and result:
$$ I = \text{Conv}((1-M_\text{bg})\otimes F + M_\text{bg}\otimes B) $$

<p>
  <img src="foreground.png" width="700" />
  <img src="background.png" width="700" /> 
</p>


In [None]:
import argparse
import importlib
import json
import math
import os
import torch
import torch.nn.functional as F
from dataset import CelebAWildTrain
import numpy as np
from utils import *
import os
from model import Generator
from other_models import Discriminator
from tqdm import tqdm
from evaluate import evaluate
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

## Parameters
All parameters are stored in object of Params:

All important hyperparemeters such as learning rates are obtained from paper.

In [None]:
class Params:
    def __init__(self):
        self.latent_dim = 256
        self.cluster_number = 8
        self.n_per_kp = 4
        self.batch_size = 16
        self.lr_gen = 1e-4 
        self.lr_disc = 4e-4 
        self.num_workers = 0
        self.data_root = ''
        self.class_name = 'celeba_wild'
        self.image_size = 128
        self.embedding_dim = 128

## Create Model

In [28]:
args = Params()
args.log = "GanSEG"

os.makedirs(args.log, exist_ok=True)
with open(os.path.join(args.log, 'parameters.json'), 'wt') as f:
    json.dump(args.__dict__, f, indent=2)

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
generator = Generator(args).to(device)
discriminator = Discriminator().to(device)
optim_disc = torch.optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.5, 0.9))
optim_gen = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=args.lr_gen, betas=(0.5, 0.9))

generator = torch.nn.DataParallel(generator)
discriminator = torch.nn.DataParallel(discriminator)

checkpoint_dir = os.path.join(args.log, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)

In [29]:
print("Generator Architecture")
print(generator)
print("Discriminator Architecture")
print(discriminator)

Generator Architecture
DataParallel(
  (module): Generator(
    (keypoints_embedding): Embedding(8, 128)
    (mask_spade_blocks): ModuleList(
      (0): SPADE(
        (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_std1): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_mean1): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_std2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_mean2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(128, 

## Dataset Preperation

In [30]:
dataset = CelebAWildTrain(args.data_root, args.image_size)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
                                       num_workers=args.num_workers, pin_memory=True, drop_last=True)
test_input_batch = {'input_noise{}'.format(noise_i): torch.randn(args.batch_size, *noise_shape).to(device)
                    for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
test_input_batch['bg_trans'] = torch.rand(args.batch_size, 1, 2).to(device) * 2 - 1

In [31]:
dataset.imgs.shape

torch.Size([45609, 3, 128, 128])

In [None]:
for epoch in range(500):    
    discriminator.train()
    generator.train()
    total_disc_loss = 0
    total_gen_loss = 0
    prog = tqdm(enumerate(data_loader))
    for batch_index, batch in prog:
        # update discriminator
        optim_disc.zero_grad()
        optim_gen.zero_grad()

        batch = {'img': batch['img'].to(device)}
        batch['img'].requires_grad_()
        input_gen = {'input_noise{}'.format(noise_i): torch.randn(args.batch_size, *noise_shape).to(device)
                    for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
        input_gen['bg_trans'] = torch.rand(args.batch_size, 1, 2).to(device) * 2 - 1

        fake_input = generator(input_gen)
        real_disc = discriminator(batch)
        fake_disc = discriminator(fake_input)
        disc_loss = F.softplus(fake_disc).mean() + F.softplus(-real_disc).mean()
        if batch_index % 4 == 0:
            disc_loss = disc_loss + penalty(batch['img'], real_disc)
        disc_loss.backward()
        total_disc_loss += disc_loss.item()
        optim_disc.step()

        # update generator
        optim_disc.zero_grad()
        optim_gen.zero_grad()

        input_gen = {'input_noise{}'.format(noise_i): torch.randn(args.batch_size, *noise_shape).to(device)
                        for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
        input_gen['bg_trans'] = torch.rand(args.batch_size, 1, 2).to(device) * 2 - 1
        fake_input = generator(input_gen)
        fake_disc = discriminator(fake_input)
        gen_loss = F.softplus(-fake_disc).mean()
        gen_loss.backward()
        total_gen_loss += gen_loss.item()
        optim_gen.step()
        disc_loss, gen_loss = total_disc_loss / len(data_loader) / 2, total_gen_loss / len(data_loader)
        if batch_index == 1000:
            break
        prog.set_description(f"Epoch: {epoch + 1}, disc_loss:  {disc_loss}, gen_loss: {gen_loss}")
    evaluate(generator, test_input_gen, args, epoch)

    if (epoch + 1) % 1 == 0:
        torch.save({'generator': generator.module.state_dict(), 'discriminator': discriminator.module.state_dict(), 'optim_gen': optim_gen.state_dict(), 'optim_disc': optim_disc.state_dict(),},os.path.join(checkpoint_dir, 'epoch_{}.model'.format(epoch)))

In [26]:
with open("logs.txt") as f:
    lines = f.read()
    print(lines)

Epoch: 1, disc_loss:  0.91602554521552, gen_loss: 2.9154146088120627: : 1000it [17:04,  1.02s/it]  
Epoch: 2, disc_loss:  0.19575286904447958, gen_loss: 1.3049921820260453: : 1000it [17:09,  1.03s/it]  
Epoch: 3, disc_loss:  0.16392073070793822, gen_loss: 1.1133889933113466: : 1000it [17:23,  1.04s/it]  
Epoch: 4, disc_loss:  0.1673950523660894, gen_loss: 1.0478422773302647: : 1000it [17:18,  1.04s/it]    
Epoch: 5, disc_loss:  0.16378850753892932, gen_loss: 0.906951866683207: : 1000it [17:26,  1.05s/it]   
Epoch: 6, disc_loss:  0.1642864903337077, gen_loss: 0.8528890935204139: : 1000it [17:25,  1.05s/it]   
Epoch: 7, disc_loss:  0.15886002495100623, gen_loss: 0.7847563005016561: : 1000it [17:24,  1.04s/it]   
Epoch: 8, disc_loss:  0.16731682119139454, gen_loss: 0.7638920359601055: : 1000it [17:26,  1.05s/it]   
Epoch: 9, disc_loss:  0.1768046492233611, gen_loss: 0.7328230114464174: : 1000it [17:24,  1.04s/it]    
Epoch: 10, disc_loss:  0.18231592197167246, gen_loss: 0.7279043149006994

## Results from Train Dataset

In [None]:
import imageio
import os
images = []
images_seg = []
filenames = [filename for filename in os.listdir("./GanSEG/results/") if "seg" not in filename]
for filename in filenames:
    images.append(imageio.imread(os.path.join("./GanSEG/results/", filename)))
imageio.mimsave('./GanSEG/gan.gif', images, duration= 2)
filenames_seg = [filename for filename in os.listdir("./GanSEG/results/") if "seg" in filename]
for filename in filenames_seg:
    images_seg.append(imageio.imread(os.path.join("./GanSEG/results/", filename)))
imageio.mimsave('./GanSEG/gan_seg.gif', images_seg, duration= 2)

<img src="./GanSEG/gan.gif" align="left"/><img src="./GanSEG/gan_seg.gif" align="left"/>

## Load the Pretrained Model

In [25]:
args_load = Params()

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
generator = Generator(args).to(device)
gen_checkpoint = torch.load("./GanSEG/checkpoints/best_model.model",
                            map_location=lambda storage, location: storage)
generator.load_state_dict(gen_checkpoint['generator'])
generator

Generator(
  (keypoints_embedding): Embedding(8, 128)
  (mask_spade_blocks): ModuleList(
    (0): SPADE(
      (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_std1): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_mean1): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_std2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_mean2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm3): Batc

In [None]:
loaded_input = {'input_noise{}'.format(noise_i): torch.randn(args.batch_size, *noise_shape).to(device)
                    for noise_i, noise_shape in enumerate(generator.noise_shapes)}
loaded_input['bg_trans'] = torch.rand(args.batch_size, 1, 2).to(device) * 2 - 1

evaluate(generator, loaded_input, args, "test")


<img src="./GanSEG/results/test.png" align="left"/><img src="./GanSEG/results/test_segmaps.png" align="left"/>

## Challanges Faced
### Model Architecture

GANSeg model architecture basically consists of 2 different module: Generator and Discriminator. Discriminator is basically fully connected layer and it could be implemented according to what paper suggests. However, Generator module requires different submodules which is not explained in detail. SPADE and AdaIN sub modules can vbe given as example. Therefore, some assumptions were made by implementing this module. Also implementing forward function of the Generator module was pretty overwhelming because 3 important level which is explained in above were need to be implemented in here. Especially, "Level 2: From Points to Masks" part could not be impelemented properly because equations that paper suggests seem incomplete in terms of the matrix dimensions. In order to match matrix dimensions, some transpose operations were done. However, it is possible that these changes lead to some performance drops.

### Dataset
Dataset and Dataloader modules were implemented individually however, paper suggests that random input noises for generation and background and foreground of the images should be stored in a dictionary. Especially foreground and background information is not in default dataset folder. Therefore, these information are taken from the official implementation of the authors. Nevertheless, it is possible to download prepared dataset from drive link which is given in download_dataset.sh. Also it will be implemented in v2.

### Gradient Penalty
Gradient penalty is suggested in the different locations in the paper. Math part is the penalty seems incomplete in the paper and therefore, gradient penalty part could not be implemented in v2. 

### Hyperparemeter Tuning and Results
Hyperparamaters were selected according to what paper suggests. These values lead to pretty successful performance(training loss). However, in results, some performance drop was observed in some class information such as blue region should not cover mouth and nose regions. Different hyperparameters values were used to optimize model but these do not affect performance remarkably. 