Skip to content

boschresearch/MCCGAN

Repository files navigation

Multi-Class Multi-Instance Count Conditioned Adversarial Image Generation MCC-StyleGAN2

Implementation of the ICCV 2021 paper "Multi-Class Multi-Instance Count Conditioned Adversarial Image Generation MCC-StyleGAN2". The paper can be found here. The code allows the users to reproduce and extend the results reported in the paper. Please cite the above paper when reporting, reproducing or extending the results.

Purpose of the project

This software is a research prototype, solely developed for and published as part of the publication. It will neither be maintained nor monitored in any way.

Setup.

MCC-SimpleGAN

This repository is a PyTorch implementation of a simple convolution based architecture with the same concept for toy experiments.

Requirements

Please add the project folder to PYTHONPATH and install the required dependencies:

Dependencies

  • python 3.6.10
  • pytorch 1.4.0
pip install -r requirements.txt

Training

To train the model, run this command:

python main.py --input_images <path_to_image_array> --countvec_path <path_to_count_csvfile> 

MCC-StyleGAN2

For MCC-StyleGAN2 repository we adapted the official tensorflow implementation of StyleGAN2 "https://github.com/NVlabs/stylegan2". For CityCount images, adaptive discriminator augmentation technique based on the implementation in "https://github.com/NVlabs/stylegan2-ada" is utilized while training. Please note that the network configuration and the loss functions however, remains the same for all datasets. The main differences are in the network architecture and the loss functions used in the training.

training/networks_stylegan2.py - The modified generator and discriminator for count conditioned image generation. To be precise we introduced count vector mapping to each layer in the mapping network. We also introduced dense connectivity where output from a layer is connected to all its following layers.

   def denseblock(x, resolution):
        dense = [None] * 5
        if (resolution > 2 and resolution < 7):
            dense[resolution - 3] = x
            
        if resolution - 3 > 0:
           #Denseskip0
                dense[0] = conv2dlayer(dense[0])
        if resolution - 4 > 0:
            #Denseskip1
                dense[1] = conv2dlayer(dense[1])
        if resolution - 5 > 0:
            #Denseskip2
                dense[2] = conv2dlayer(dense[2])
        if resolution - 6 > 0:
           #Denseskip3
                dense[3] = conv2dlayer(dense[3])
        if resolution - 7 > 0:
            #Denseskip4
                dense[4] = conv2dlayer(dense[4])

        if resolution > 3:
            #Denseskipx
                dense[resolution - 4] = conv2dlayer(dense[resolution - 4])
                for iter in range(0, resolution - 3):
                    x = x + dense[iter]
                x = x * (1 / np.sqrt(max(2, resolution - 3)))
        return x

training/loss.py - An additional count loss functions used for training the network

countlossfake = tf.reduce_mean(tf.squared_difference(fake_count_out, count_labels))
countlossreal = tf.reduce_mean(tf.squared_difference(real_count_out, count_labels))
countloss = (countlossfake + countlossreal) / 2

Incorporate these additional files provided in the MCC-StyeGAN2 directory to the corresponding folders in the original repo.

Dataset

  • To create MultiMNIST dataset, please refer to the repo here.

  • To create CLEVR dataset, please refer to the CLEVR directory adapted from the repo. Incorporate the additional files provided in the CLEVR directory to the corresponding folders in the original repo. The CLEVR directory includes image generation scripts specifically modified for CLEVR2 and CLEVR3 with multiprocessing enabled. Refer to the original repo README for instructions on running the image generation.

  • To create CityCount dataset, please refer to the Citycount directory. Please download the required dataset(leftImg8bit_trainvaltest.zip, gtBbox3d_trainvaltest.zip and gtBbox_cityPersons_trainval.zip) from here.

License

MCC-StyleGAN2 is open-sourced under the AGPL-3.0 license. See the LICENSE file for details. For a list of other open source components included in the MCC-StyleGAN2, see the file 3rd-party-licenses.txt. For further queries, please contact amrutha.saseendran21@gmail.com

About

Official Implementation of the paper "Multi-Class Multi-Instance Count Conditioned Adversarial Image Generation" (ICCV 2021).

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages