<a href="https://colab.research.google.com/github/leojklarner/Q-SAVI/blob/staging/qsavi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Running Q-SAVI on Colab

The following notebook demonstrates how to train and evaluate Q-SAVI models and reproduces the experimental results from the paper.
Specifically, it:

1.   clones the Q-SAVI source code from GitHub
2.   downloads a pre-processed subsample of the ZINC database (featurized as ECFPs) to use as a context point distribution
3. installs any dependencies not available in the default Colab environment  
4. specifies the hyperparameter combinations with the lowest validation NLL, identified by running the qsavi_hyperparam_search.py script
5. trains 10 models for each hyperparameter combination and split using different random seeds



## Setup

In [None]:
# download Q-SAVI source code from GitHub Repo
!git clone -b staging https://github.com/leojklarner/Q-SAVI.git

# download and extract pre-processed context point files
!mkdir /content/Q-SAVI/data/datasets/zinc
!wget https://www.dropbox.com/s/xsbz8wyewupnpe8/zinc_context_points_ecfp.tar.gz?dl=0 -P /content/Q-SAVI/data/datasets/zinc
!tar -xf /content/Q-SAVI/data/datasets/zinc/zinc_context_points_ecfp.tar.gz?dl=0 -C /content/Q-SAVI/data/datasets/zinc

Cloning into 'Q-SAVI'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 60 (delta 6), reused 10 (delta 2), pack-reused 40[K
Unpacking objects: 100% (60/60), 162.72 MiB | 8.73 MiB/s, done.
Updating files: 100% (22/22), done.
--2023-07-11 17:24:14--  https://www.dropbox.com/s/xsbz8wyewupnpe8/zinc_context_points_ecfp.tar.gz?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/xsbz8wyewupnpe8/zinc_context_points_ecfp.tar.gz [following]
--2023-07-11 17:24:15--  https://www.dropbox.com/s/raw/xsbz8wyewupnpe8/zinc_context_points_ecfp.tar.gz
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc7f77bcc89ce784a1e7cac1145e.dl.dropboxuser

In [None]:
%cd /content/Q-SAVI/

# install packages not already available in Colab environment
!pip install dm-haiku
!pip install -q gwpy

/content/Q-SAVI
Collecting dm-haiku
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.9 jmp-0.0.4
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.4/45.4 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.0/51.0 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for ligo-segments (setup.py) ... [?25l[?25hdon

## Model Training

In [None]:
import os
import argparse
import json
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score, brier_score_loss

from qsavi.qsavi import QSAVI
from qsavi.config import add_qsavi_args, arg_map

Num GPUs Available (TF):  1
JAX is using gpu
JAX devices: [gpu(id=0)]


In [None]:
# specify the hyperparameters with the lowest validation set NLL
# obtained from running qsavi_hyper_search.py

best_hyperparams = {
    ("spectral_split"): {
        "learning_rate": 1e-4, "num_layers": 2, "embed_dim": 32,
        "prior_cov": 100.0,  "n_context_points": 16,
        },
    ("mw_split"): {
        "learning_rate": 1e-4, "num_layers": 4, "embed_dim": 32,
        "prior_cov": 100.0,  "n_context_points": 16,
        },
    ("scaffold_split"): {
        "learning_rate": 1e-4, "num_layers": 6, "embed_dim": 64,
        "prior_cov": 10.0,  "n_context_points": 16,
        },
    ("random_split"): {
        "learning_rate": 1e-4, "num_layers": 4, "embed_dim": 32,
        "prior_cov": 100.0,  "n_context_points": 16,
        },
}

In [None]:
def run_qsavi(split):
  """
  Run Q-SAVI algorithm for the hyperparameter combination with
  the lowest validation set NLL with 10 different seeds, using
  the specified data split and pre-processed ECFPs.

  Args:
    split: data split to train on (
      "spectral_split", "mw_split",
      "scaffold_split", "random_split")

  Returns:
    list of 10 dicts of test set predictions
  """

  parser = argparse.ArgumentParser(description='Q-SAVI Command Line Interface')
  parser.add_argument('-f')  # extra argument needed to get argparser to work in colab
  add_qsavi_args(parser)
  kwargs = parser.parse_args()


  print(f"Using best hyperparameters for {split}:")
  hypers = best_hyperparams[split]
  for k, v in hypers.items():
      print("\t-", k, ":", v)

  kwargs.split = split
  kwargs.featurization = "ec_bit_fp"
  kwargs.learning_rate = hypers["learning_rate"]
  kwargs.num_layers = hypers["num_layers"]
  kwargs.embed_dim = hypers["embed_dim"]
  kwargs.prior_cov = hypers["prior_cov"]
  kwargs.n_context_points = hypers["n_context_points"]
  kwargs.datadir = "/content/Q-SAVI/data/datasets"

  # rerun Q-SAVI with 10 different random seeds
  qsavi_results = []

  for i in range(10):

    kwargs.seed = i

    print(
        "\n\nFull input arguments:",
        json.dumps(vars(kwargs), indent=4, separators=(",", ":")),
        "\n\n",
    )

    qsavi = QSAVI(kwargs)
    val_metrics, test_metrics = qsavi.train()

    qsavi_results.append({"split": split, **test_metrics})

  return qsavi_results

In [None]:
%%capture

# this takes around 25 minutes for all spits,
# remove the %%capture to see progress logs

results = []

for split in ["spectral_split", "mw_split", "scaffold_split", "random_split"]:
  results.extend(run_qsavi(split))

## Model Evaluation

In [None]:
results = pd.DataFrame(results)
results["auc_roc"] = results.apply(lambda x: roc_auc_score(x["labels"], x["preds"]), axis=1)
results["brier"] = results.apply(lambda x: brier_score_loss(x["labels"], x["preds"]), axis=1)

In [None]:
pd.options.display.float_format = '{:.3f}'.format

results.groupby("split")[["auc_roc", "brier"]].agg(
    ["mean", lambda x: np.std(x)/np.sqrt(len(x))]
    ).rename(columns={"<lambda_0>": "standard error"})

Unnamed: 0_level_0,auc_roc,auc_roc,brier,brier
Unnamed: 0_level_1,mean,standard error,mean,standard error
split,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
mw_split,0.65,0.002,0.047,0.0
random_split,0.708,0.001,0.088,0.0
scaffold_split,0.657,0.004,0.102,0.0
spectral_split,0.606,0.003,0.13,0.0
