<a href="https://colab.research.google.com/github/mikaeilahmadi/CAMB/blob/master/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [0]:
import tensorflow

In [0]:
import os
import numpy as np
import imageio
import matplotlib.pyplot as plt

class ImageHelper(object):
    def save_image(self, generated, epoch, directory):
        fig, axs = plt.subplots(5, 5)
        count = 0
        for i in range(5):
            for j in range(5):
                axs[i,j].imshow(generated[count, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                count += 1
        fig.savefig("{}/{}.png".format(directory, epoch))
        plt.close()
        
    def makegif(self, directory):
        filenames = np.sort(os.listdir(directory))
        filenames = [ fnm for fnm in filenames if ".png" in fnm]
    
        with imageio.get_writer(directory + '/image.gif', mode='I') as writer:
            for filename in filenames:
                image = imageio.imread(directory + filename)
                writer.append_data(image)

In [0]:
from __future__ import print_function, division

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Keras modules
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

class GAN():
    def __init__(self, image_shape, generator_input_dim, image_hepler):
        optimizer = Adam(0.0002, 0.5)
        
        self._image_helper = image_hepler
        self.img_shape = image_shape
        self.generator_input_dim = generator_input_dim

        # Build models
        self._build_generator_model()
        self._build_and_compile_discriminator_model(optimizer)
        self._build_and_compile_gan(optimizer)

    def train(self, epochs, train_data, batch_size):
        
        real = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        history = []
        for epoch in range(epochs):
            #  Train Discriminator
            batch_indexes = np.random.randint(0, train_data.shape[0], batch_size)
            batch = train_data[batch_indexes]
            genenerated = self._predict_noise(batch_size)
            loss_real = self.discriminator_model.train_on_batch(batch, real)
            loss_fake = self.discriminator_model.train_on_batch(genenerated, fake)
            discriminator_loss = 0.5 * np.add(loss_real, loss_fake)

            #  Train Generator
            noise = np.random.normal(0, 1, (batch_size, self.generator_input_dim))
            generator_loss = self.gan.train_on_batch(noise, real)

            # Plot the progress
            print ("---------------------------------------------------------")
            print ("******************Epoch {}***************************".format(epoch))
            print ("Discriminator loss: {}".format(discriminator_loss[0]))
            print ("Generator loss: {}".format(generator_loss))
            print ("---------------------------------------------------------")
            
            history.append({"D":discriminator_loss[0],"G":generator_loss})
            
            # Save images from every hundereth epoch generated images
            if epoch % 100 == 0:
                self._save_images(epoch)
                
        self._plot_loss(history)
        self._image_helper.makegif("generated/")        
    
    def _build_generator_model(self):
        generator_input = Input(shape=(self.generator_input_dim,))
        generator_seqence = Sequential(
                [Dense(256, input_dim=self.generator_input_dim),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(512),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(1024),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(np.prod(self.img_shape), activation='tanh'),
                 Reshape(self.img_shape)])
    
        generator_output_tensor = generator_seqence(generator_input)       
        self.generator_model = Model(generator_input, generator_output_tensor)
        
    def _build_and_compile_discriminator_model(self, optimizer):
        discriminator_input = Input(shape=self.img_shape)
        discriminator_sequence = Sequential(
                [Flatten(input_shape=self.img_shape),
                 Dense(512),
                 LeakyReLU(alpha=0.2),
                 Dense(256),
                 LeakyReLU(alpha=0.2),
                 Dense(1, activation='sigmoid')])
    
        discriminator_tensor = discriminator_sequence(discriminator_input)
        self.discriminator_model = Model(discriminator_input, discriminator_tensor)
        self.discriminator_model.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.discriminator_model.trainable = False
    
    def _build_and_compile_gan(self, optimizer):
        real_input = Input(shape=(self.generator_input_dim,))
        generator_output = self.generator_model(real_input)
        discriminator_output = self.discriminator_model(generator_output)        
        
        self.gan = Model(real_input, discriminator_output)
        self.gan.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    def _save_images(self, epoch):
        generated = self._predict_noise(25)
        generated = 0.5 * generated + 0.5
        self._image_helper.save_image(generated, epoch, "generated/")
    
    def _predict_noise(self, size):
        noise = np.random.normal(0, 1, (size, self.generator_input_dim))
        return self.generator_model.predict(noise)
        
    def _plot_loss(self, history):
        hist = pd.DataFrame(history)
        plt.figure(figsize=(20,5))
        for colnm in hist.columns:
            plt.plot(hist[colnm],label=colnm)
        plt.legend()
        plt.ylabel("loss")
        plt.xlabel("epochs")
        plt.show()

In [0]:
def _build_generator_model(self):
        generator_input = Input(shape=(self.generator_input_dim,))
        generator_seqence = Sequential(
                [Dense(256, input_dim=self.generator_input_dim),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(512),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(1024),
                 LeakyReLU(alpha=0.2),
                 BatchNormalization(momentum=0.8),
                 Dense(np.prod(self.img_shape), activation='tanh'),
                 Reshape(self.img_shape)])
    
        generator_output_tensor = generator_seqence(generator_input)       
        self.generator_model = Model(generator_input, generator_output_tensor)

In [0]:
def _build_and_compile_discriminator_model(self, optimizer):
    discriminator_input = Input(shape=self.img_shape)
    discriminator_sequence = Sequential(
            [Flatten(input_shape=self.img_shape),
             Dense(512),
             LeakyReLU(alpha=0.2),
             Dense(256),
             LeakyReLU(alpha=0.2),
             Dense(1, activation='sigmoid')])

    discriminator_tensor = discriminator_sequence(discriminator_input)
    self.discriminator_model = Model(discriminator_input, discriminator_tensor)
    self.discriminator_model.compile(loss='binary_crossentropy',
        optimizer=optimizer,
        metrics=['accuracy'])
    self.discriminator_model.trainable = False

In [0]:
def train(self, epochs, train_data, batch_size):
        
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    history = []
    for epoch in range(epochs):
        #  Train Discriminator
        batch_indexes = np.random.randint(0, train_data.shape[0], batch_size)
        batch = train_data[batch_indexes]
        genenerated = self._predict_noise(batch_size)
        loss_real = self.discriminator_model.train_on_batch(batch, real)
        loss_fake = self.discriminator_model.train_on_batch(genenerated, fake)
        discriminator_loss = 0.5 * np.add(loss_real, loss_fake)

        #  Train Generator
        noise = np.random.normal(0, 1, (batch_size, self.generator_input_dim))
        generator_loss = self.gan.train_on_batch(noise, real)

        # Plot the progress
        print ("---------------------------------------------------------")
        print ("******************Epoch {}***************************".format(epoch))
        print ("Discriminator loss: {}".format(discriminator_loss[0]))
        print ("Generator loss: {}".format(generator_loss))
        print ("---------------------------------------------------------")

        history.append({"D":discriminator_loss[0],"G":generator_loss})

        # Take a snapshot every 100th epoch
        if epoch % 100 == 0:
            self._save_images(epoch)

    self._plot_loss(history)
    self._image_helper.makegif("generated/")  

In [12]:
import numpy as np
from keras.datasets import fashion_mnist

from image_helper import ImageHelper
from gan import GAN

(X, _), (_, _) = fashion_mnist.load_data()
X_train = X / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

image_helper = ImageHelper()
generative_advarsial_network = GAN(X_train[0].shape, 100, image_helper)
generative_advarsial_network.train(30000, X_train, batch_size=32)

ModuleNotFoundError: ignored

In [10]:
# To determine which version you're using:
!pip show tensorflow

# For the current version: 
!pip install --upgrade tensorflow

# For a specific version:
!pip install tensorflow==1.2

# For the latest nightly build:
!pip install tf-nightly

Name: tensorflow
Version: 1.15.2
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /tensorflow-1.15.2/python3.6
Requires: tensorboard, protobuf, wheel, keras-applications, gast, google-pasta, numpy, wrapt, keras-preprocessing, astor, six, grpcio, tensorflow-estimator, absl-py, opt-einsum, termcolor
Required-by: stable-baselines, magenta, fancyimpute
Collecting tensorflow
[?25l  Downloading https://files.pythonhosted.org/packages/85/d4/c0cd1057b331bc38b65478302114194bd8e1b9c2bbc06e300935c0e93d90/tensorflow-2.1.0-cp36-cp36m-manylinux2010_x86_64.whl (421.8MB)
[K     |████████████████████████████████| 421.8MB 20kB/s 
[?25hCollecting tensorflow-estimator<2.2.0,>=2.1.0rc0
[?25l  Downloading https://files.pythonhosted.org/packages/18/90/b77c328a1304437ab1310b463e533fa7689f4bfc41549593056d812fab8e/tensorflow_estimator-2.1.0-py2.py3-none

Collecting tensorflow==1.2
[?25l  Downloading https://files.pythonhosted.org/packages/5e/55/7995cc1e9e60fa37ea90e6777d832e75026fde5c6109215d892aaff2e9b7/tensorflow-1.2.0-cp36-cp36m-manylinux1_x86_64.whl (35.0MB)
[K     |████████████████████████████████| 35.0MB 120kB/s 
Collecting bleach==1.5.0
  Downloading https://files.pythonhosted.org/packages/33/70/86c5fec937ea4964184d4d6c4f0b9551564f821e1c3575907639036d9b90/bleach-1.5.0-py2.py3-none-any.whl
Collecting backports.weakref==1.0rc1
  Downloading https://files.pythonhosted.org/packages/6a/f7/ae34b6818b603e264f26fe7db2bd07850ce331ce2fde74b266d61f4a2d87/backports.weakref-1.0rc1-py3-none-any.whl
Collecting markdown==2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/ac/99/288a81a38526a42c98b5b9832c6e339ca8d5dd38b19a53abfac7c8037c7f/Markdown-2.2.0.tar.gz (236kB)
[K     |████████████████████████████████| 245kB 38.9MB/s 
[?25hCollecting html5lib==0.9999999
[?25l  Downloading https://files.pythonhosted.org/packages/ae/ae/bc

Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/19/71/67a8c442164b92e77685438f14483cb5a1e91bcaccf0dd1da2a23647c65e/tf_nightly-2.2.0.dev20200406-cp36-cp36m-manylinux2010_x86_64.whl (517.1MB)
[K     |████████████████████████████████| 517.2MB 32kB/s 
[?25hCollecting tb-nightly<2.4.0a0,>=2.3.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/a6/63/a0b69e4c814581e52bfa76ec456fe7b910242b9497d8d57732ac1729da53/tb_nightly-2.3.0a20200406-py3-none-any.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 35.5MB/s 
Collecting gast==0.3.3
  Downloading https://files.pythonhosted.org/packages/d6/84/759f5dd23fec8ba71952d97bcc7e2c9d7d63bdc582421f3cd4be845f0c98/gast-0.3.3-py2.py3-none-any.whl
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/51/09/88a430a478bc75d8cdc46edd2b3b1151380c8b7e72bf6e51080f88225109/tf_estimator_nightly-2.3.0.dev2020040601-py2.py3-none-any.whl (455kB)
[K     |████████████████████

In [1]:
# https://pypi.python.org/pypi/libarchive
!apt-get -qq install -y libarchive-dev && pip install -q -U libarchive
import libarchive

Selecting previously unselected package libarchive-dev:amd64.
(Reading database ... 144568 files and directories currently installed.)
Preparing to unpack .../libarchive-dev_3.2.2-3.1ubuntu0.6_amd64.deb ...
Unpacking libarchive-dev:amd64 (3.2.2-3.1ubuntu0.6) ...
Setting up libarchive-dev:amd64 (3.2.2-3.1ubuntu0.6) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
[K     |████████████████████████████████| 163kB 4.2MB/s 
[?25h  Building wheel for libarchive (setup.py) ... [?25l[?25hdone


In [0]:
# https://pypi.python.org/pypi/pydot
!apt-get -qq install -y graphviz && pip install -q pydot
import pydot

In [3]:
!apt-get -qq install python-cartopy python3-cartopy
import cartopy

Selecting previously unselected package python-pkg-resources.
(Reading database ... (Reading database ... 5%(Reading database ... 10%(Reading database ... 15%(Reading database ... 20%(Reading database ... 25%(Reading database ... 30%(Reading database ... 35%(Reading database ... 40%(Reading database ... 45%(Reading database ... 50%(Reading database ... 55%(Reading database ... 60%(Reading database ... 65%(Reading database ... 70%(Reading database ... 75%(Reading database ... 80%(Reading database ... 85%(Reading database ... 90%(Reading database ... 95%(Reading database ... 100%(Reading database ... 144624 files and directories currently installed.)
Preparing to unpack .../0-python-pkg-resources_39.0.1-2_all.deb ...
Unpacking python-pkg-resources (39.0.1-2) ...
Selecting previously unselected package python-pyshp.
Preparing to unpack .../1-python-pyshp_1.2.12+ds-1_all.deb ...
Unpacking python-pyshp (1.2.12+ds-1) ...
Selecting previously unselected package python-sha