In [1]:
import tensorflow as tf 
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import numpy as np 
from matplotlib import pyplot as plt


In [11]:
latent_dim = 20
## data augmentation model
da_input = Input(shape=(512,256,3))
x = RandomCrop(256,256)(da_input)
x = RandomFlip()(x)
x = RandomRotation(factor = ((0.4)),fill_mode = "reflect")(x)
x = RandomZoom(0.2,0.2)(x)
da_model = Model(da_input,x)

## generative model 
gen_input = Input(shape = (latent_dim))
x = Dense(8*8*8)(gen_input)
x = Reshape((8,8,8))(x)
x = Conv2DTranspose(16,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(32,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(64,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(128,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(256,(4,4),strides = 2 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(512,(4,4),strides = 1 ,activation = 'linear',padding = 'same')(x)
x = LeakyReLU(alpha = 0.1)(x)
x = Conv2DTranspose(3,(5,5),strides = 1 ,activation = 'relu',padding = 'same')(x)
gen_model = Model(gen_input,x)

# discriminator model
dis_input = Input((256,256,3))
x = Conv2D(512,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(dis_input)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(256,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(128,kernel_size = 4,activation = 'relu',strides = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(64,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(32,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = LeakyReLU(alpha = 0.2)(x)
x = Conv2D(16,kernel_size = 4,activation = 'relu',strides  = 2,padding = 'same')(x)
x = Flatten()(x)
x = Dropout(0.3)(x)
x = Dense(1,activation = 'sigmoid')(x)
dis_model = Model(dis_input,x)

In [1]:
class GAN(tf.keras.Model):
    def __init__(self,da,gen,dis, **kwargs):
        super().__init__(**kwargs)
        self.da = da
        self.gen = gen
        self.dis = dis
        self.dis_loss_tracker = tf.keras.metrics.Mean(name='dis_loss')
        self.gen_loss_tracker = tf.keras.metrics.Mean(name='gen_loss')
    @property
    def metrics(self):
        return [self.dis_loss_tracker,
                self.gen_loss_tracker]
    
    def train_step(self, data):
        batch_size = tf.shape(data)[0]
        random_in  = tf.random.normal(shape = (batch_size,latent_dim))
        fake_genrated = self.da(random_in)
        augmented_images = self.da(data)
        whole_images = np.concatenate((augmented_images,fake_genrated),axis=0)
        labels = np.concatenate((np.ones(batch_size,1),np.zeros(shape=(batch_size,1))),axis=0)

        with tf.GradientTape() as tape:
            pred = self.dis(whole_images)
            dis_loss = tf.keras.losses.binary_crossentropy(labels,pred)
        grads = tape.gradient(dis_loss,self.dis.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.dis.trainable_weights))
        self.dis_loss_tracker.update_state(dis_loss)

        random_latent_vectors = tf.random.normal(
                                                shape=(batch_size,latent_dim))
        misleading_labels = tf.zeros((batch_size, 1))
        with tf.GradientTape() as tape:
            out = self.dis(self.gen(random_latent_vectors))
            gen_loss = tf.keras.losses.binary_crossentropy(misleading_labels,out)
        grads = tape.gradient(gen_loss,self.gen.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.dis.trainable_weights))
        self.gen_loss_tracker.update_state(gen_loss)

        return {'dis_loss': self.dis_loss_tracker.result(),
                'gen_loss': self.gen_loss_tracker.result(),
                } 


NameError: name 'tf' is not defined