In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist, fashion_mnist
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dropout, BatchNormalization as BN, LeakyReLU, Flatten, Dense, Reshape
from tensorflow.keras.optimizers import Adam
from google.colab import drive
from tqdm import tqdm
from google.colab import files
drive.mount('/content/gdrive')

In [None]:
def data(name, index): 
    (xtest, ytest), (xtrain, ytrain) = name.load_data()
    dataset = np.vstack((xtest,xtrain))
    label = np.vstack((ytest.reshape(ytest.shape[0],1),ytrain.reshape(ytrain.shape[0],1))).reshape(dataset.shape[0],)
    dataset = (dataset[np.where(label == index)] - 127.5).astype(np.float32) / 127.5
    return dataset

fashion_dataset = data(fashion_mnist,7)

In [None]:
class Discriminator:
    def __init__(self, lr, input_shape):
        self.lr = lr
        self.input_shape = input_shape
        self.d = self.model() 
        
    def model(self):
        Model = Sequential()
        Model.add(Conv2D(filters = 64, kernel_size = (5,5), strides = (2,2), padding = 'same', input_shape = self.input_shape))
        Model.add(LeakyReLU(alpha = 0.2))
        Model.add(Dropout(0.3))        
        
        Model.add(Conv2D(filters = 128, kernel_size = (5,5), strides = (2,2), padding = 'same'))
        Model.add(LeakyReLU(alpha = 0.2))
        Model.add(Dropout(0.3))        
        
        Model.add(Flatten())
        
        Model.add(Dense(1, activation = 'sigmoid'))
        Model.compile(optimizer = Adam(lr = self.lr, beta_1 = 0.5), loss = 'binary_crossentropy')
        
        print(Model.summary())

        return Model
    
        
class Generator:
    def __init__(self, lr, input_shape, latent_dim):
        self.lr = lr
        self.latent_dim = latent_dim 
        self.output_shape = input_shape
        self.g = self.model()
        
    def model(self):
        Model = Sequential()
        Model.add(Dense(7*7*self.latent_dim, input_dim = self.latent_dim))
        Model.add(LeakyReLU(alpha = 0.2))
        Model.add(Reshape((7,7,self.latent_dim)))
        
        Model.add(Conv2DTranspose(128, kernel_size = (5,5), strides = (2,2), padding = 'same'))
        Model.add(BN())
        Model.add(LeakyReLU(alpha = 0.2))

        Model.add(Conv2DTranspose(64, kernel_size = (4,4), strides = (2,2), padding = 'same'))
        Model.add(BN())
        Model.add(LeakyReLU(alpha = 0.2))

        Model.add(Conv2DTranspose(1, kernel_size = (3,3), padding = 'same', activation = 'tanh'))   
        Model.add(Flatten())
        Model.add(Reshape(self.output_shape))
        
        print(Model.summary())
        
        return Model
    
    
class GAN:
    def __init__(self, dataset, epochs = 1000, batch = 64, latent_dim = 128, lr = 2e-4):
        self.epochs = epochs 
        self.batch = batch 
        self.latent_dim = latent_dim
        self.lr = lr 
        self.dataset = dataset.reshape(*dataset.shape, 1)
        self.input_size = (self.dataset[0].shape[0], self.dataset[0].shape[1], 1)
        self.half_batch = self.batch // 2
        
        
        self.D = Discriminator(self.lr, self.input_size).d
        self.G = Generator(self.lr, self.input_size, self.latent_dim).g
        self.GAN_model = self.gan_model()
        
        self.DLoss = []
        self.GLoss = []
        self.images = self.stack()
        self.real = []
        self.fake = []

        
    def gan_model(self):
        self.D.trainable = False
        
        Model = Sequential()
        Model.add(self.G)
        Model.add(self.D)
        Model.compile(optimizer = Adam(lr = self.lr, beta_1 = 0.5), loss = 'binary_crossentropy')
        return Model 
    
    def z(self,batch):
        return np.random.uniform(-1,1,(batch,self.latent_dim)).reshape(batch,1,1,self.latent_dim)
    
    def stack(self, size = 15): 
        imgs = self.G.predict(self.z(size))
        stacked = imgs[0]
        for i in range(1,size): 
            stacked = np.vstack((stacked, imgs[i]))
            
        return stacked
    
    def plot(self, epoch, size = 5):
        for i in range(size**2):
            plt.subplot(size,size,i + 1)
            plt.axis('off')
            plt.imshow(self.G.predict(self.z(1)).reshape(28,28),cmap = 'gray')
        plt.savefig('img{}.png'.format(epoch), dpi = 400)
        files.download('img{}.png'.format(epoch))
        plt.show()
        plt.close()

    def loss_graph(self, epoch):
        x = [i for i in range(len(self.GLoss))]
        plt.plot(x,self.GLoss, color = 'b', label = 'Generator')
        plt.plot(x,self.DLoss, color = 'y', label = 'Discriminator')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.savefig('graph.png',dpi = 400)
        plt.imsave('img{}.png'.format(epoch),self.images.reshape(self.images.shape[0],self.images.shape[1]), cmap = 'gray', dpi = 400)
        files.download('img{}.png'.format(epoch))
        files.download('graph.png')


    def train(self): 
        for epoch in tqdm(range(1, self.epochs+1), ascii = True, unit = 'Epoch'):
        
            if np.random.random() > 0.05:
                real_labels = np.random.uniform(0.9,1,self.half_batch).reshape(self.half_batch,1)
                fake_labels = np.random.uniform(0,0.2,self.half_batch).reshape(self.half_batch,1)

            else:
                fake_labels = np.random.uniform(0.9,1,self.half_batch).reshape(self.half_batch,1)
                real_labels = np.zeros((self.half_batch,1)).reshape(self.half_batch,1)
                                                                               
            real_imgs = self.dataset[np.random.randint(0,len(self.dataset), self.half_batch)]
            real_loss = self.D.train_on_batch(real_imgs,real_labels)
            
            fake_imgs = self.G.predict(self.z(self.half_batch))
            fake_loss = self.D.train_on_batch(fake_imgs, fake_labels)
            
            DL = 0.5 * (real_loss + fake_loss)
            self.DLoss.append(DL)
            
            #-----------------------------------------------------#
            
            noise = self.z(self.batch)
            labels = np.random.uniform(0.9,1,self.batch)
            GL = self.GAN_model.train_on_batch(noise,labels)
            self.GLoss.append(GL)
            
            #-----------------------------------------------------#
            
            if epoch%(self.epochs//50) == 0:
                self.images = np.hstack((self.images,self.stack()))
                if epoch > self.epochs // 2: 
                    self.G.save('model{}.h5'.format(epoch))
            
            if epoch % 100 == 0 and epoch > 0.5 * self.epochs: 
                self.plot(epoch)
            
        self.loss_graph(epoch)
              
                        