<a href="https://colab.research.google.com/github/imnawar/SD2GAN/blob/main/SN_SD2GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import time
import random
import zipfile
import numpy as np
import tensorflow as tf
from keras import backend as K
import matplotlib.pyplot as plt

In [2]:
class SD2GAN():
    def __init__(self, **params):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = params['z_dim']
        self.dataset = params['dataset']
        self.batch_size = params['bs']
        self.epochs = params['epochs']

        self.optimizer1 = tf.keras.optimizers.Adam(params['lr'], params['beta'])
        self.optimizer2 = tf.keras.optimizers.Adam(params['lr_disc2'], params['beta'])

        self.sn = self.SN(**params)

        if self.dataset == 'mnist': 
            self.PATH = 'SD2GAN' + '/' + self.dataset +'/' + 'SN'
            (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.mnist.load_data()
            self.x_train = self.x_train / 127.5 - 1.
            self.x_test = self.x_test / 127.5 - 1.
            self.img_rows = self.x_train.shape[1]
            self.img_cols = self.x_train.shape[2]
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
        elif self.dataset == 'fashion':
            self.PATH = 'SD2GAN' + '/' + self.dataset +'/' + 'SN'
            (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.fashion_mnist.load_data()
            self.x_train = self.x_train / 127.5 - 1.
            self.x_test = self.x_test / 127.5 - 1.
            self.img_rows = self.x_train.shape[1]
            self.img_cols = self.x_train.shape[2]
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)

        #json_file = open(self.PATH + '.json', 'r')
        #loaded_model_json = json_file.read()
        #json_file.close()
        #self.Siamese_net = tf.keras.models.model_from_json(loaded_model_json)
        #self.Siamese_net.load_weights(self.PATH + '.h5')


        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer1, metrics=['accuracy'])

        self.discriminator2 = self.build_discriminator2()
        self.discriminator2.compile(loss='binary_crossentropy', optimizer=self.optimizer2, metrics=['accuracy'])

        self.generator = self.build_generator()

        self.Siamese_net = self.sn.train(self.generator)

        z = tf.keras.layers.Input(shape=(self.latent_dim,))
        img = self.generator(z)
        imgs1 = tf.keras.layers.Input((self.img_shape))
        imgs2 = tf.keras.layers.Input(self.img_shape)
        self.discriminator.trainable = False
        self.discriminator2.trainable = False
        self.Siamese_net.trainable = False
        similarity = self.Siamese_net([imgs1, imgs2])
        similarity_validity = self.discriminator2(similarity)
        validity = self.discriminator(img)

        self.combined = tf.keras.models.Model([z, [imgs1, imgs2]], [validity, similarity_validity])
        self.combined.compile(loss='binary_crossentropy', optimizer=self.optimizer1)

    def build_generator(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Dense((128*7*7), input_shape=(self.latent_dim,)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.Reshape((7, 7, 128)))
        model.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
        model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same'))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
        model.add(tf.keras.layers.Conv2D(1, (5, 5), padding='same'))
        model.add(tf.keras.layers.Activation('tanh'))

        noise = tf.keras.layers.Input(shape=(self.latent_dim,))
        img = model(noise)

        return tf.keras.models.Model(noise, img)

    def build_discriminator(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Conv2D(128, (5, 5),padding='same', input_shape=self.img_shape))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(512))
        model.add(tf.keras.layers.Activation('tanh'))
        model.add(tf.keras.layers.Dense(1))
        model.add(tf.keras.layers.Activation('sigmoid'))

        img = tf.keras.layers.Input(shape=self.img_shape,)
        validity = model(img)

        return tf.keras.models.Model(img, validity)

    def build_discriminator2(self):

        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Dense(128, input_shape=(1,), activation=tf.nn.relu))
        model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

        array_sim = tf.keras.layers.Input(shape=(1,))
        validity = model(array_sim)

        return tf.keras.models.Model(array_sim, validity)

    def train(self):
        start_time = time.time()
        valid = np.ones((self.batch_size, 1))
        fake = np.zeros((self.batch_size, 1))
        fake_s = np.zeros((int(self.batch_size/2), 1))
        real_s = np.ones((int(self.batch_size/2), 1))

        for epoch in range(self.epochs):

            idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
            imgs = self.x_train[idx]
            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            splited_batch_real = np.split(imgs, 2)
            part1_real = splited_batch_real[0]
            part2_real = splited_batch_real[1]
            sim_real = self.Siamese_net.predict([part1_real, part2_real])

            noise2 = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            gen_imgs2 = self.generator.predict(noise2)

            splited_batch = np.split(gen_imgs2, 2)
            part1 = splited_batch[0]
            part2 = splited_batch[1]

            sim_fake = self.Siamese_net.predict([part1, part2])
            

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            d_loss_real2 = self.discriminator2.train_on_batch(sim_real, real_s)
            d_loss_fake2 = self.discriminator2.train_on_batch(sim_fake, fake_s)
            d_loss2 = 0.5 * np.add(d_loss_real2, d_loss_fake2)

            noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
            noise2 = np.random.normal(0, 1, (self.batch_size*2, self.latent_dim))
            gen_imgs2 = self.generator.predict(noise2)
            g_loss = self.combined.train_on_batch([noise, [gen_imgs2[:self.batch_size,:,:,:], gen_imgs2[self.batch_size:,:,:,:]]], valid)

            if epoch % 300 == 0:
                print ("%d [D loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], d_loss2[0], 100*d_loss2[1], g_loss[0]))

        print("Training done in: %s minutes ---" % ((time.time() - start_time)/60))

        model_json = self.generator.to_json()
        with open("output/generator.json", "w") as json_file:
            json_file.write(model_json)
        self.generator.save_weights("output/generator.h5")
        print("Saved Generator to output/")
        r, c = 8, 8
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("output/%d.png" % epoch)
        plt.close()
        print("Saved generated imgs to output/")


    class SN():
        def __init__(self, **params):
            self.img_rows = 28
            self.img_cols = 28
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
            self.latent_dim = params['z_dim']
            self.dataset = params['dataset']
            self.batch_size = params['bs']
            self.epochs = params['epochs']
            self.num_classes = 10

        def euclid_dis(self, vects):
            x,y = vects
            x = tf.cast(x, tf.float32)
            y = tf.cast(y, tf.float32)
            sum_square = K.sum(K.square(x-y), axis=1, keepdims=True)
            return K.sqrt(K.maximum(sum_square, K.epsilon()))

        def eucl_dist_output_shape(self, shapes):
            shape1, shape2 = shapes
            return (shape1[0], 1)

        def contrastive_loss(self, y_true, y_pred):
            margin = 1
            square_pred = K.square(y_pred)
            margin_square = K.square(K.maximum(margin - y_pred, 0))
            return K.mean(y_true * square_pred + (1 - y_true) * margin_square)

        def create_pairs(self, x, digit_indices):
            pairs = []
            labels = []
            
            n=min([len(digit_indices[d]) for d in range(self.num_classes)]) -1 
            for d in range(self.num_classes):
              for i in range(n):
                z1, z2 = digit_indices[d][i], digit_indices[d][i+1]
                pairs += [[x[z1], x[z2]]]
                inc = random.randrange(1, self.num_classes)
                dn = (d + inc) % self.num_classes
                z1, z2 = digit_indices[d][i], digit_indices[dn][i]
                pairs += [[x[z1], x[z2]]]
                labels += [1,0]
            return np.array(pairs), np.array(labels)

        def create_base_net(self, input_shape):
            input = tf.keras.layers.Input(shape = input_shape)
            x = tf.keras.layers.Conv2D(4, (5,5), activation = 'tanh')(input)
            x = tf.keras.layers.AveragePooling2D(pool_size = (2,2))(x)
            x = tf.keras.layers.Conv2D(16, (5,5), activation = 'tanh')(x)
            x = tf.keras.layers.AveragePooling2D(pool_size = (2,2))(x)
            x = tf.keras.layers.Flatten()(x)
            x = tf.keras.layers.Dense(10, activation = 'tanh')(x)
            model = tf.keras.models.Model(input, x)
            model.summary()
            return model

        def prepare_data(self, G):
            (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.mnist.load_data()
            self.x_train = self.x_train.reshape(self.x_train.shape[0], 28, 28,1)
            self.x_test = self.x_test.reshape(self.x_test.shape[0], 28, 28, 1)
            self.x_train = self.x_train.astype('float32')
            self.x_test = self.x_test.astype('float32')
            self.x_train = self.x_train / 127.5 - 1.
            self.x_test = self.x_test / 127.5 - 1.

            noise = np.random.normal(0, 1, (10000, self.latent_dim))
            gen_imgs = G.predict(noise)
            X_combined = np.r_[self.x_train, gen_imgs]
            noise_labels = np.zeros(10000)
            noise_labels += 10
            noise_labels = noise_labels.astype('int')
            y_combined = np.r_[self.y_train, noise_labels]
            self.y_train = y_combined
            self.x_train = X_combined

            noise = np.random.normal(0, 1, (10000, self.latent_dim))
            gen_imgs = G.predict(noise)
            X_combined = np.r_[self.x_test, gen_imgs]
            noise_labels = np.zeros(10000)
            noise_labels += 10
            noise_labels = noise_labels.astype('int')
            y_combined = np.r_[self.y_test, noise_labels]
            self.y_test = y_combined
            self.x_test = X_combined

            self.input_shape = self.x_train.shape[1:]

        def compute_accuracy(self, y_true, y_pred):
            pred = y_pred.ravel() < 0.5
            return np.mean(pred == y_true)


        def accuracy(self, y_true, y_pred):
            return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))


        def train(self, G):
            self.prepare_data(G)
            # create training+test positive and negative pairs
            digit_indices = [np.where(self.y_train == i)[0] for i in range(self.num_classes)]
            tr_pairs, tr_y = self.create_pairs(self.x_train, digit_indices)
            tr_y = tf.cast(tr_y, tf.float32)

            digit_indices = [np.where(self.y_test == i)[0] for i in range(self.num_classes)]
            te_pairs, te_y = self.create_pairs(self.x_test, digit_indices)

            te_y = tf.cast(te_y, tf.float32)

            base_network = self.create_base_net(self.input_shape)

            input_a = tf.keras.layers.Input(shape=self.input_shape)
            input_b = tf.keras.layers.Input(shape=self.input_shape)

            processed_a = base_network(input_a)
            processed_b = base_network(input_b)

            distance = tf.keras.layers.Lambda(self.euclid_dis,output_shape=self.eucl_dist_output_shape)([processed_a, processed_b])

            self.model = tf.keras.models.Model([input_a, input_b], distance)

            # train
            rms = tf.keras.optimizers.RMSprop()
            self.model.compile(loss=self.contrastive_loss, optimizer=rms, metrics=[self.accuracy])
            self.model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
                      batch_size=128,
                      epochs=10)
            return self.model

In [None]:
params = {
    'dataset': "mnist", # 'mnist' , 'fashion'
    'z_dim' : 100, 
    'lr' : 0.0001, 
    'lr_disc2' : 0.00009, 
    'beta' : 0.5, 
    'bs' : 32, 
    'epochs' : 30000
}
# Making output folder
try: os.mkdir('output')
except: pass
gan = SD2GAN(**params)
gan.train()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 24, 24, 4)         104       
_________________________________________________________________
average_pooling2d (AveragePo (None, 12, 12, 4)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 16)          1616      
_________________________________________________________________
average_pooling2d_1 (Average (None, 4, 4, 16)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 256)               0         
_________________________________