Architecture Modified From: 3D Point Cloud Generative Adversarial Network Based on Tree Stuctured Graph Convolutions.
https://github.com/seowok/TreeGAN 

In [None]:
import tensorflow as tf
import keras
import math
import numpy as np
from tensorflow.keras import layers
from keras.models import Sequential
from keras.layers import Conv1D, LeakyReLU, Dense
import numpy as np

In [None]:
class TreeGCN(keras.layers.Layer):
  def __init__(self, batch, depth, features, degrees, support=10, node=1, upsample=False, activation=True, kernel_initializer='glorot_uniform'):
    self.batch = batch
    self.depth = depth
    self.in_feature = features[depth]
    self.out_feature = features[depth+1]
    self.node = node
    self.degree = degrees[depth]
    self.upsample = upsample
    self.activation = activation 
    super(TreeGCN, self).__init__()

    #Need to create a list of submodules similar to nn.ModuleList
    #This is certainly creating an accessible array of layers, not sure if it has the same dependencies
    self.W_root = Sequential()
    # self.W_root = [Dense(self.out_features, input_dim=features[index], use_bias=False) for index in range(self.depth + 1)]
    for index in range(self.depth + 1):
      self.W_root.add(Dense(self.out_features, input_dim=features[index], use_bias=False))

    
    #Declare the loop, which has two linear layers (Dense in keras) with no bias and this input shape
    self.W_loop = Sequential()
    self.W_loop.add(Dense(self.in_feature*support, input_dim=self.in_feature, use_bias=False))
    self.W_loop.add(Dense(self.out_feature, use_bias=False)) 

    def build_branch(self):
      if self.upsample:
        kernel_initializer = tf.keras.initializer.glorot_uniform()
        self.branch_input_shape = (self.node, self.in_feature, self.in_feature * self.degree)
        self.W_branch = self.add_weight(name = 'branch',
                    shape=self.branch_input_shape,
                    initializer= kernel_initializer,
                    trainable= True)

      stdv =  1. / math.sqrt(self.out_feature)
      bias_initializer = tf.keras.initializer.RandomUniform(minval=-stdv, maxval=stdv)
      self.bias = self.add_weight(name= 'bias',
                   shape= (1, self.degree, self.out_feature),
                   initializer= bias_initializer,
                   trainable=True) 
      # gain -> relu used here in original code


    self.leaky_relu = LeakyRelu(alpha=0.2)
    self.build_branch()

    def call(self, tree):
      root = 0
      for inx in range(self.depth+1):
        root_num = tree[inx].shape[1] # obtains number of roots in layer
        repeat_num = int(self.node / root_num) # all nodes / nodes in curr tree
        root_node = self.W_root[inx](tree[inx]) # applies tree at inx into layer inx and returns matrix of size: out_features * (depth + 1)
        temp = np.tile(root_node,(1,1, repeat_num,)).reshape((self.batch, -1, self.out_feature,))
        root = root + temp

      branch = 0
      if self.upsample:
        branch = np.expand_dims(tree[-1], axis=2) @ self.W_branch # not sure what W_branch is doing
        branch = self.leaky_relu(branch)
        branch = branch.reshape(self.batch, self.node*self.degree, self.in_feature)

        branch = self.W_loop(branch)
        branch = np.tile(root,(1,1, sef.degree,)).reshape((self.batch, -1, self.out_feature,)) + branch
      else:
        branch = self.W_loop(tree[-1])
        branch = root + branch

      if self.activation:
        branch = self.leaky_relu(branch + np.tile(self.bias, (1, self.node, 1,)))
      tree.append(branch)

      return tree


In [None]:
class Generator(keras.layers.Layer):
    def __init__(self, batch_size, features, degrees, support):
        self.batch_size = batch_size
        self.layer_num = len(features)-1
        assert self.layer_num == len(degrees), "Number of features should be one more than number of degrees."
        self.pointcloud = None
        super(Generator, self).__init__()
                
        vertex_num = 1
        self.gcn = Sequential()
        for inx in range(self.layer_num):
            if inx == self.layer_num-1:
                self.gcn.add('TreeGCN_'+str(inx),
                                    TreeGCN(self.batch_size, inx, features, degrees, 
                                            support=support, node=vertex_num, upsample=True, activation=False))
            else:
                self.gcn.add('TreeGCN_'+str(inx),
                                    TreeGCN(self.batch_size, inx, features, degrees, 
                                            support=support, node=vertex_num, upsample=True, activation=True))
            vertex_num = int(vertex_num * degrees[inx])

    def forward(self, tree):
        feat = self.gcn(tree)
        
        self.pointcloud = feat[-1]

        return self.pointcloud

    def getPointcloud(self):
        return self.pointcloud[-1]

In [None]:
# TODO(Noah): Can play around with the discriminator architecture
class Discriminator(tf.keras.Model):
    
    def __init__(self, batch_size, features):
        
        self.batch_size = batch_size
        self.layer_num = len(features)-1
        super(KerasDiscriminator, self).__init__()
        
        self.fc_layers = []
        for inx in range(self.layer_num):
            self.fc_layers.append(tf.keras.layers.Conv1D( filters=features[inx+1], kernel_size=1, stride=1 ))
            
        self.leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2)
    
        self.final_layer = keras.Sequential()
        self.final_layer.add(tf.keras.layers.Dense(features[-1]))
        self.final_layer.add(tf.keras.layers.Dense(features[-2]))
        self.final_layer.add(tf.keras.layers.Dense(features[-2]))
        self.final_layer.add(tf.keras.layers.Dense(1))
        
    def call(self, inputs, training=False):
        # inputs is a batch of tensors 
        
        feat = tf.transpose(inputs, perm=[1, 2])
        #feat = inputs.transpose(1,2) # NOT SURE IF THIS WORKS
        
        vertex_num = tf.shape(feat)[2]
        #vertex_num = feat.size(2) # NOT SURE IF THIS WORKS

        for inx in range(self.layer_num):
            feat = self.fc_layers[inx](feat)
            feat = self.leaky_relu(feat)
        
        maxpool = tf.keras.layers.MaxPool1D(pool_size=vertex_num) # NOTE: Don't know if generating the layer in this space is a bad idea...
        
        out = maxpool(feat)
        out = tf.squeeze(out, -1)
        out = self.final_layer(out)

        return out