In [1]:
import numpy as np
import tensorflow as tf

In [2]:
def leakyrelu(x, alpha=0.01):
    return tf.maximum(x, alpha*x)
    

class StackGAN():
    def __init__(self,image_dims):
        self.s = image_dims
        self.s2,self.s4,self.s8,self.s16 = self.s//2,self.s//4,self.s//8,self.s//16
        self.gfdim = 128
        self.initializer = tf.truncated_normal_initializer(stddev=0.02)
        self.e_dim = 128
        self.batch_size = 64
        self.noise_dim = 100
        
    def get_conv(self,tensor,shape,initializer, name, padding = 'SAME', strides = [1,1,1,1],isrelu = True, isbn = True):
        W = tf.get_variable(name=name,shape=shape,initializer= initializer)
        conv = tf.nn.conv2d(tensor,filter=W,strides=strides,padding = padding)
        if(isbn):
            conv = tf.layers.batch_normalization(conv)
        if(isrelu):
            conv = tf.nn.relu(conv)
        return conv
        
    
    def generator(self,z):
        z = tf.contrib.layers.flatten(inputs=z)
        output_shape = (self.s16**2)*self.gfdim
        W1 = tf.get_variable(initializer=tf.random_normal_initializer(stddev=0.2),name='W1',shape= [z.shape[1],output_shape])
        b1 = tf.get_variable(initializer=tf.constant_initializer(0.0),name = 'b1', shape = [output_shape])
        fc1 = tf.matmul(z,W1) + b1
        bn1 = tf.layers.batch_normalization(fc1,axis=1)
        rs1 = tf.reshape(bn1,shape=[-1,self.s16,self.s16,self.gfdim*8])
        
        initializer = self.initializer
        
        Wc1 = tf.get_variable(name='Wc1', initializer=tf.truncated_normal_initializer(stddev=0.02),shape = [1,1,self.gfdim,self.gfdim*2])
        conv1 = tf.nn.conv2d(input=rs1,filter=Wc1,strides = [1,1,1,1],padding = 'SAME')
        bn2 = tf.layers.batch_normalization(conv1)
        rel1 = tf.nn.relu(bn2)
        Wc2 = tf.get_variable(name='Wc2', initializer=tf.truncated_normal_initializer(stddev=0.02), shape=[3,3,self.gfdim*2,self.gfdim*2])
        conv2 = tf.nn.conv2d(input=rel1, filter=Wc2, strides=[1,1,1,1],padding='SAME')
        bn3 = tf.layers.batch_normalisation(conv2)
        rel2 = tf.nn.relu(bn3)
        Wc3 = tf.get_variable(name='Wc3', initializer=tf.truncated_normal_initializer(stddev=0.02), shape=[3,3,self.gfdim*2,self.gfdim*8])
        conv3 = tf.nn.conv2d(input=rel2, filter=Wc3, strides=[1,1,1,1],padding='SAME')
        bn4 = yf.layers.batch_normalisation(conv3)
        
        sum1 = rs1 + gn4
        rel3 = tf.nn.relu(sum1)
        convT1 = tf.image.resize_nearest_neighbor(size=[self.s8,self.s8])
        Wc4 = tf.get_variable(name='Wc4',initializer=initializer,shape=[3,3,self.gfdim*8,self.gfdim*4])
        conv4 = tf.nn.conv2d(input = convT1, filter = Wc4, strides = [1,1,1,1],padding = 'SAME')
        bn5 = tf.layers.batch_normalization(conv4)
        
        conv5 = self.get_conv(bn5,[1,1,self.gfdim*4,self.gfdim],initializer,'Wc5')
        conv6 = self.get_conv(conv5,[3,3,self.gfdim,self.gfdim],initializer,'Wc6')
        conv7 = self.get_conv(conv6,[3,3,self.gfdim,self.gfdim*4],initializer,'Wc7',isrelu=False)
        sum2 = tf.nn.relu(bn5 + conv7)
        convT2 = tf.image.resize_nearest_neighbor(images=sum2,size=[self.s4,self.s4])
        conv8 = self.get_conv(convT2,[3,3,self.gfdim*4,self.gfdim*2],initializer,'Wc8')
        convT3 = tf.image.resize_nearest_neighbor(images=conv8,size=[self.s2,self.s2])
        conv9 = self.get_conv(convT3,[3,3,self.gfdim*2,self.gfdim],initializer,'Wc9')
        convT4 = tf.image.resize_nearest_neighbor(images=conv9,size=[self.s,self.s])
        conv10 = self.get_conv(convT4,[3,3,self.gfdim,3],initializer,'Wc10')
        image = tf.nn.tanh(conv10)
        return image
    
    def discriminator(self, x):
        initializer = self.initializer
        conv1 = self.get_conv(x,[1,1,int(x.shape[3]),self.dfdim*8],initializer,'Wd1', isrelu=False)
        r1 = leakyrelu(conv1, 0.2)
        conv2 = self.get_conv(r1, [self.s16,self.s16,self.dfdim*8,1],'Wd2',strides=[1,self.s16,self.s16,1])
        return conv2
    
    def image_embedding(self, x):
        conv1 = self.get_conv(x, [4,4,int(x.shape[3]),self.dfdim],'We1', strides=[1,2,2,1], isbn=False,isrelu=False)
        r1 = leakyrelu(conv1, 0.2)
        conv2 = self.get_conv(r1, [4,4,x.self.dim,self.dfdim*2],'We2', strides=[1,2,2,1], isrelu=False)
        r2 = leakyrelu(conv2,0.2)
        conv3 = self.get_conv(r2, [4,4,x.self.dim*2,self.dfdim*4],'We3', strides=[1,2,2,1], isrelu=False)
        conv4 = self.get_conv(conv3, [4,4,x.self.dim*4,self.dfdim*8],'We4', strides=[1,2,2,1], isrelu=False)
        
        conv5 = self.get_conv(conv4, [1,1,self.dim*8,self.dfdim*2],'We5',isrelu=False)
        r3 = leakyrelu(conv5, 0.2)
        conv6 = self.get_conv(conv4, [3,3,self.dim*2,self.dfdim*2],'We6',isrelu=False)
        r4 = leakyrelu(conv6, 0.2)
        conv7 = self.get_conv(conv4, [3,3,self.dim*2,self.dfdim*8],'We7',isrelu=False)
        sum1 = conv4 + conv7
        return leakyrelu(sum1,0.2)
    
    def text_embedding(self, embedding):
        ouput_shape = self.e_dim
        W1 = tf.get_variable(initializer=tf.random_normal_initializer(stddev=0.2),name='Wt1',shape= [embedding.shape[1],output_shape])
        b1 = tf.get_variable(initializer=tf.constant_initializer(0.0),name = 'bt1', shape = [output_shape])
        fc1 = tf.matmul(z,W1) + b1
        r1 = leakyrelu(fc1,0.2)
        return r1
    
    def cond_aug(self, context_embed):
        output_shape = self.e_dim*2
        W1 = tf.get_variable(initializer=tf.random_normal_initializer(stddev=0.2),name='Wca1',shape= [embedding.shape[1],output_shape])
        b1 = tf.get_variable(initializer=tf.constant_initializer(0.0),name = 'bca1', shape = [output_shape])
        fc1 = tf.matmul(z,W1) + b1
        r1 = leakyrelu(fc1,0.2)
        mean = r1[:,:self.e_dim]
        sigma = r1[:,self.e_dim:]
        return mean, sigma
    
    def call_discriminator(self, image, embed):
        X = self.image_embedding(image)
        t_embed = self.text_embedding(embed)
        t_embed_exp = tf.expand_dims(tf.expand_dims(t_embed, axis=1), axis=1)
        t_embed_exp = tf.tile(t_embed_exp, [1,self.s16,self.s16,1])
        xt = tf.concat(3, [X, t_embed_exp])
        return discriminator(xt)
        
    def sample_cont(self, embed):
        mean, sigma = self.cond_aug(t_embed)
        epsilon = tf.truncated_normal(tf.shape(mean))
        std_dev = tf.exp(sigma)
        c = mean + (std_dev * epsilon)
        
        #error omplementation
        
        return c, kl_loss
    
    def get_generator(self, embed):
        t_embed = text_embedding(embed)
        noise = tf.random_normal([self.batch_size, self.noise_dim])
        c, _ = self.sample_cont(self, t_embed)
        z = tf.concat(1, [c, noise])
        out = self.generator(z)
        return out