# Drug-Target Interaction Prediction in ChEMBL using Distributed GNN based Link Prediction

In this notebook, we demonstrate how to use `Katana Graph Platform` to implement a distributed GNN based model for performing drug-target interactions predictions pipeline. We pose this as a link prediction problem. For this demo, we would be using the `chembl29` dataset which consists of roughly 2.1M drug or compound nodes and 15K target or gene nodes. Overall, there are 3.2M connections between drug and target pairs. The input for the GNN model would be the `chembl29` bipartite graph and it would predict the interaction between the corresponding pairs. We pose this as a `classification` problem in this notebook. 

There are five distinct steps of a ML pipeline.

* **[Step 1: Katana Setup and Data Loading.](#step1)** Set up a `Katana Client` and load the data from a collections of csvs into a `rdg`.
* **[Step 2: Data Preprocessing.](#step2)** Generate features for drugs and target data and store them into the `rdg`.
* **[Step 3: Data Splitting into Train-Val-Test.](#step3)** Partitioing compounds into disjoint sets for developing generalizable models.
* **[Step 4: Setting up Components of AI Training Pipeline.](#step4)** Defining GNN models, dataset abstractions, and trainer abstractions.
* **[Step 5: Putting Everything Together.](#step5)** Defining a remote function with all the peices together.

## Requirements

Before you begin, make sure you meet these prerequisites:

* A running [Katana cluster](../../getting-started/index.rst) (cloud or local deployment).


<a id='step1'></a>
## Step 1. Katana Setup and Data Loading

Once you have a running katana cluster, you would need to initialize a `Client`. This would require setting up the address of the Katana Server in the current environment.

In [None]:
import os
import warnings

warnings.filterwarnings("ignore")
os.environ["MODIN_ENGINE"] = "python"

# Docker container on macOS or Windows
# os.environ["KATANA_SERVER_ADDRESS"] = "host.docker.internal:8080"

# Docker container on Linux
# os.environ["KATANA_SERVER_ADDRESS"] = "localhost:8080"

print("--")


### Load Libraries and Initialize the Katana `Client`

Starting a Katana remote Client is required to interface with the Katana remote service and schedule distributed operations.

In [None]:
# Connect to the Katana Server
from katana import remote
from katana.remote import import_data

client = remote.Client()

print("--")


### Load data

Create a new graph with four partitions and load the example `chembl29` graph.

In [None]:
%%time

#  MMM
# g = client.create_graph(num_partitions=4)       
g = client.create_graph(num_partitions=3)

import_data.csv(
    g,
    input_node_path="gs://katana-demo-datasets/csv-datasets/chembl29/nodes.txt",
    input_edge_path="gs://katana-demo-datasets/csv-datasets/chembl29/edges.txt",
    input_dir="gs://katana-demo-datasets/csv-datasets/chembl29/",
    data_delimiter=",",
    schema_delimiter=",",
)

print("--")


### Inspect the Graph Schema and the Graph

In [None]:
%%time
print(f"The graph has {g.num_nodes()} nodes and {g.num_edges()} edges.")

<a id='step2'></a>
## Step 2: Data Preprocessing: Feature Generation and Stroing

There are two types of nodes in the `chembl29` dataset: `compound` and `targets`. For running ML tasks on this dataset, we first need to add numerical features to the nodes on which the ML models will operate. In this section, we show how to add features to specific node types. 

We provide the `HlsPreprocessingGraph` class to perform preprocessing and save features as a property on the graph's nodes. We use a custom user defined featurizer that randomly assigns a numpy array as a feature. We use this option in the demo as it is extremeley fast.

Note: The custom featurizer function used with upsert_featurizer_feature needs to be deterministic.

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports)
import deepchem
import numpy
from katana.ai.hls.hls_preprocessing_graph import HlsPreprocessingGraph
from katana.distributed import Graph


def remote_preprocessing(graph: Graph):

    feat_obj = HlsPreprocessingGraph(graph)

    # Random feature
    def random_feature_generator(_):
        nr_features = 100
        rng = numpy.random.default_rng()
        return rng.random(size=(nr_features,), dtype=numpy.float32)

    def random_smiles_feature_generator(_):
        nr_features = 2
        rng = numpy.random.default_rng()
        return rng.random(size=(nr_features,), dtype=numpy.float32)

    feat_obj.upsert_featurizer_feature(
        in_feature_name="chembl_id",
        out_feature_name="random_100",
        featurizer=random_feature_generator,
        node_types=["compound", "target"],
    )

    feat_obj.upsert_featurizer_feature(
        in_feature_name="canonical_smiles",
        out_feature_name="random_2",
        featurizer=random_smiles_feature_generator,
        node_types=["compound"],
    )


g.run(remote_preprocessing)

print("--")


#### Adding embedding features required by the sampler/dataloader

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports), W0404 (reimport)
from katana_enterprise.ai.preprocessing.preprocessing_graph import PreprocessingGraph
from katana_enterprise.distributed import Graph


def generate_features(graph):
    feature_obj = PreprocessingGraph(graph)
    # Node labels, not used but have to be specified for later in the pipeline.
    label_data = numpy.zeros((len(graph.nodes()),)).reshape((-1, 1))
    feature_obj.upsert_node_feature(feature_name="embedding", feature_data=label_data)


g.run(generate_features)

print("--")


#### Testing to make sure the features are correctly added.

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports), W0404 (reimport)
from katana_enterprise.ai.preprocessing.preprocessing_graph import PreprocessingGraph
from katana_enterprise.distributed import Graph


def test_featurization(graph: Graph):

    preproc_graph = PreprocessingGraph(graph)
    random_feat = preproc_graph.get_node_feature(feature_name="random_100")
    assert random_feat.shape[1] == 100

    embed_feat = preproc_graph.get_node_feature(feature_name="embedding")
    assert len(embed_feat.shape) == 1


g.run(test_featurization)

print("--")


<a id='step3'></a>
## Step 3: Data Splitting into Train-Val-Test

In the Drug-Target Inetraction (DTI) prediction tasks, we extract a set of triplets of `(drug, target, label)` from the `rdg` as our input data. In order to test the generalizability of the ML models, we split the input data ensuring that a drug (or compound) is only present is one of the `train`, `validation`, and `test` sets. This would ensure that the trivial predictions (for eg, two very similar drugs will have very similar interactions with the same target) are not counted towards model performance. In this section, we show how to split the data into train-val-test splits, how to save the split information as properties on the compound nodes, and finally how to use cypher queries to extarct ML-ready distributed data.

### Partitioing Compounds into Disjoint Sets:
To achieve this split, we create a `split_label` property on the drug (compound) nodes. This property can have three values: `0`, `1`, and `2` indicating the drug to be part of the train, validation, and test set respectively.

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports), W0404 (reimport)
import dataclasses
from enum import IntEnum

from katana_enterprise.ai.preprocessing import RandomSplitter
from katana_enterprise.ai.preprocessing.preprocessing_graph import PreprocessingGraph
from katana_enterprise.distributed import Graph


@dataclasses.dataclass
class SplitConfig:
    train_frac: float = 0.7
    val_frac: float = 0.1
    test_frac: float = 0.2


class SplitType(IntEnum):
    """SplitType encodes training, validation, and test data separation"""

    TRAIN = 0
    VAL = 1
    TEST = 2


def split_generator(graph: Graph, split_config: SplitConfig):

    feat_obj = PreprocessingGraph(graph)

    # Generate training, validation, and test splits
    split_arr = feat_obj.generate_split_property(
        target_property_name="random_2",
        split_encoder=RandomSplitter(
            split_ratio=[split_config.train_frac, split_config.val_frac, split_config.test_frac], random_state=42
        ),
    )
    # Commit the splits to the graph
    feat_obj.upsert_node_feature(feature_name="split_label", feature_data=split_arr)


# Set up the split configuration
split_config = SplitConfig(train_frac=0.7, val_frac=0.1, test_frac=0.2)


# Updating the split through a remote function
g.run(lambda g: split_generator(g, split_config))

print("--")


#### Testing the Split Generation Method

In [None]:
%%time
import numpy


def query_str(split_type: SplitType):

    query_str = f"""MATCH (a:compound)
    WHERE a.split_label = {split_type}
    RETURN count(a) as count
    """
    return query_str


num_compounds = g.query("""MATCH (a:compound) Return count(a) as count""")["count"][0]
train_count = g.query(query_str(SplitType.TRAIN))["count"][0]
val_count = g.query(query_str(SplitType.VAL))["count"][0]
test_count = g.query(query_str(SplitType.TEST))["count"][0]


assert train_count + val_count + test_count == num_compounds
assert numpy.isclose(split_config.train_frac, train_count / num_compounds)
assert numpy.isclose(split_config.val_frac, val_count / num_compounds)
assert numpy.isclose(split_config.test_frac, test_count / num_compounds)

print("--")


<a id='step4'></a> 
## Step 4: Setting up Components of AI Training Pipeline.

<a id='query'></a>
### Step 4.1 Query to Extract ML-ready Data with Train-Val-Test Split

Once we have saved the split labels on the compound data, we can extract a ML-ready dataset consisting of `(compound, target, label)` using cypher query. We will execute this query from the distributed graph which would return a distributed table consisting of triplets. We use a `threshold` on the `pchembl_value` to create a binary classification label on the egdes.

In [None]:
%%time
def build_query_str(split_type: SplitType, num_samples, threshold=5):
    query_str = f"""MATCH (a:compound)-[r]->(b:target)
    WHERE r.pchembl_value is not NULL and a.split_label = {split_type}
    WITH a, b, r.pchembl_value > {threshold} as label
    RETURN a as src, b as dst, label
    LIMIT {num_samples}
    """

    return query_str


g.query(build_query_str(SplitType.TRAIN, 20000))

print("--")


<a id='trainer'></a>
### Step 4.2: Training Abstractions and Helper Functions


#### An Utility Function to Share the Validation Data Across All Hosts

This method is only required we are interested in `classification` tasks and our validation metric function is `non-decomposable` such as `roc-auc` or `pr-auc`. In this scenario, the validation metric computed in each host can not be combined together in order to reduce to a single score. There are two possible approaches to handle this situation. 
1. All hosts communicate the predictions on the validation data aftert each epoch to a single source host and the reduced validation metric is computed on the source host. This would increase communication after each epoch.
2. All hosts share the same validation data. This increases the communication one-time before the training starts, however during the training there is no additional communication overhead.

In this notebook, we adopt the second approach.

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports), W0404 (reimport)
import pandas
import torch
from katana_enterprise.distributed import Graph
from katana_enterprise.distributed.pytorch import init_workers


def broadcast_val_data(val_df):
    """A function to gather validation data from all workers to the rank 0 process"""
    if not torch.distributed.is_initialized():
        init_workers()

    num_hosts = torch.distributed.get_world_size()
    gathered_val_data = [None for _ in range(num_hosts)]
    torch.distributed.all_gather_object(gathered_val_data, val_df)
    combined_val_df = pandas.concat(gathered_val_data)
    return combined_val_df


def test_broadcast_val_data(graph: Graph):
    query_str = build_query_str(SplitType.TEST, num_samples=100)
    df = graph.query(query_str, balance_output=True).to_pandas()
    assert len(df) == 25  # balance_output must ensure this

    df = broadcast_val_data(df)
    assert len(df) == 100


g.run(test_broadcast_val_data)

print("--")


In [None]:

#  Error from above
#
#  
Host 0 errors:
Traceback (most recent call last):
  File "/opt/miniconda/lib/python3.8/site-packages/katana_enterprise/worker/worker.py", line 86, in execute
    value = function(graph)
  File "<timed exec>", line 23, in test_broadcast_val_data
AssertionError

Host 1 errors:
Traceback (most recent call last):
  File "/opt/miniconda/lib/python3.8/site-packages/katana_enterprise/worker/worker.py", line 86, in execute
    value = function(graph)
  File "<timed exec>", line 23, in test_broadcast_val_data
AssertionError

Host 2 errors:
Traceback (most recent call last):
  File "/opt/miniconda/lib/python3.8/site-packages/katana_enterprise/worker/worker.py", line 86, in execute
    value = function(graph)
  File "<timed exec>", line 23, in test_broadcast_val_data
AssertionError

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File <timed exec>:29

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/async_to_sync.py:249, in AsyncToSync.<locals>.do_wrap.<locals>.wrapper(self, *args, **kwargs)
    246 @wraps(underlying_func)
    247 def wrapper(self, *args, **kwargs):
    248     return registry.async_to_sync(
--> 249         underlying_func(
    250             get_self_func(self),
    251             *(registry.sync_to_async(a) for a in args),
    252             **{k: registry.sync_to_async(v) for k, v in kwargs.items()},
    253         )
    254     )

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/async_to_sync.py:176, in async_to_sync.<locals>.wrapper(timeout, *args, **kwargs)
    166     registry = AsyncToSyncClassRegistry.get()
    167     return registry.async_to_sync(
    168         wait_for(
    169             async_func(
   (...)
    174         )
    175     )
--> 176 return wait_for(async_func(*args, **kwargs), timeout=timeout)

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/async_to_sync.py:147, in wait_for(coro, timeout)
    145 try:
    146     future = asyncio.run_coroutine_threadsafe(timeout_coro, loop=AsyncRunnerThread.get().loop)
--> 147     return future.result()
    148 except KeyboardInterrupt:
    149     inner_future.cancel()

File /opt/conda/lib/python3.8/concurrent/futures/_base.py:444, in Future.result(self, timeout)
    442     raise CancelledError()
    443 elif self._state == FINISHED:
--> 444     return self.__get_result()
    445 else:
    446     raise TimeoutError()

File /opt/conda/lib/python3.8/concurrent/futures/_base.py:389, in Future.__get_result(self)
    387 if self._exception:
    388     try:
--> 389         raise self._exception
    390     finally:
    391         # Break a reference cycle with the exception in self._exception
    392         self = None

File /opt/conda/lib/python3.8/asyncio/tasks.py:455, in wait_for(fut, timeout, loop)
    450     warnings.warn("The loop argument is deprecated since Python 3.8, "
    451                   "and scheduled for removal in Python 3.10.",
    452                   DeprecationWarning, stacklevel=2)
    454 if timeout is None:
--> 455     return await fut
    457 if timeout <= 0:
    458     fut = ensure_future(fut, loop=loop)

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/remote/aio/graph.py:509, in Graph.run(self, function)
    507 print(result.stdout, file=sys.stdout, end="")
    508 print(result.stderr, file=sys.stderr, end="")
--> 509 return result.value

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/remote/run_result.py:21, in RunResult.value(self)
     19 @property
     20 def value(self):
---> 21     self.reraise_if_error()
     22     assert self.success
     23     return self._value

File /opt/conda/lib/python3.8/site-packages/katana_enterprise/remote/run_result.py:17, in RunResult.reraise_if_error(self)
     15 def reraise_if_error(self):
     16     if not self.success:
---> 17         raise self._value

AssertionError: 
        

<a id='step5'></a>
## Step 5: Putting the Pipeline Together

We are now ready to put together a end-to-end pipeline for DTI task.

### Setting up Model Hyperparameters and Training Hyperparameters

In [None]:
# pylint: disable=W0404 (reimport), E0602(undefined-variable)
# pylint: disable=too-many-instance-attributes
import dataclasses
from typing import Callable


@dataclasses.dataclass
class TrainingHyperParams:
    num_train_samples: int = 10000
    num_val_samples: int = 1000
    epochs: int = 5
    patience: int = 50
    optimizer: Callable = torch.optim.Adam
    lr: float = 0.01
    weight_decay: float = 0.001
    scheduler: Callable = torch.optim.lr_scheduler.StepLR
    step_size: int = 5


# Setting up training hyper parameters
training_hp = TrainingHyperParams()

### The Training Pipeline

In [None]:
%%time
# pylint: disable=C0412 (ungrouped-imports), W0404 (reimport)
import numpy
import torch
from katana_enterprise.ai import data, loss, model, train
from katana_enterprise.ai.torch import ReduceMethod
from katana_enterprise.ai.train import DistTrainer
from katana_enterprise.distributed import Graph
from katana_enterprise.distributed.pytorch import init_workers
from sklearn.metrics import average_precision_score, mean_squared_error


def remote_pipeline(graph, training_params):
    if not torch.distributed.is_initialized():
        init_workers()

    # Initialize random edge samples from the graph using queries
    # -----------------------------------------------------------

    train_data = graph.query(build_query_str(SplitType.TRAIN, training_params.num_train_samples), balance_output=True)
    train_src_nodes, train_dst_nodes, train_labels = (
        list(train_data["src"]),
        list(train_data["dst"]),
        list(train_data["label"]),
    )
    train_seeds = [
        (u, v, label.as_py()) for _, (u, v, label) in enumerate(zip(train_src_nodes, train_dst_nodes, train_labels))
    ]

    val_data = graph.query(build_query_str(SplitType.VAL, training_params.num_val_samples), balance_output=True)
    # val_df = broadcast_val_data(val_data)

    val_src_nodes, val_dst_nodes, val_labels = list(val_data["src"]), list(train_data["dst"]), list(val_data["label"])
    val_seeds = [(u, v, label.as_py()) for _, (u, v, label) in enumerate(zip(val_src_nodes, val_dst_nodes, val_labels))]

    # train_samples = list(train_df.itertuples(index=False))
    # val_samples = list(val_df.itertuples(index=False))

    # Initialize the Katana multiminibatch sampler.
    # --------------------------------------------

    fan_in = [5, 10]
    batches_at_once = 10
    sampler_config = data.SampledSubgraphConfig(
        layer_fan=fan_in,
        max_minibatches=batches_at_once,
        property_batch_size=batches_at_once,
        feat_prop_name="random_100",
        multilayer_export=True,  # multi-layer export needs to be true to have a subgraph for each layer.
        sample_with_replacement=True,
        label_prop_name="random_100",
    )

    # Initialize the Katana Dataloader.
    # ---------------------------------

    train_sampler = data.LpSubgraphSampler(graph, sampler_config)
    train_dataloader = data.LpDataLoader(
        train_seeds, train_sampler, local_batch_size=50, shuffle=True, balance_seeds=True
    )

    val_sampler = data.LpSubgraphSampler(graph, sampler_config)
    val_dataloader = data.LpDataLoader(val_seeds, val_sampler, local_batch_size=50, shuffle=False, balance_seeds=True)

    # Define a DGL model wrapped in Katana
    # ------------------------------------

    nr_features = 100
    dgl_model = model.LpGcn(in_dim=nr_features, out_dim=2, gnn_hidden_dims=[256, 128], mlp_hidden_dims=[128],)

    # Define the loss function, the optimizers, the validaiton metric, and the dsitributed tracker
    # --------------------------------------------------------------------------------------------

    loss_fn = loss.CrossEntropyLoss([2])
    optimizer = training_params.optimizer(
        dgl_model.parameters(), lr=training_params.lr, weight_decay=training_params.weight_decay
    )
    scheduler = training_params.scheduler(optimizer, step_size=training_params.step_size)

    # Define the validation metric
    # y is a vector of length n and pred is a matrix of dimensions n by 2.
    # So to convert it to edges labels, we need to find the distance between the endpoints in pred.
    def pr_auc(y, pred):
        assert len(pred[0]) == 2
        edge_pred = torch.exp(pred[:, 1])
        return average_precision_score(y, edge_pred)

    # Define the tracker
    tracker = train.DistTracker(callback_fn=print, src_rank=0)

    # Initialize the distributed trainer
    # ---------------------------------

    trainer = DistTrainer(
        model=dgl_model,
        train_loss_fn=loss_fn,
        validation_metric_fn=pr_auc,
        validation_reduce_method=ReduceMethod.MEAN,
        train_loader=train_dataloader,
        validation_loader=val_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=training_params.epochs,
        patience=training_params.patience,
        maximization=True,
        tracker=tracker,
    )

    # Start the Model training
    # ------------------------

    trained_model, loss_val = trainer.train()
    print(f"Number of trained parameters: {sum(p.numel() for p in trained_model.parameters() if p.requires_grad)}")
    print(f"Trainer loss:{loss_val}")


#     # Time for inference
#     inference_sampler = data.PyGNodeSubgraphSampler(graph, sampler_config)
#     nodes_to_infer = graph.master_nodes()
#     inference_loader = data.NodeDataLoader(
#         inference_sampler, 1000, node_ids=nodes_to_infer, drop_last=False, balance_seeds=False
#     )
#     print("Starting the eval phase...")
#     trained_model.eval()
#     embedding = (
#         torch.vstack([trained_model.encode(data) for data in inference_loader if len(data.x) > 0]).detach().numpy()
#     )
#     print("Generated the node embeddings...")
#     print(embedding)


g.run(lambda g: remote_pipeline(g, training_hp))