In [None]:
!pip install rdkit==2023.9.3 tensorflow tensorflow_gnn

In [None]:
from os import environ
environ['CUDA_VISIBLE_DEVICES']='-1'
import tensorflow as tf
import tensorflow_gnn as tfgnn

def graph_tensor_spec():
  spec = tfgnn.GraphTensorSpec.from_piece_specs(
      node_sets_spec = {
        "atom": tfgnn.NodeSetSpec.from_field_specs(
          features_spec = {
            tfgnn.HIDDEN_STATE: tf.TensorSpec((None, 118), tf.float32)
          },
          sizes_spec = tf.TensorSpec((1,), tf.int32)
        )
      },
      edge_sets_spec = {
        "bond": tfgnn.EdgeSetSpec.from_field_specs(
          features_spec = {
            tfgnn.HIDDEN_STATE: tf.TensorSpec((None, 22), tf.float32)
          },
          sizes_spec = tf.TensorSpec((1,), tf.int32),
          adjacency_spec = tfgnn.AdjacencySpec.from_incident_node_sets("atom", "atom")
        )
      }
  )
  return spec

def FeatureExtract(channels = 256, layer_num = 4, drop_rate = 0.5):
  inputs = tf.keras.Input(type_spec = graph_tensor_spec())
  results = inputs.merge_batch_to_components() # merge graphs of a batch to one graph as different components
  results = tfgnn.keras.layers.MapFeatures(
    node_sets_fn = lambda node_set, *, node_set_name: tf.keras.layers.Dense(channels)(node_set[tfgnn.HIDDEN_STATE]),
    edge_sets_fn = lambda edge_set, *, edge_set_name: tf.keras.layers.Dense(channels)(edge_set[tfgnn.HIDDEN_STATE]))(results)
  # only update node vectors
  for i in range(layer_num):
    results = tfgnn.keras.layers.GraphUpdate(
      node_sets = {
        "atom": tfgnn.keras.layers.NodeSetUpdate(
          edge_set_inputs = {
            "bond": tfgnn.keras.layers.SimpleConv(
              message_fn = tf.keras.Sequential([
                tf.keras.layers.Dense(channels, activation = tf.keras.activations.gelu, kernel_regularizer = tf.keras.regularizers.l2(5e-4), bias_regularizer = tf.keras.regularizers.l2(5e-4)),
                tf.keras.layers.Dropout(drop_rate)
              ]),
              reduce_type = "sum",
              receiver_tag = tfgnn.TARGET
            )
          },
          next_state = tfgnn.keras.layers.NextStateFromConcat(
            transformation = tf.keras.Sequential([
              tf.keras.layers.Dense(channels, activation = tf.keras.activations.gelu, kernel_regularizer = tf.keras.regularizers.l2(5e-4), bias_regularizer = tf.keras.regularizers.l2(5e-4)),
              tf.keras.layers.Dropout(drop_rate)
            ])
          )
        )
      }
    )(results)
  # graph pooling
  results = tfgnn.keras.layers.Pool(tag = tfgnn.CONTEXT, reduce_type = "mean", node_set_name = "atom")(results)
  return tf.keras.Model(inputs = inputs, outputs = results)

def Predictor(channels = 256, layer_num = 4, drop_rate = 0.5):
  inputs = tf.keras.Input(type_spec = graph_tensor_spec())
  results = FeatureExtract(channels, layer_num, drop_rate)(inputs)
  results = tf.keras.layers.Dense(2, activation = tf.keras.activations.softmax)(results)
  return tf.keras.Model(inputs = inputs, outputs = results)


In [None]:
from requests import get
from shutil import rmtree
from os import system, remove
from hashlib import md5

if exists('ckpt.tar.gz'): remove('ckpt.tar.gz')
response = get('https://raw.githubusercontent.com/breadbread1984/tfgnn_example/classification/ckpt.tar.gz')
with open('ckpt.tar.gz', 'wb') as f:
  f.write(response.content)
assert 'e03810e81ccad15f8243332db4d6efd5' == md5(response.content).hexdigest()
if exists('ckpt'): rmtree('ckpt')
system('tar xzvf ckpt.tar.gz')
predictor = Predictor()
optimizer = tf.keras.optimizers.Adam(1e-2)
checkpoint = tf.train.Checkpoint(model = predictor, optimizer = optimizer)
checkpoint.restore(tf.train.latest_checkpoint('ckpt'))

In [None]:
from rdkit import Chem

def smiles_to_sample(smiles):
  molecule = Chem.MolFromSmiles(smiles)
  nodes = list()
  edges = list()
  for atom in molecule.GetAtoms():
    idx = atom.GetIdx()
    nodes.append(atom.GetAtomicNum())
    for neighbor_atom in atom.GetNeighbors():
      neighbor_idx = neighbor_atom.GetIdx()
      bond = molecule.GetBondBetweenAtoms(idx, neighbor_idx)
      edges.append((idx, neighbor_idx, bond.GetBondType()))
  nodes = tf.stack(nodes, axis = 0) # nodes.shape = (node_num,)
  edges = tf.stack(edges, axis = 0) # edges.shape = (edge_num, 3)
  graph = tfgnn.GraphTensor.from_pieces(
    node_sets = {
      "atom": tfgnn.NodeSet.from_fields(
        sizes = tf.constant([nodes.shape[0]]),
        features = {
          tfgnn.HIDDEN_STATE: tf.one_hot(nodes, 118)
        }
      )
    },
    edge_sets = {
      "bond": tfgnn.EdgeSet.from_fields(
        sizes = tf.constant([edges.shape[0]]),
        adjacency = tfgnn.Adjacency.from_indices(
          source = ("atom", edges[:,0]),
          target = ("atom", edges[:,1])
        ),
        features = {
          tfgnn.HIDDEN_STATE: tf.one_hot(edges[:,2], 22)
        }
      )
    }
  )
  return graph

In [None]:
csv = open('/bohr/ai4scup-cns-5zkz/v3/mol_test.csv', 'r')
output = open('submission.csv', 'w')
output.write('SMILES,TARGET\n')
for line,row in enumerate(csv.readlines()):
  if line == 0: continue
  smiles, label = row.split(',')
  graph = smiles_to_sample(smiles)
  pred = predictor(graph)
  print('%s,%d' % (smiles, 1 if pred[0,1] > 0.5 else 0) + '\n')
  output.write('%s,%d' % (smiles, 1 if pred[0,1] > 0.5 else 0) + '\n')
output.close()
csv.close()