# Flax ogbg-molpcba Example

<a href="https://colab.research.google.com/github/google/flax/blob/main/examples/ogbg_molcpba/ogbg_molcpba.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Demonstration notebook for
https://github.com/google/flax/tree/main/examples/ogbg_molcpba.

## Setup

In [11]:
#@title pip installs
# Install clu, ml-collections, latest Flax version, and tensorflow_datasets.
!pip install -U -q clu ml-collections git+https://github.com/google/flax tensorflow_datasets

In [None]:
#@title Fetch example code
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@markdown **Fetch newest Flax version and copy of example code.**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted and any changes will be lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will *persist*. Even if you re-run this
#@markdown Colab notebook later on, the files will still exist. You can
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files.

example_directory = 'examples/ogbg_molpcba'
editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')
repo, branch = 'https://github.com/google/flax', 'main'

if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download Flax repo from Github.
  if not os.path.isdir('flaxrepo'):
    !git clone --depth=1 -b $branch $repo flaxrepo
  # Copy example files & change directory.
  mount_gdrive = 'no' #@param ['yes', 'no']
  if mount_gdrive == 'yes':
    DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
    from google.colab import drive
    drive.mount('/content/gdrive')
    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
  else:
    DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
    example_root_path = f'/content/{example_directory}'
    from IPython import display
    display.display(display.HTML(
        f'<h1 style="color:red;" class="blink">{DISCLAIMER}</h1>'))
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r flaxrepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    s = open(f'{example_root_path}/{relpath}').read()
    open(f'{example_root_path}/{relpath}', 'w').write(
        f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
    files.view(f'{example_root_path}/{relpath}')

In [12]:
#@title Display current working directory
# Note: In Colab, running the above cell changes the working directory.
!pwd

/content


## Imports

In [9]:
#@title Base imports
from absl import logging
import flax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import pprint
logging.set_verbosity(logging.INFO)

In [None]:
#@title Auto-reload setup
# Local imports from current directory - auto reload.
# Any changes you make to train.py and other modules below will appear automatically.
%load_ext autoreload
%autoreload 2
import train
import input_pipeline
from configs import default as config_lib
config = config_lib.get_config()

## Dataset

TensorFlow Datasets supports customizable visualization of the ogbg_molpcba dataset.

In [6]:
#@title Visualization helpers
# Dictionaries used to map nodes and edges to colors.
atomic_numbers_to_elements = {
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
    14: 'Si',
    15: 'P',
    16: 'S',
    17: 'Cl',
    35: 'Br,'
}
elements_to_colors = {
    element: f'C{index}'
    for index, element in enumerate(atomic_numbers_to_elements.values())
}
bond_types_to_colors = {num: f'C{num}' for num in range(4)}

# Node colors are atomic numbers.
def node_color_fn(graph):
  atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
  return {
      index: elements_to_colors[atomic_numbers_to_elements[atomic_number]]
      for index, atomic_number in enumerate(atomic_numbers)
  }

# Node labels are element names.
def node_label_fn(graph):
  atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
  return {
      index: atomic_numbers_to_elements[atomic_number]
      for index, atomic_number in enumerate(atomic_numbers)
  }

# Edge colors are bond types.
def edge_color_fn(graph):
  bonds = graph['edge_index'].numpy()
  bond_types = graph['edge_feat'][:, 0].numpy()
  return {
      tuple(bond): bond_types_to_colors[bond_type]
      for bond, bond_type in zip(bonds, bond_types)
  }

In [None]:
#@title Visualize examples
ds, ds_info = tfds.load('ogbg_molpcba', with_info=True)
tfds.visualization.show_examples(ds, ds_info,
                                 node_color_fn=node_color_fn,
                                 node_label_fn=node_label_fn,
                                 edge_color_fn=edge_color_fn)

## Training

In [None]:
#@title Start TensorBoard
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=.

In [None]:
#@title Training loop
# We don't use TPUs in this Colab because we do not distribute our
# training using pmap() - if you're looking for an example using TPUs
# checkout the below Colab notebook:
# https://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb

config.num_training_steps = 10000
state = train.train_and_evaluate(config, workdir=f'./models')

In [None]:
#@title Upload to TensorBoard.dev
if 'google.colab' in str(get_ipython()):
  #@markdown You can upload the training results directly to [TensorBoard.dev](https://tensorboard.dev).
  #@markdown
  #@markdown Note that everbody with the link will be able to see the data.
  upload_data = 'no' #@param ['yes', 'no']
  if upload_data == 'yes':
    !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/ogbg_molpcba'

## Inference

In [None]:
#@title Create evaluation model
eval_net = train.create_model(config, deterministic=True)
eval_state = state.replace(apply_fn=eval_net.apply)

In [None]:
#@title Compute metrics
# Compute accuracy and mean average precision on validation and test sets.
eval_metrics = train.evaluate_model(eval_state, datasets, splits=['validation', 'test'])
pprint.pprint(eval_metrics)