In [5]:
# test_script.py

import requests
from schemas import InferenceInput, InferenceConfig
import time

# Define the URL of the FastAPI server
api_url = "http://127.0.0.1:8000/inference/"
progress_url = "http://127.0.0.1:8000/inference/progress/"

# Example input data
inference_input = InferenceInput(
    protein_path="examples/6ahs_protein_processed.pdb",
    ligand_description="examples/6ahs_ligand.sdf"
)

inference_config = InferenceConfig(
    actual_steps=19,
    ckpt="best_ema_inference_epoch_model.pt",
    confidence_ckpt="best_model_epoch75.pt",
    confidence_model_dir="./workdir/v1.1/confidence_model",
    different_schedules=False,
    inf_sched_alpha=1,
    inf_sched_beta=1,
    inference_steps=20,
    initial_noise_std_proportion=1.4601642460337794,
    limit_failures=5,
    model_dir="./workdir/v1.1/score_model",
    no_final_step_noise=True,
    no_model=False,
    no_random=False,
    no_random_pocket=False,
    ode=False,
    old_filtering_model=True,
    old_score_model=False,
    resample_rdkit=False,
    samples_per_complex=10,
    sigma_schedule="expbeta",
    temp_psi_rot=0.9022615585677628,
    temp_psi_tor=0.5946212391366862,
    temp_psi_tr=0.727287304570729,
    temp_sampling_rot=2.06391612594481,
    temp_sampling_tor=7.044261621607846,
    temp_sampling_tr=1.170050527854316,
    temp_sigma_data_rot=0.7464326999906034,
    temp_sigma_data_tor=0.6943254174849822,
    temp_sigma_data_tr=0.9299802531572672,
    loglevel="WARNING",
    choose_residue=False,
    out_dir="results/user_inference",
    save_visualisation=False,
    batch_size=10
)

def test_inference(api_url, progress_url, inference_input, inference_config):
    response = requests.post(api_url, json={"input": inference_input.dict(), "config": inference_config.dict()})
    if response.status_code == 200:
        print("Inference initiated successfully")
        response_data = response.json()
        task_id = response_data["task_id"]
        print("Task ID:", task_id)
        
        # Check progress periodically
        while True:
            progress_response = requests.get(f"{progress_url}{task_id}")
            if progress_response.status_code == 200:
                progress_data = progress_response.json()
                print("Progress:", progress_data["progress"])
                if "Post-processing Results" in progress_data["progress"]:
                    break
            else:
                print("Failed to get progress")
            time.sleep(5)
            
    else:
        print("Failed to start inference")
        print("Status Code:", response.status_code)
        print("Response:", response.text)

if __name__ == "__main__":
    test_inference(api_url, progress_url, inference_input, inference_config)

Failed to start inference
Status Code: 422
Response Body: {'detail': [{'type': 'missing', 'loc': ['body', 'config'], 'msg': 'Field required', 'input': None, 'url': 'https://errors.pydantic.dev/2.5/v/missing'}]}


# Imports


In [2]:

import time
import os
from schemas import InferenceInput, InferenceConfig

# API URLs
api_url = "http://127.0.0.1:8000/inference/"
progress_url = "http://127.0.0.1:8000/inference/progress/"
download_url = "http://127.0.0.1:8000/inference/download/"

# Path to the ZIP file
zip_file_path = "/root/projects/DiffDock/data/1a0q_test_data.zip"

def upload_zip_and_start_inference(api_url, zip_file_path):
    files = {'file': ('archive.zip', open(zip_file_path, 'rb'), 'application/zip')}
    response = requests.post(api_url, files=files)
    if response.status_code == 200:
        print("Inference initiated successfully")
        return response.json()["task_id"]
    else:
        print("Failed to start inference")
        print("Status Code:", response.status_code)
        print("Response:", response.text)
        return None

def monitor_progress(progress_url, task_id):
    while True:
        progress_response = requests.get(f"{progress_url}{task_id}")
        if progress_response.status_code == 200:
            progress_data = progress_response.json()
            print("Progress:", progress_data["progress"])
            if progress_data["progress"] == "Completed":
                return True
        else:
            print("Failed to get progress")
            time.sleep(5)
def download_results(download_url, task_id):
    download_response = requests.get(f"{download_url}{task_id}")
    if download_response.status_code == 200:
        with open(f"{task_id}_output.zip", "wb") as f:
            f.write(download_response.content)
        print(f"Results downloaded and saved as {task_id}_output.zip")
    else:
        print("Failed to download results")
        print("Status Code:", download_response.status_code)
        print("Response:", download_response.text)
if __name__ == "__main__":
    task_id = upload_zip_and_start_inference(api_url, zip_file_path)
    if task_id:
        completed = monitor_progress(progress_url, task_id)
        if completed:
            download_results(download_url, task_id)


Failed to start inference
Status Code: 422
Response: {"detail":[{"type":"missing","loc":["body","config"],"msg":"Field required","input":null,"url":"https://errors.pydantic.dev/2.5/v/missing"}]}


In [5]:

import requests
from schemas import InferenceInput, InferenceConfig
import time

# Define the URL of the FastAPI server
api_url = "http://127.0.0.1:8000/inference/"
progress_url = "http://127.0.0.1:8000/inference/progress/"
download_url = "http://127.0.0.1:8000/inference/download/"

# Example input data
inference_input = InferenceInput(
    protein_path="examples/1a46_protein_processed.pdb",
    ligand_description="examples/1a46_ligand.sdf"
)

inference_config = InferenceConfig(
    actual_steps=19,
    ckpt="best_ema_inference_epoch_model.pt",
    confidence_ckpt="best_model_epoch75.pt",
    confidence_model_dir="./workdir/v1.1/confidence_model",
    different_schedules=False,
    inf_sched_alpha=1,
    inf_sched_beta=1,
    inference_steps=20,
    initial_noise_std_proportion=1.4601642460337794,
    limit_failures=5,
    model_dir="./workdir/v1.1/score_model",
    no_final_step_noise=True,
    no_model=False,
    no_random=False,
    no_random_pocket=False,
    ode=False,
    old_filtering_model=True,
    old_score_model=False,
    resample_rdkit=False,
    samples_per_complex=10,
    sigma_schedule="expbeta",
    temp_psi_rot=0.9022615585677628,
    temp_psi_tor=0.5946212391366862,
    temp_psi_tr=0.727287304570729,
    temp_sampling_rot=2.06391612594481,
    temp_sampling_tor=7.044261621607846,
    temp_sampling_tr=1.170050527854316,
    temp_sigma_data_rot=0.7464326999906034,
    temp_sigma_data_tor=0.6943254174849822,
    temp_sigma_data_tr=0.9299802531572672,
    loglevel="WARNING",
    choose_residue=False,
    out_dir="results/user_inference",
    save_visualisation=False,
    batch_size=10
)

def test_inference(api_url, progress_url, download_url, inference_input, inference_config):
    response = requests.post(api_url, json={"input": inference_input.dict(), "config": inference_config.dict()})
    if response.status_code == 200:
        print("Inference initiated successfully")
        response_data = response.json()
        task_id = response_data["task_id"]
        print("Task ID:", task_id)
        
        # Check progress periodically
        while True:
            progress_response = requests.get(f"{progress_url}{task_id}")
            if progress_response.status_code == 200:
                progress_data = progress_response.json()
                print("Progress:", progress_data["progress"])
                if progress_data["progress"] == "Completed":
                    break
            else:
                print("Failed to get progress")
            time.sleep(5)
            
        # Download results
        download_response = requests.get(f"{download_url}{task_id}")
        if download_response.status_code == 200:
            with open(f"{task_id}_output.zip", "wb") as f:
                f.write(download_response.content)
            print(f"Results downloaded and saved as {task_id}_output.zip")
        else:
            print("Failed to download results")
            print("Status Code:", download_response.status_code)
            print("Response:", download_response.text)
            
    else:
        print("Failed to start inference")
        print("Status Code:", response.status_code)
        print("Response:", response.text)

if __name__ == "__main__":
    test_inference(api_url, progress_url, download_url, inference_input, inference_config)


Inference initiated successfully
Task ID: 03f9824f-5945-4339-8b74-865d35b870c0
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
Progress: No such task
P

ConnectionError: HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /inference/progress/03f9824f-5945-4339-8b74-865d35b870c0 (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f2c9047a0d0>: Failed to establish a new connection: [Errno 111] Connection refused'))