# Run Inference with FuncX
Large-scale inference with FuncX on Theta

In [1]:
from funcx.sdk.client import FuncXClient
from datetime import datetime
from time import sleep
from rdkit import Chem
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import os



Configuration

In [2]:
batch_size = 4096
output_file = 'ena+db_tox21_screening.csv'

## Prepare the FuncX Client
This is what we'll be using to connect to Theta for sending/recieving tasks

In [3]:
fxc = FuncXClient()
theta_ep = 'd3a23590-3282-429a-8bce-e0ca0f4177f3'
with open('func_uuid.json') as fp:
    func_id = json.load(fp)
print(f'Running inference on {func_id}')

Running inference on 0549fe6d-7eae-4949-aee6-468658f2ce93


## Send Inference Requests
Send out inferenece requests for each dataset

Parse the data and make sure the SMILES are valid

In [4]:
database = pd.read_csv(os.path.join('..', '..', 'databases', 'ena+db.can'), header=None, delim_whitespace=True)
print(f'Loaded {len(database)} molecules')

Loaded 310782 molecules


In [5]:
database.rename(columns={0: 'smiles'}, inplace=True)

In [6]:
database.drop_duplicates('smiles', inplace=True)
print(f'Found {len(database)} unique SMILES')

Found 310650 unique SMILES


In [7]:
database['invalid'] = database['smiles'].apply(lambda x: Chem.MolFromSmiles(x) is None)

RDKit ERROR: [07:38:47] Explicit valence for atom # 3 O, 3, is greater than permitted
RDKit ERROR: [07:38:47] Explicit valence for atom # 12 Cl, 5, is greater than permitted
RDKit ERROR: [07:38:47] Explicit valence for atom # 19 N, 5, is greater than permitted
RDKit ERROR: [07:38:47] Explicit valence for atom # 16 Ga, 6, is greater than permitted
RDKit ERROR: [07:38:47] Can't kekulize mol.  Unkekulized atoms: 2 3 5 7 10 12 13 19 21 22 28 29
RDKit ERROR: 
RDKit ERROR: [07:38:47] Explicit valence for atom # 0 O, 3, is greater than permitted
RDKit ERROR: [07:38:47] Explicit valence for atom # 3 Be, 4, is greater than permitted
RDKit ERROR: [07:38:47] Can't kekulize mol.  Unkekulized atoms: 18 19 21
RDKit ERROR: 
RDKit ERROR: [07:38:48] Explicit valence for atom # 20 Be, 3, is greater than permitted
RDKit ERROR: [07:38:48] Explicit valence for atom # 24 N, 4, is greater than permitted
RDKit ERROR: [07:38:48] Explicit valence for atom # 1 Cl, 4, is greater than permitted
RDKit ERROR: [07:38

In [8]:
database.query('not invalid', inplace=True)
print(f'Found {len(database)} valid SMILES')

Found 310635 valid SMILES


Make the tasks

In [9]:
fxc.max_requests = 5000  # Enable faster task submission

In [10]:
db_tasks = []
for chunk in tqdm(np.array_split(database['smiles'], len(database) // batch_size)):
    db_tasks.append(fxc.run(chunk.tolist(), endpoint_id=theta_ep, function_id=func_id))
    sleep(0.1)
print(f'Submitted {len(db_tasks)} tasks')

100%|██████████| 75/75 [00:25<00:00,  2.94it/s]

Submitted 75 tasks





## Save Results
As results are returned, save them to disk

In [11]:
columns = ['smiles']

In [12]:
def write_results(status, path): 
    # Loop over all results in the status message
    for key, result in status.items():
        result = result['result']
        # Parse the data
        data = pd.DataFrame(result['result'])
        exists = os.path.isfile(path)
        
        # Get the runtime and save it
        data['task_id'] = key
        data['runtime'] = (datetime.fromisoformat(result['end']) - datetime.fromisoformat(result['start'])).total_seconds()
        data['start_time'] = result['start']
        data['end_time'] = result['end']
        data['host'] = result['hostname']
        data['core_count'] = result['core_count']
        
        # Save the result to disk
        data.to_csv(path, mode='a', header=not exists, index=False)

In [13]:
if os.path.isfile(output_file):
    print(f'Removing old file: {output_file}')
    os.unlink(output_file)

In [14]:
remaining_results = set(db_tasks)
pbar = tqdm(total=len(db_tasks))
while len(remaining_results) > 0:
    # Get the status of the current tasks
    status = fxc.get_batch_status(list(remaining_results))
    
    # Write the results to disk
    write_results(status, output_file)
    
    # Update the list of results that are remaining
    remaining_results.difference_update(status.keys())
    pbar.update(len(status))
    sleep(120)

100%|██████████| 75/75 [08:22<00:00,  4.52s/it]