# Run Inference with FuncX
This notebook is used to evaluate the performance of a large-scale inference run of predicting molecular toxicity using our graph-conv model.
The inferences are performed by sending batches of tasks to Theta via FuncX at various batch sizes, which is the main knob we can tinker with for the inference.

In [1]:
from funcx.sdk.client import FuncXClient
from datetime import datetime
from time import sleep
from rdkit import Chem
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import os



Configuration

In [2]:
batch_sizes = [128, 256, 512, 1024, 2048, 4096, 8192]
output_file = 'funcx_perf_test.csv'

## Prepare the FuncX Client
This is what we'll be using to connect to Theta for sending/recieving tasks

In [3]:
fxc = FuncXClient()
theta_ep = 'd3a23590-3282-429a-8bce-e0ca0f4177f3'
with open('func_uuid.json') as fp:
    func_id = json.load(fp)
print(f'Running inference on {func_id}')

Running inference on 627d9b72-8f4f-4020-9c76-696596e6eac8


## Send Inference Requests
Send out inferenece requests for each dataset

Parse the data and make sure the SMILES are valid

In [4]:
drugbank = pd.read_csv(os.path.join('..', 'databases', 'drugbank', 'smiles.txt'), header=None)
print(f'Loaded {len(drugbank)} molecules')

Loaded 9678 molecules


In [5]:
drugbank.rename(columns={0: 'smiles'}, inplace=True)

In [6]:
drugbank['smiles'] = drugbank['smiles'].apply(lambda x: x[:-8])

In [7]:
drugbank['invalid'] = drugbank['smiles'].apply(Chem.MolFromSmiles).apply(lambda x: x is None)

RDKit ERROR: [07:44:10] Explicit valence for atom # 2 O, 3, is greater than permitted
RDKit ERROR: [07:44:10] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [07:44:10] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [07:44:10] Explicit valence for atom # 0 N, 4, is greater than permitted
RDKit ERROR: [07:44:10] Explicit valence for atom # 13 Cl, 5, is greater than permitted
RDKit ERROR: [07:44:10] SMILES Parse Error: syntax error while parsing: OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]
RDKit ERROR: [07:44:10] SMILES Parse Error: Failed parsing SMILES 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]' for input: 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C

In [8]:
drugbank.query('not invalid', inplace=True)
print(f'Found {len(drugbank)} valid SMILES')

Found 9655 valid SMILES


Make the tasks

In [9]:
fxc.max_requests = 5000  # Enable faster task submission

In [10]:
db_tasks = []
for batch_size in batch_sizes:
    for chunk in np.array_split(drugbank['smiles'], len(drugbank) // batch_size):
        db_tasks.append(fxc.run(chunk.tolist(), endpoint_id=theta_ep, function_id=func_id))
        sleep(0.1)
print(f'Submitted {len(db_tasks)} tasks')

Submitted 146 tasks


## Save Results
As results are returned, save them to disk

In [11]:
columns = ['smiles']

In [12]:
def write_results(status, path): 
    # Loop over all results in the status message
    for key, result in status.items():
        result = result['result']
        # Parse the data
        data = pd.DataFrame(result['result'])
        exists = os.path.isfile(path)
        
        # Get the runtime and save it
        data['task_id'] = key
        data['runtime'] = (datetime.fromisoformat(result['end']) - datetime.fromisoformat(result['start'])).total_seconds()
        data['start_time'] = result['start']
        data['end_time'] = result['start']
        
        # Save the result to disk
        data.to_csv(path, mode='a', header=not exists, index=False)

In [14]:
remaining_results = set(db_tasks)
pbar = tqdm(total=len(db_tasks))
while len(remaining_results) > 0:
    # Get the status of the current tasks
    status = fxc.get_batch_status(list(remaining_results))
    
    # Write the results to disk
    write_results(status, output_file)
    
    # Update the list of results that are remaining
    remaining_results.difference_update(status.keys())
    pbar.update(len(status))
    sleep(15)


  0%|          | 0/146 [00:00<?, ?it/s][A
  0%|          | 0/146 [00:00<?, ?it/s][A
  0%|          | 0/146 [00:15<?, ?it/s][A
  0%|          | 0/146 [00:30<?, ?it/s][A
  0%|          | 0/146 [00:45<?, ?it/s][A
  0%|          | 0/146 [01:01<?, ?it/s][A
  0%|          | 0/146 [01:16<?, ?it/s][A
  2%|▏         | 3/146 [01:31<12:10,  5.11s/it][A
 32%|███▏      | 46/146 [01:47<06:08,  3.69s/it][A
 58%|█████▊    | 84/146 [02:03<02:47,  2.71s/it][A
 74%|███████▍  | 108/146 [02:19<01:19,  2.09s/it][A
 84%|████████▍ | 123/146 [02:35<00:40,  1.78s/it][A
 91%|█████████ | 133/146 [02:50<00:22,  1.71s/it][A
 95%|█████████▌| 139/146 [03:05<00:13,  1.96s/it][A
 97%|█████████▋| 142/146 [03:21<00:11,  2.91s/it][A
 99%|█████████▊| 144/146 [03:36<00:08,  4.35s/it][A
 99%|█████████▉| 145/146 [03:52<00:07,  7.67s/it][A
100%|██████████| 146/146 [04:07<00:00, 10.00s/it][A