<a href="https://colab.research.google.com/github/imnawar/SD2GAN/blob/main/SD2GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import time
import zipfile
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
class SD2GAN():
    def __init__(self, **params):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = params['z_dim']
        self.dataset = params['dataset']
        self.batch_size = params['bs']
        self.epochs = params['epochs']

        self.optimizer1 = tf.keras.optimizers.Adam(params['lr'], params['beta'])
        self.optimizer2 = tf.keras.optimizers.Adam(params['lr_disc2'], params['beta'])

        if self.dataset == 'mnist': 
            self.PATH = 'SD2GAN' + '/' + self.dataset +'/' + 'SN'
            (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.mnist.load_data()
            self.x_train = self.x_train / 127.5 - 1.
            self.x_test = self.x_test / 127.5 - 1.
            self.img_rows = self.x_train.shape[1]
            self.img_cols = self.x_train.shape[2]
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
        elif self.dataset == 'fashion':
            self.PATH = 'SD2GAN' + '/' + self.dataset +'/' + 'SN'
            (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.fashion_mnist.load_data()
            self.x_train = self.x_train / 127.5 - 1.
            self.x_test = self.x_test / 127.5 - 1.
            self.img_rows = self.x_train.shape[1]
            self.img_cols = self.x_train.shape[2]
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)

        json_file = open(self.PATH + '.json', 'r')
        loaded_model_json = json_file.read()
        json_file.close()
        self.Siamese_net = tf.keras.models.model_from_json(loaded_model_json)
        self.Siamese_net.load_weights(self.PATH + '.h5')


        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer1, metrics=['accuracy'])

        self.discriminator2 = self.build_discriminator2()
        self.discriminator2.compile(loss='binary_crossentropy', optimizer=self.optimizer2, metrics=['accuracy'])

        self.generator = self.build_generator()

        z = tf.keras.layers.Input(shape=(self.latent_dim,))
        img = self.generator(z)
        imgs1 = tf.keras.layers.Input((self.img_shape))
        imgs2 = tf.keras.layers.Input(self.img_shape)
        self.discriminator.trainable = False
        self.discriminator2.trainable = False
        self.Siamese_net.trainable = False
        similarity = self.Siamese_net([imgs1, imgs2])
        similarity_validity = self.discriminator2(similarity)
        validity = self.discriminator(img)

        self.combined = tf.keras.models.Model([z, [imgs1, imgs2]], [validity, similarity_validity])
        self.combined.compile(loss='binary_crossentropy', optimizer=self.optimizer1)

    def build_generator(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Dense((128*7*7), input_shape=(self.latent_dim,)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.Reshape((7, 7, 128)))
        model.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
        model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same'))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
        model.add(tf.keras.layers.Conv2D(1, (5, 5), padding='same'))
        model.add(tf.keras.layers.Activation('tanh'))

        noise = tf.keras.layers.Input(shape=(self.latent_dim,))
        img = model(noise)

        return tf.keras.models.Model(noise, img)

    def build_discriminator(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Conv2D(128, (5, 5),padding='same', input_shape=self.img_shape))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(512))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.Dense(1))
        model.add(tf.keras.layers.Activation('sigmoid'))

        img = tf.keras.layers.Input(shape=self.img_shape,)
        validity = model(img)

        return tf.keras.models.Model(img, validity)

    def build_discriminator2(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Dense(128, input_shape=(1,), activation=tf.nn.relu))
        model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

        array_sim = tf.keras.layers.Input(shape=(1,))
        validity = model(array_sim)

        return tf.keras.models.Model(array_sim, validity)

    def train(self):
        start_time = time.time()
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))
        fake_s = np.zeros((int(self.batch_size/2), 1))
        real_s = np.ones((int(self.batch_size/2), 1))

        for epoch in range(self.epochs):

            idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
            imgs = self.x_train[idx]
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            splited_batch_real = np.split(imgs, 2)
            part1_real = splited_batch_real[0]
            part2_real = splited_batch_real[1]
            sim_real = self.Siamese_net.predict([part1_real, part2_real])

            noise2 = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            gen_imgs2 = self.generator.predict(noise2)

            splited_batch = np.split(gen_imgs2, 2)
            part1 = splited_batch[0]
            part2 = splited_batch[1]

            sim_fake = self.Siamese_net.predict([part1, part2])
            

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            d_loss_real2 = self.discriminator2.train_on_batch(sim_real, real_s)
            d_loss_fake2 = self.discriminator2.train_on_batch(sim_fake, fake_s)
            d_loss2 = 0.5 * np.add(d_loss_real2, d_loss_fake2)

            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            noise2 = np.random.normal(0, 1, (self.batch_size*2, self.latent_dim))
            gen_imgs2 = self.generator.predict(noise2)
            g_loss = self.combined.train_on_batch([noise, [gen_imgs2[:self.batch_size,:,:,:], gen_imgs2[self.batch_size:,:,:,:]]], valid)

            if epoch % 1500 == 0:
                print ("%d [D loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], d_loss2[0], 100*d_loss2[1], g_loss[0]))

        print("Training done in: %s minutes ---" % ((time.time() - start_time)/60))

        model_json = self.generator.to_json()
        with open("output/generator.json", "w") as json_file:
            json_file.write(model_json)
        self.generator.save_weights("output/generator.h5")
        print("Saved Generator to output/")
        r, c = 8, 8
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("output/%d.png" % epoch)
        plt.close()
        print("Saved generated imgs to output/")

In [9]:
'''
    Setup settings and download all necessary files
'''
# Download pre-trained models 
!wget -O SD2GAN.zip https://www.dropbox.com/s/4ez4oqbo6sw8oex/SD2GAN.zip?dl=0
# Unzip the files 
fantasy_zip = zipfile.ZipFile('SD2GAN.zip')
fantasy_zip.extractall('')
fantasy_zip.close()
#Setup arguments
params = {
    'dataset': "mnist", # 'mnist' , 'fashion'
    'train' : False, 
    'pretrained': "SD2GAN", 
    'z_dim' : 100, 
    'lr' : 0.0001, 
    'lr_disc2' : 0.00009, 
    'beta' : 0.5, 
    'bs' : 32, 
    'epochs' : 30000
}
# Making output folder
try: os.mkdir('output')
except: pass
if params['train']:
    gan = SD2GAN(**params)
    gan.train()
else:
  gan = params['pretrained']
  dataset = params['dataset']
  latent_dim = params['z_dim']
  PATH = 'SD2GAN/' + dataset +'/'+ gan
  # Load the pretrained model
  json_file = open(PATH+".json", 'r')
  loaded_model_json = json_file.read()
  json_file.close()
  G = tf.keras.models.model_from_json(loaded_model_json)
  G.load_weights(PATH+".h5")
  print("Loaded Generator from disk. ", gan)
  print("Generating image")
  noise = np.random.normal(0, 1, (1, latent_dim))
  gen_imgs = G.predict(noise)
  plt.imshow(gen_imgs[0,:,:,0], cmap='gray')
  plt.savefig("output/%d.png" % 1)
  plt.close()
  print("you can find the output in output/ folder.")

--2021-05-07 19:04:56--  https://www.dropbox.com/s/4ez4oqbo6sw8oex/SD2GAN.zip?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.18, 2620:100:601a:18::a27d:712
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/4ez4oqbo6sw8oex/SD2GAN.zip [following]
--2021-05-07 19:04:57--  https://www.dropbox.com/s/raw/4ez4oqbo6sw8oex/SD2GAN.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc5813019f60b40cea4242020b8c.dl.dropboxusercontent.com/cd/0/inline/BODbiRldX2Bpla-lPRRVU-DBrqUHI0DwK5TQdlvDtdHIFlJO3JnVcnJ-IBnL1pEenjSO98QWoELcqp8Zka-hTC8j9ZBCG18scCmtVG3L5sPPKwFJZxHpZKDS-bAa-1pmxOjg--c8KtpPm3R4PopWb8_R/file# [following]
--2021-05-07 19:04:57--  https://uc5813019f60b40cea4242020b8c.dl.dropboxusercontent.com/cd/0/inline/BODbiRldX2Bpla-lPRRVU-DBrqUHI0DwK5TQdlvDtdHIFlJO3JnVcnJ-IBnL1pEenjSO98QWoELcqp8Zka-hTC8j9ZB