In [1]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import numpy as np
from matplotlib import pyplot as plt


In [3]:
# dataset
b_size = 10
dataset = tf.keras.utils.image_dataset_from_directory('/content/drive/MyDrive/data',
                                                      labels=None,
                                                      image_size= (512,256),
                                                      shuffle=True,
                                                      batch_size =b_size,
                                                      color_mode = 'rgb',
                                                      crop_to_aspect_ratio=False
                                                      )
def crop_image(image):
    cropped_image = tf.image.crop_to_bounding_box(image,70,150,512, 256)
    return cropped_image


def normalize(ds):
  ds = ds/255.0
  return ds
dataset = dataset.map(normalize)
#dataset = dataset.map(crop_image)


Found 1801 files belonging to 1 classes.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
latent_dim = 20
## data augmentation model
da_input = Input(shape=(512,256,3))
x = RandomCrop(256,256)(da_input)
x = RandomFlip()(x)
x = RandomRotation(factor = ((0.4)),fill_mode = "reflect")(x)
x = RandomZoom(0.2,0.2)(x)
da_model = Model(da_input,x)

## generative model
gen_input = Input(shape = (latent_dim))
x = Dense(8*8*8)(gen_input)
x = Reshape((8,8,8))(x)
x = Conv2DTranspose(16,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(32,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(64,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(128,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(256,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(512,(4,4),strides = 1 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2DTranspose(3,(5,5),strides = 1 ,activation = 'linear',padding = 'same')(x)
gen_model = Model(gen_input,x)

# discriminator model
dis_input = Input((256,256,3))
x = Conv2D(512,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(dis_input)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(256,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(128,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(64,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(32,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(16,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1,activation = 'sigmoid')(x)
dis_model = Model(dis_input,x)

In [5]:
class GAN(tf.keras.Model):
    def __init__(self,da,gen,dis, **kwargs):
        super().__init__(**kwargs)
        self.da = da
        self.gen = gen
        self.dis = dis
        self.dis_loss_tracker = tf.keras.metrics.Mean(name='dis_loss')
        self.gen_loss_tracker = tf.keras.metrics.Mean(name='gen_loss')
    @property
    def metrics(self):
        return [self.dis_loss_tracker,
                self.gen_loss_tracker]
    def compile(self,optimizers, **kwargs):
        super(GAN,self).compile(kwargs)
        self.dis_opt = optimizers['dis_optimizer']
        self.gen_opt = optimizers['gen_optimizer']

    def train_step(self, data):
        batch_size = tf.shape(data)[0]
        random_in  = tf.random.normal(shape = (batch_size,latent_dim))
        fake_generated = self.gen(random_in)
        augmented_images = self.da(data)
        whole_images = tf.concat([augmented_images,fake_generated],axis=0)
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))],axis=0)
        with tf.GradientTape() as tape:
            pred = self.dis(whole_images)
            dis_loss = tf.keras.losses.binary_crossentropy(labels,pred)

        grads = tape.gradient(dis_loss,self.dis.trainable_weights)
        self.dis_opt.apply_gradients(zip(grads, self.dis.trainable_weights))
        self.dis_loss_tracker.update_state(dis_loss)

        random_latent_vectors = tf.random.normal(
                                                shape=(batch_size,latent_dim))
        misleading_labels = tf.ones((batch_size, 1))
        with tf.GradientTape() as tape:
            out = self.dis(self.gen(random_latent_vectors))
            gen_loss = tf.keras.losses.binary_crossentropy(misleading_labels,out)
        grads = tape.gradient(gen_loss,self.gen.trainable_weights)
        self.gen_opt.apply_gradients(zip(grads, self.gen.trainable_weights))
        self.gen_loss_tracker.update_state(gen_loss)

        return {'dis_loss': self.dis_loss_tracker.result(),
                'gen_loss': self.gen_loss_tracker.result(),
                }
gan = GAN(da_model,gen_model,dis_model)

In [None]:
     gan.compile(optimizers = {
    'dis_optimizer':tf.keras.optimizers.Adam(),
    'gen_optimizer':tf.keras.optimizers.Adam()
})
gan.fit(dataset,epochs = 3)

Epoch 1/3

In [None]:
size = 8
random_latent_vector = tf.random.normal(shape= (size,latent_dim))
generated = np.array(gan.gen.predict(random_latent_vector))
generated = (generated - generated.min())(generated.max() - generated.min())
for i in range(1,size+1):
  plt.subplot(4,2,i)
  plt.axis('off')
  plt.imshow()
plt.show()
