In [None]:
import cifar10
from ops import *
import time

In [None]:
class TripleGan(object):
    
    def discriminator(self,x,label,scope='discriminator',is_training=True,reuse=False):
        with tf.variable_scope(scope, reuse=reuse):
            x=dropout(x,0.2)  #to prevent overfitting, probability is kept at 0.2
            y=reshape(label,[-1,1,1,10])#reshaping label to a 4-D vector
            x=conv_concat(x,y)#concatenating y to x by reshaping y to same dimensions as x and adding ones wherever necessary 
        
        #convolution 1
        x=conv_layer(x,filter_size=32,stride=1,kernel=[3,3])#convolution with stride=1 
        x=lrelu(x,0.2)#adding rectified linear unit to prevent vanishing derivatives
        x=conv_concat(x,y)#concatenating y for conditional discriminator so that we can condition it on basis of output as well 32*128
        x=conv_layer(x,filter_size=32,stride=2,kernel=[3,3])#strides of 2 as x has now doubled because of conactenating y to it 32*64
        x=dropout(x,0.2)
        x=lrelu(x,0.2)
        
        
        #convolution 2
        x=conv_layer(x,filter_size=64,stride=1,kernel=[3,3])
        x=lrelu(x,0.2)
        x=conv_concat(x,y)
        x=conv_layer(x,filter_size=64,stride=2,kernel=[3,3])
        x=dropout(x,0.2)
        x=lrelu(x,0.2) 
        
        
        #convolution 3
        x=conv_layer(x,filter_size=64,stride=1,kernel=[3,3])
        x=lrelu(x,0.2)
        x=conv_concat(x,y)
        x=conv_layer(x,filter_size=64,stride=2,kernel=[3,3])
        x=lrelu(x,0.2) 
        
        #FC layers
        x=GAP(x)#global average pooling layer, used to prevent overfitting reduces dimensions of form h*w*d to form 1*1*d by taking averages of h and w values
        x=flatten(x)#flattening to a 1-d array
        x=concat(x,label)#concatenation of labels to the output after all convolution to the flattened layer,1st fully conneceted layer
        logit=linear(x,unit=1)
        output=sigmoid(logit)#activation function
        
        return output,logit,x
    
    
    
    def generator(self, z, y, scope='generator', is_training=True, reuse=False):
        with tf.variable_scope(scope, reuse=reuse) :

            x = concat([z, y]) # mlp_concat

            x = relu(linear(x, unit=512*4*4, layer_name=scope+'_linear1'))
            x = batch_norm(x, is_training=is_training, scope=scope+'_batch1')

            x = tf.reshape(x, shape=[-1, 4, 4, 512])
            y = tf.reshape(y, [-1, 1, 1, self.y_dim])
            x = conv_concat(x,y)

            x = relu(deconv_layer(x, filter_size=256, kernel=[5,5], stride=2, layer_name=scope+'_deconv1'))
            x = batch_norm(x, is_training=is_training, scope=scope+'_batch2')
            x = conv_concat(x,y)

            x = relu(deconv_layer(x, filter_size=128, kernel=[5,5], stride=2, layer_name=scope+'_deconv2'))
            x = batch_norm(x, is_training=is_training, scope=scope+'_batch3')
            x = conv_concat(x,y)

            x = tanh(deconv_layer(x, filter_size=3, kernel=[5,5], stride=2, wn=False, layer_name=scope+'deconv3'))

        return x
    
    def classifier(self, x, scope='classifier', is_training=True, reuse=False):
        with tf.variable_scope(scope, reuse=reuse) :
            x = gaussian_noise_layer(x) # default = 0.15
            x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv1'))
            x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv2'))
            x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv3'))

            x = max_pooling(x, kernel=[2,2], stride=2)
            x = dropout(x, rate=0.5, is_training=is_training)

            x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv4'))
            x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv5'))
            x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv6'))

            x = max_pooling(x, kernel=[2,2], stride=2)
            x = dropout(x, rate=0.5, is_training=is_training)

            x = lrelu(conv_layer(x, filter_size=512, kernel=[3,3], layer_name=scope+'_conv7'))
            x = nin(x, unit=256, layer_name=scope+'_nin1')
            x = nin(x, unit=128, layer_name=scope+'_nin2')

            x = Global_Average_Pooling(x)
            x = flatten(x)
            x = linear(x, unit=10, layer_name=scope+'_linear1')
        return x
        