### Milestone

1. Run Graph CNN with Tox21 (DeepChem Tutorial)
2. Change input dataset to DUD-E + PDBBind
3. Compare DeepChem GraphCNN (tensorflow) and pyGCN (pytorch) and migrate.
4. Change pyGCN and pyGAT to GAGAN 

The source of this demo is [here](https://deepchem.io/docs/notebooks/graph_convolutional_networks_for_tox21.html)

In [2]:
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
import deepchem as dc
import os
import pandas as pd





## Part I. Building Graph from Molecule (SMILES)

### Original Code

In [None]:
# Load Tox21 dataset
tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv')
train_dataset, valid_dataset, test_dataset = tox21_datasets

### Analyzing Output of Code

In [None]:
print('Sample of X: ' + str(test_dataset.X[0]))
print('Sample of y: ' + str(test_dataset.y[0]))
print('Sample of w: ' + str(test_dataset.w[0]))


X = test_dataset.X
y = test_dataset.y
w = test_dataset.w

print('X.shape: ' + str(X.shape))
print('y.shape: ' + str(y.shape))
print('w.shape: ' + str(w.shape)) # I think this is the "transformed" version of y basically.


In [None]:
tox21_tasks # returns the columns of the dataset 
transformers # It has BalancingTransformer - it balances positive and negative examples. 

In [18]:
# Try it with PDBBind

pdbbind_tasks, pdbbind_datasets, transformers = dc.molnet.load_pdbbind_grid(featurizer='GraphConv', subset='full')
train_dataset_p, valid_dataset_p, test_datase_p = pdbbind_datasets

Loading dataset from disk.
Loading dataset from disk.
Loading dataset from disk.


### My Code

I extracted the code that I would re-use in my own model. Download the data below and untar it.

In [6]:
# Comparison to deepchem load_tox21 module

tox21_df = pd.read_csv('/home/joanna/mindslab/drug_discovery/data/tox21.csv')
tox21_df.head(10)

Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
0,0.0,0.0,1.0,,,0.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX3021,CCOc1ccc2nc(S(N)(=O)=O)sc2c1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3020,CCN1C(=O)NC(c2ccccc2)C1=O
2,,,,,,,,0.0,,0.0,,,TOX3024,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3027,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX20800,CC(O)(P(=O)(O)O)P(=O)(O)O
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,TOX5110,CC(C)(C)OOC(C)(C)CCC(C)(C)OOC(C)(C)C
6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX6619,O=S(=O)(Cl)c1ccccc1
7,0.0,,0.0,,1.0,,,1.0,0.0,1.0,0.0,1.0,TOX25232,O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1
8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,TOX22514,OC[C@H](O)[C@@H](O)[C@H](O)CO
9,,,,,,,,0.0,,0.0,,,TOX22517,CCCCCCCC(=O)[O-].CCCCCCCC(=O)[O-].[Zn+2]


In [None]:
pdbbind_df = pd.read_csv('/home/joanna/mindslab/drug_discovery/data/pdbbind_grid_full.csv')
pdbbind_df.head(10)

In [7]:
pdbbind_full_df = pd.read_csv('/home/joanna/mindslab/drug_discovery/data/pdbbind_core_df.csv')
pdbbind_full_df.head(10)

Unnamed: 0,pdb_id,smiles,complex_id,protein_pdb,ligand_pdb,ligand_mol2,label
0,2d3u,CC1CCCCC1S(O)(O)NC1CC(C2CCC(CN)CC2)SC1C(O)O,2d3uCC1CCCCC1S(O)(O)NC1CC(C2CCC(CN)CC2)SC1C(O)O,"['HEADER 2D3U PROTEIN\n', 'COMPND 2D3U P...","['COMPND 2d3u ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",6.92
1,3cyx,CC(C)(C)NC(O)C1CC2CCCCC2C[NH+]1CC(O)C(CC1CCCCC...,3cyxCC(C)(C)NC(O)C1CC2CCCCC2C[NH+]1CC(O)C(CC1C...,"['HEADER 3CYX PROTEIN\n', 'COMPND 3CYX P...","['COMPND 3cyx ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",8.0
2,3uo4,OC(O)C1CCC(NC2NCCC(NC3CCCCC3C3CCCCC3)N2)CC1,3uo4OC(O)C1CCC(NC2NCCC(NC3CCCCC3C3CCCCC3)N2)CC1,"['HEADER 3UO4 PROTEIN\n', 'COMPND 3UO4 P...","['COMPND 3uo4 ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Fri Aug 2...",6.52
3,1p1q,CC1ONC(O)C1CC([NH3+])C(O)O,1p1qCC1ONC(O)C1CC([NH3+])C(O)O,"['HEADER 1P1Q PROTEIN\n', 'COMPND 1P1Q P...","['COMPND 1p1q ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",4.89
4,3ag9,NC(O)C(CCC[NH2+]C([NH3+])[NH3+])NC(O)C(CCC[NH2...,3ag9NC(O)C(CCC[NH2+]C([NH3+])[NH3+])NC(O)C(CCC...,"['HEADER 3AG9 PROTEIN\n', 'COMPND 3AG9 P...","['COMPND 3ag9 ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",8.05
5,2wtv,OC(O)C1CCC(NC2NCC3CNC(C4C(F)CCCC4F)C4CC(Cl)CCC...,2wtvOC(O)C1CCC(NC2NCC3CNC(C4C(F)CCCC4F)C4CC(Cl...,"['HEADER 2WTV PROTEIN\n', 'COMPND 2WTV P...","['COMPND 2wtv ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",8.74
6,3dxg,OC1C(O)C(N2CCC(O)NC2O)OC1CO[PH](O)(O)O,3dxgOC1C(O)C(N2CCC(O)NC2O)OC1CO[PH](O)(O)O,"['HEADER 3DXG PROTEIN\n', 'COMPND 3DXG P...","['COMPND 3dxg ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",2.4
7,4tmn,CC(C)CC(N[PH](O)(O)C(CC1CCCCC1)NC(O)OCC1CCCCC1...,4tmnCC(C)CC(N[PH](O)(O)C(CC1CCCCC1)NC(O)OCC1CC...,"['HEADER 4TMN PROTEIN\n', 'COMPND 4TMN P...","['COMPND 4tmn ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",10.17
8,2zcq,O[PH](O)(O)C(CCCC1CCCC(OC2CCCCC2)C1)S(O)(O)O,2zcqO[PH](O)(O)C(CCCC1CCCC(OC2CCCCC2)C1)S(O)(O)O,"['HEADER 2ZCQ PROTEIN\n', 'COMPND 2ZCQ P...","['COMPND 2zcq ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",8.82
9,1q8t,CC([NH3+])C1CCC(C(O)NC2CCNCC2)CC1,1q8tCC([NH3+])C1CCC(C(O)NC2CCNCC2)CC1,"['HEADER 1Q8T PROTEIN\n', 'COMPND 1Q8T P...","['COMPND 1q8t ligand \n', 'AUTHOR GENERA...","['### \n', '### Created by X-TOOL on Thu Aug 2...",4.76


In [11]:
example_protein = eval(pdbbind_full_df.loc[0, 'protein_pdb'])

In [17]:
example_protein[-10000:]

['ATOM    703  NZ  LYS A  90      10.763  71.347  49.866  1.00  0.00           N1+\n',
 'ATOM    704  N   LEU A  91      13.811  64.043  48.251  1.00  0.00           N  \n',
 'ATOM    705  CA  LEU A  91      13.544  62.608  48.377  1.00  0.00           C  \n',
 'ATOM    706  C   LEU A  91      14.802  61.753  48.246  1.00  0.00           C  \n',
 'ATOM    707  O   LEU A  91      14.752  60.533  48.407  1.00  0.00           O  \n',
 'ATOM    708  CB  LEU A  91      12.525  62.163  47.330  1.00  0.00           C  \n',
 'ATOM    709  CG  LEU A  91      11.081  62.617  47.551  1.00  0.00           C  \n',
 'ATOM    710  CD1 LEU A  91      10.261  62.285  46.321  1.00  0.00           C  \n',
 'ATOM    711  CD2 LEU A  91      10.503  61.935  48.792  1.00  0.00           C  \n',
 'ATOM    712  N   THR A  92      15.929  62.401  47.966  1.00  0.00           N  \n',
 'ATOM    713  CA  THR A  92      17.204  61.710  47.800  1.00  0.00           C  \n',
 'ATOM    714  C   THR A  92      17.913  6

In [None]:
# Explore rdkit
from rdkit import Chem
from rdkit.Chem import rdmolfiles
from rdkit.Chem import rdmolops


test_smiles_seq = 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'

mol = Chem.MolFromSmiles(test_smiles_seq)
print(mol.GetNumAtoms())
print([atom.GetAtomicNum() for atom in mol.GetAtoms()])
print([atom.GetSymbol() for atom in mol.GetAtoms()])

In [None]:
# TODO change atom list to all atoms
possible_atom_list =  ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
    'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
    'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt','Hg',
    'Pb', 'As', 'Unknown']
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
        
possible_degree_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]
possible_chirality_list = ['R', 'S']

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def get_atom_features(atom, bool_id_feat=False, explicit_H=False, use_chirality=False):
    """
    From deepchem.feat.graph_features
    """
    if bool_id_feat: return np.array([atom_to_id(atom)])
    else:
        ## why not one-hot get_formal_charge and get_num_radical_electrons?

        
        result = one_of_k_encoding_unk(atom.GetSymbol(), possible_atom_list) + \
                 one_of_k_encoding(atom.GetDegree(), possible_degree_list) + \
                 one_of_k_encoding_unk(atom.GetImplicitValence(), possible_valence_list) + \
                [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                 one_of_k_encoding_unk(atom.GetHybridization(), possible_hybridization_list) + \
                [atom.GetIsAromatic()]
                
        if not explicit_H:
            result = result + one_of_k_encoding_unk(atom.GetTotalNumHs(), possible_numH_list)
        if use_chirality:
            try:
                result = result + one_of_k_encoding_unk(atom.GetProp('_CIPCode'), possible_chirality_list) + \
                    [atom.HasProp('_Ch_possible_numH_list = iralityPossible')]
            except:
                result = result + [False, False] + [atom.HasProp('_ChiralityPossible')]
                
    return np.array(result)

In [None]:
# Comparison to "featurizer" in deepchem
# featurizer._get_atom_properties gets additional features of the atom so it is ignored for now.
from deepchem.feat.mol_graphs import ConvMol

def smiles_2_graph(smiles_arr, featurize_fn):
    """
    from deepchem.data.data_loader.featurize_smiles_df
    """
    raw_features = [] # modified
    features = []
    for ind, elem in enumerate(smiles_arr):
        mol = Chem.MolFromSmiles(elem)
        if mol:
            # what are these lines doing?
            # Answer: found in deepchem.data.data_loader featurize_smiles_df
            # TODO (ytz) this is a bandage solution to reorder the atoms so
            # that they're always in the same canonical order. Presumably this
            # should be correctly implemented in the future for graph mols.
            
            new_order = rdmolfiles.CanonicalRankAtoms(mol)
            mol = rdmolops.RenumberAtoms(mol, new_order)
            
        raw_feature, feature = featurize_fn(mol)
        raw_features.append(raw_feature)
        features.append(feature)

    return raw_features, features

def get_graph_from_molecule(mol, use_master_atom=False):
    """
    From ConvMolFeaturizer._featurize
    Input:
        rdkit.Chem.rdchem.Mol
        
    Output:
        nodes - np.ndarray of shape (num_atoms, num_feat)
        canon_adj_list - list. index corresponds to the index of node 
                         and canon_adj_list[index] corresponds to indices of the nodes that node i is connected to. 
    """
    idx_nodes = [(atom.GetIdx(), get_atom_features(atom)) for atom in mol.GetAtoms()]
    idx_nodes.sort()
    idx, nodes = list(zip(*idx_nodes))
    
    nodes = np.vstack(nodes)

    # Master atom is the "average" of all atoms that is connected to all atom
    # Introduced in https://arxiv.org/pdf/1704.01212.pdf
    if use_master_atom: 
        master_atom_features = np.expand_dims(np.mean(nodes, axis=0), axis=0)
        nodes = np.concatenate([nodes. master_atom_features], axis=0)
        
    edge_list = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    
    canon_adj_list = [[] for _ in range(len(nodes))] # nodes is np 2darray. why len?
    
    for edge in edge_list:
        canon_adj_list[edge[0]].append(edge[1])
        canon_adj_list[edge[1]].append(edge[0])
        
    if use_master_atom:
        fake_atom_index = len(nodes) - 1
        
        for i in range(len(nodes) - 1):
            canon_adj_list[i].append(fake_atom_index)
        
    return (nodes, canon_adj_list), ConvMol(nodes, canon_adj_list)

In [None]:
# extracted from ConvMol. Do we really need this??

def cumulative_sum(l, offset=0):
    """Returns cumulative sums for set of counts.
    Returns the cumulative sums for a set of counts with the first returned value
    starting at 0. I.e [3,2,4] -> [0, 3, 5, 9]. Keeps final sum for searching. 
    Useful for reindexing.
    Parameters
    ----------
    l: list
        List of integers. Typically small counts.
    """
    return np.insert(np.cumsum(l), 0, 0) + offset

def cumulative_sum_minus_last(l, offset=0):
    """Returns cumulative sums for set of counts, removing last entry.
    Returns the cumulative sums for a set of counts with the first returned value
    starting at 0. I.e [3,2,4] -> [0, 3, 5]. Note last sum element 9 is missing.
    Useful for reindexing
    Parameters
    ----------
    l: list
        List of integers. Typically small counts.
    """
    return np.delete(np.insert(np.cumsum(l), 0, 0), -1) + offset


def get_convmol_features(atom_features, adj_list, max_deg, min_deg):
    n_atoms, n_feat = atom_features.shape
    deg_list = [len(edges) for edges in adj_list]
    deg_slice = []
    # membership = n_atoms * [0]
    
    ## start of _deg_sort()
    old_ind = range(n_atoms)
    new_ind = list(np.lexsort((old_ind, deg_list)))
    
    # reorder old atom_features
    atom_features = atom_features[new_ind, :]
    deg_list = [deg_list[i] for i in new_ind]
    adj_list = [adj_list[i] for i in new_ind]

    # not intuitive way of sorting edges in adj_list
    old_to_new = dict(zip(new_ind, old_ind)) # interesting. it's (value, key)
    adj_list = [[old_to_new[k] for k in adj_list[i]] for i in range(len_new_ind)]
    
    # construct adj_lists
    deg_array = np.array(deg_list)
    deg_adj_lists = (max_deg + 1 - min_deg) * [0]

    for deg in range(min_deg, max_deg + 1):
        rng = np.array(range(n_atoms))
        indices = rng[deg_array == deg]
        
        to_cat = [adj_list[i] for i in indices]
        if len(to_cat) > 0:
            adj_list = np.vstack([adj_list[i] for i in indices])
            deg_adj_lists[deg - min_deg] = adj_list
            
        else:
            deg_adj_lists[deg - min_deg] = np.zeros([0, deg], dtype=np.int32)
            
    
    # construct slice info 
    deg_slice = np.zeros([max_deg + 1 - min_deg, 2], dtype=np.int32)
    
    for deg in range(min_deg, max_deg + 1):
        if deg == 0:
            deg_size = np.sum(deg_array == deg)
        else:
            deg_size = deg_adj_lists[deg - min_deg].shape[0]
        
        deg_slice[deg - min_deg, 1] = deg_size
        
        # get cumulative indices after the first index
        if deg > min_deg: 
            deg_slice[deg-min_deg, 0] = (deg_slice[deg - min_deg - 1, 0] + deg_slice[deg - min_deg - 1, 1])
    
    # set indices with zerosized slices to zero to avoid indexing errors
    deg_slice[:, 0] *= (deg_slice[:, 1] != 0)
    
    ## end of _deg_sort()
    
    deg_id_list = np.array(deg_list) - min_deg
    deg_size = [deg_slice[deg - min_deg, 1] for deg in range(min_deg, max_deg + 1)] # This part is equivalent of get_num_atoms_with_deg
    
    degree_list = [] ## ???? What's the difference bt deg_list vs degree_list?
    for i, deg in enumerate(range(min_deg, max_deg+1)):
        degree_list.extend([deg] * deg_size[i])
    
    deg_start = cumulative_sum(deg_size)
    
    deg_block_indices = np.array([i - deg_start[deg_list[i]] for i in range(n_atoms)])
    
    return deg_list, deg_adj_lists, deg_slice
   

In [None]:
num_atom_features = smiles_2_graph([test_smiles_seq], featurize_fn=get_graph_from_molecule)[0][0].shape[1]
print(num_atom_features)

# Why do I have extra one?? cuz i added As

## Part 2. GraphConvolution Model in Keras

`GraphConv` module in deepchem is using Keras to implement graph convolution.
The reference (paper) is not provided.

### Original Code

In [None]:
from deepchem.models.tensorgraph.models.graph_models import GraphConvModel
model = GraphConvModel(
    len(pdbbind_tasks), batch_size=50, mode='regression')
# Set nb_epoch=10 for better results.
model.fit(train_dataset_p, nb_epoch=10)



















Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor








































Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "








In [None]:
metric = dc.metrics.Metric(
    dc.metrics.roc_auc_score, np.mean, mode="classification")

print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)
print("Training ROC-AUC Score: %f" % train_scores["mean-roc_auc_score"])
valid_scores = model.evaluate(valid_dataset, [metric], transformers)
print("Validation ROC-AUC Score: %f" % valid_scores["mean-roc_auc_score"])

### My Code

In [None]:
from tensorflow.keras.layers import Input, Dense, Reshape, Softmax, Dropout, Activation, BatchNormalization

class GraphConvModel:
    def __init__(self, 
                 graph_conv_layers=[64, 64], 
                 dense_layer_size=128,
                 dropout=0.0, 
                 num_atom_features=num_atom_features,
                 n_classes=2,
                 batch_size=100,
                 max_deg=10,
                 **kwargs):
    
        self.dense_layer_size = dense_layer_size
        self.graph_conv_layers = graph_conv_layers
        self.number_atom_features = num_atom_features
        self.n_classes = n_classes
        if not isinstance(dropout, collections.Sequence):
            dropout = [dropout] * (len(graph_conv_layers) + 1)
        if len(dropout) != len(graph_conv_layers) + 1:
            raise ValueError('Wrong number of dropout probabilities provided')
        self.dropout = dropout
        self.max_deg = max_deg
        
        # Build model
        atom_features = Input(shape=(self.number_atom_features,))
        degree_slice = Input(shape=(2,), dtype=tf.int32)
        membership = Input(shape=tuple(), dtype=tf.int32)
        n_samples = Input(shape=tuple(), dtype=tf.int32)
        dropout_switch = tf.keras.Input(shape=tuple())
        
        # What is this part?
        self.deg_adjs = []
        for i in range(0, self.max_deg + 1): # Why 10?
            deg_adj = Input(shape=(i + 1,), dtype=tf.int32)
            self.deg_adjs.append(deg_adj)
            
        in_layer = atom_features
        for layer_size, dropout in zip(self.graph_conv_layers, self.dropout):
            gc1_in = [in_layer, degree_slice, membership] + self.deg_adjs
            gc1 = GraphConv(layer_size, activation_fn=tf.nn.relu)(gc1_in)
            batch_norm1 = BatchNormalization(fused=False)(gc1)
            
            if dropout > 0.0:
                batch_norm1 = SwitchedDropout(rate=dropout)([batch_norm1, dropout_switch])
            gp_in = [batch_norm1, degree_slice, membership] + self.deg_adjs
            in_layer = GraphPool()(gp_in)

        dense = Dense(self.dense_layer_size, activation=tf.nn.relu)(in_layer)
        batch_norm3 = BatchNormalization(fused=False)(dense)
        
        if self.dropout[-1] > 0.0:
            batch_norm3 = SwitchedDropout(rate=self.dropout[-1])([batch_norm3, dropout_switch])
        self.neural_fingerprint = GraphGather(batch_size=batch_size, activation_fn=tf.nn.tanh)(
            [batch_norm3, degree_slice, membership] + self.deg_adjs)
        
        logits = Reshape((1, n_classes))(Dense(self.n_classes))(self.neural_fingerprint)
        logits = TrimGraphOutput()([logits, n_samples])
        output = Softmax()(logits)
        outputs = [output, logit]
        output_types = ['prediction', 'loss']
        loss = SoftmaxCrossEntropy()
        
        model = tf.keras.Model(
            inputs=[atom_features, degree_slices, membership, n_samples, dropout_switch] + self.deg_adjs,
            output=outputs
        )
        ## TODO The model says this but change it.
        ## TODO Check KerasModel's fit 
        super(GraphConvModel, self).__init__(
            model, loss, output_types=output_types, batch_size=batch_size, **kwargs)
        
        
# TODO Implement the layers

class GraphConv:
    def __init__(self):
        pass

class SwitchedDropout:
    def __init__(self):
        pass
    
class GraphPool:
    def __init__(self):
        pass
    
class GraphGather:
    def __init__(self):
        pass
        
# TODO Study this code. What is ConvMol doing for deg_slice & membership?
def default_generator(self,
                        dataset,
                        epochs=1,
                        mode='fit',
                        deterministic=True,
                        pad_batches=True):
    for epoch in range(epochs):
        for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=pad_batches): # what does this do?
            if self.mode == 'classification':
                y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape(
                -1, self.n_tasks, self.n_classes)
            multiConvMol = ConvMol.agglomerate_mols(X_b)
            n_samples = np.array(X_b.shape[0])
            if mode == 'predict':
                dropout = np.array(0.0)
            else:
                dropout = np.array(1.0)
            inputs = [
                multiConvMol.get_atom_features(), multiConvMol.deg_slice,
                np.array(multiConvMol.membership), n_samples, dropout
            ]
            for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
                inputs.append(multiConvMol.get_deg_adjacency_lists()[i])
            yield (inputs, [y_b], [w_b])
        
# Why yield?
# We should use yield when we want to iterate over a sequence, but don’t want to store the entire sequence in memory.

In [None]:
from deepchem.models.tensorgraph.tensor_graph import TensorGraph

tg = TensorGraph(use_queue=False)

In [None]:
from deepchem.models.tensorgraph.layers import Feature

atom_features = Feature(shape=(None, 75))
degree_slice = Feature(shape=(None, 2), dtype=tf.int32)
membership = Feature(shape=(None,), dtype=tf.int32)

deg_adjs = []
for i in range(0, 10 + 1):
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
    deg_adjs.append(deg_adj)

In [None]:
from deepchem.models.tensorgraph.layers import Dense, GraphConv, BatchNorm
from deepchem.models.tensorgraph.layers import GraphPool, GraphGather

batch_size = 50

gc1 = GraphConv(
    64,
    activation_fn=tf.nn.relu,
    in_layers=[atom_features, degree_slice, membership] + deg_adjs)
batch_norm1 = BatchNorm(in_layers=[gc1])
gp1 = GraphPool(in_layers=[batch_norm1, degree_slice, membership] + deg_adjs)
gc2 = GraphConv(
    64,
    activation_fn=tf.nn.relu,
    in_layers=[gp1, degree_slice, membership] + deg_adjs)
batch_norm2 = BatchNorm(in_layers=[gc2])
gp2 = GraphPool(in_layers=[batch_norm2, degree_slice, membership] + deg_adjs)
dense = Dense(out_channels=128, activation_fn=tf.nn.relu, in_layers=[gp2])
batch_norm3 = BatchNorm(in_layers=[dense])
readout = GraphGather(
    batch_size=batch_size,
    activation_fn=tf.nn.tanh,
    in_layers=[batch_norm3, degree_slice, membership] + deg_adjs)

In [None]:
from deepchem.models.tensorgraph.layers import Dense, SoftMax, \
    SoftMaxCrossEntropy, WeightedError, Stack
from deepchem.models.tensorgraph.layers import Label, Weights

costs = []
labels = []
for task in range(len(tox21_tasks)):
    classification = Dense(
        out_channels=2, activation_fn=None, in_layers=[readout])

    softmax = SoftMax(in_layers=[classification])
    tg.add_output(softmax)

    label = Label(shape=(None, 2))
    labels.append(label)
    cost = SoftMaxCrossEntropy(in_layers=[label, classification])
    costs.append(cost)
all_cost = Stack(in_layers=costs, axis=1)
weights = Weights(shape=(None, len(tox21_tasks)))
loss = WeightedError(in_layers=[all_cost, weights])
tg.set_loss(loss)


In [None]:
from deepchem.metrics import to_one_hot
from deepchem.feat.mol_graphs import ConvMol

def data_generator(dataset, epochs=1, predict=False, pad_batches=True):
    for epoch in range(epochs):
        if not predict:
            print('Starting epoch %i' % epoch)
        for ind, (X_b, y_b, w_b, ids_b) in enumerate(
            dataset.iterbatches(
                batch_size, pad_batches=pad_batches, deterministic=True)):
            d = {}
            for index, label in enumerate(labels):
                d[label] = to_one_hot(y_b[:, index])
            d[weights] = w_b
            multiConvMol = ConvMol.agglomerate_mols(X_b)
            d[atom_features] = multiConvMol.get_atom_features()
            d[degree_slice] = multiConvMol.deg_slice
            d[membership] = multiConvMol.membership
            for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
                d[deg_adjs[i - 1]] = multiConvMol.get_deg_adjacency_lists()[i]
            yield d

In [None]:
# Epochs set to 1 to render tutorials online.
# Set epochs=10 for better results.
tg.fit_generator(data_generator(train_dataset, epochs=1))

In [None]:
metric = dc.metrics.Metric(
    dc.metrics.roc_auc_score, np.mean, mode="classification")

def reshape_y_pred(y_true, y_pred):
    """
    TensorGraph.Predict returns a list of arrays, one for each output
    We also have to remove the padding on the last batch
    Metrics taks results of shape (samples, n_task, prob_of_class)
    """
    n_samples = len(y_true)
    retval = np.stack(y_pred, axis=1)
    return retval[:n_samples]


print("Evaluating model")
train_predictions = tg.predict_on_generator(data_generator(train_dataset, predict=True))
train_predictions = reshape_y_pred(train_dataset.y, train_predictions)
train_scores = metric.compute_metric(train_dataset.y, train_predictions, train_dataset.w)
print("Training ROC-AUC Score: %f" % train_scores)

valid_predictions = tg.predict_on_generator(data_generator(valid_dataset, predict=True))
valid_predictions = reshape_y_pred(valid_dataset.y, valid_predictions)
valid_scores = metric.compute_metric(valid_dataset.y, valid_predictions, valid_dataset.w)
print("Valid ROC-AUC Score: %f" % valid_scores)