In [1]:
import tensorflow as tf
import gin

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
mol = gin.i_o.from_smiles.to_mol('CC')
mol = gin.deterministic.hydrogen.add_hydrogen(mol)
atoms, adjacency_map = mol

In [4]:
gin.deterministic.mm.alkane_energy.alkane_energy(atoms, adjacency_map, tf.zeros((1, 8, 3)))

(<tf.Tensor: id=2170, shape=(1, 7), dtype=float32, numpy=
 array([[2988.0952, 1682.8805, 1682.8805, 1682.8805, 1682.8805, 1682.8805,
         1682.8805]], dtype=float32)>,
 <tf.Tensor: id=2205, shape=(1, 12), dtype=float32, numpy=
 array([[715.7538, 715.7538, 589.9715, 715.7538, 589.9715, 589.9715,
         715.7538, 715.7538, 589.9715, 715.7538, 589.9715, 589.9715]],
       dtype=float32)>,
 <tf.Tensor: id=2243, shape=(1, 9), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: id=2279, shape=(1, 0), dtype=float32, numpy=array([], shape=(1, 0), dtype=float32)>)

In [5]:
class GraphConv(tf.keras.Model):
    """ Spectral graph convolution.

    https://arxiv.org/pdf/1609.02907.pdf
    """
    def __init__(self, units=64, depth=10):
        super(GraphConv, self).__init__()
        self.d0 = tf.keras.layers.Dense(
            units=units,
            activation='tanh')
        self.d1 = tf.keras.layers.Dense(
            units=units,
            activation='tanh')
        self.d2 = tf.keras.layers.Dense(
            units=units,
            activation='tanh')
        self.depth=depth
    
    def call(self, atoms, adjacency_map):
        a = tf.math.add(
            adjacency_map,
            tf.transpose(
                adjacency_map))
        
        a_tilde = tf.math.add(
            a,
            tf.eye(
                tf.shape(a)[0]))
        
        d_tilde_n_1_2 = tf.linalg.diag(
            tf.math.pow(
                tf.reduce_sum(
                    a_tilde,
                    axis=0),
                tf.constant(
                    -0.5,
                    dtype=tf.float32)))
        
        x = tf.matmul(
            tf.matmul(
                d_tilde_n_1_2,
                a),
            d_tilde_n_1_2)
        
        return self.d2(
            tf.matmul(
                x,
                self.d1(
                    tf.matmul(
                        x,
                        self.d0(
                            tf.matmul(
                                x,
                                tf.one_hot(
                                    atoms,
                                    self.depth)))))))

In [9]:
class GraphFlow(tf.keras.Model):
    """ Graph Flow model.
    
    """
    def __init__(
            self, 
            units=64, 
            depth=3,
            max_degree=4):
        
        super(GraphFlow, self).__init__()
        self.units = units
        self.depth = depth
    
        self.gru_xyz = tf.keras.layers.GRU(units=units)
        self.gru_graph = tf.keras.layers.GRU(
            units=units,
            return_sequences=True,
            return_states=True)
        
        for degree in range(1, max_degree+1):
            for existing_degree in range(degree):
                setattr(
                    self,
                    'd_1_' + str(degree) + '_' + str(existing_degree),
                    tf.keras.layers.Dense(
                        depth * (3 * (degree - existing_degree)) ** 2))
                
                setattr(
                    self,
                    'd_0_' + str(degree) + '_' + str(existing_degree),
                    tf.keras.layers.Dense(
                        units,
                        activation='tanh'))
                
    @staticmethod
    def gen_child_xyz(
            z,
            degree,
            existing_degree,
            parent_xyz, # (3, )
            parnet_h_gru_graph, # (d_h, )
            parnet_h_gru_xyz, # (d_h, )
            other_child_xyz, # (n_child, 3)
            other_child_h_gru_graph, # (n_child, d_h),
            other_child_h_gru_xyz, # (n_child, d_h),
            global_h_gru,
            ):
        
        h = tf.reshape(
                tf.concat(
                    [
                        parent_xyz,
                        parent_h_gru_graph,
                        parnet_h_gru_xyz,
                        tf.reshape(other_child_xyz, [-1]),
                        tf.reshape(other_child_xyz - parent_xyz, [-1]),
                        tf.reshape(other_child_h_gru_graph, [-1]),
                        tf.reshape(other_child_h_gru_xyz, [-1]),
                        global_gru
                    ],
                    axis=0),
                [1, -1])
        
        w = tf.reshape(
                getattr(
                    self,
                    'd_1_' + str(degree) + '_' + str(existing_degree))(
                    getattr(
                        self,
                        'd_0_' + str(degree) + '_' + str(existing_degree))(
                    h)),
                [-1, 3 * (degree - existing_degree), 3 * (degree - existing_degree)])
        
        
        
        
        
                
    
                
                
        
    

In [10]:
graph_flow = GraphFlow()

'dense'