# Imports


In [6]:
# test_script.py

import requests
from schemas import InferenceInput, InferenceConfig
import time

# Define the URL of the FastAPI server
api_url = "http://127.0.0.1:8000/inference/"
progress_url = "http://127.0.0.1:8000/inference/progress/"

# Example input data
inference_input = InferenceInput(
    protein_path="examples/6ahs_protein_processed.pdb",
    ligand_description="examples/6ahs_ligand.sdf"
)

inference_config = InferenceConfig(
    actual_steps=19,
    ckpt="best_ema_inference_epoch_model.pt",
    confidence_ckpt="best_model_epoch75.pt",
    confidence_model_dir="./workdir/v1.1/confidence_model",
    different_schedules=False,
    inf_sched_alpha=1,
    inf_sched_beta=1,
    inference_steps=20,
    initial_noise_std_proportion=1.4601642460337794,
    limit_failures=5,
    model_dir="./workdir/v1.1/score_model",
    no_final_step_noise=True,
    no_model=False,
    no_random=False,
    no_random_pocket=False,
    ode=False,
    old_filtering_model=True,
    old_score_model=False,
    resample_rdkit=False,
    samples_per_complex=10,
    sigma_schedule="expbeta",
    temp_psi_rot=0.9022615585677628,
    temp_psi_tor=0.5946212391366862,
    temp_psi_tr=0.727287304570729,
    temp_sampling_rot=2.06391612594481,
    temp_sampling_tor=7.044261621607846,
    temp_sampling_tr=1.170050527854316,
    temp_sigma_data_rot=0.7464326999906034,
    temp_sigma_data_tor=0.6943254174849822,
    temp_sigma_data_tr=0.9299802531572672,
    loglevel="WARNING",
    choose_residue=False,
    out_dir="results/user_inference",
    save_visualisation=False,
    batch_size=10
)

def test_inference(api_url, progress_url, inference_input, inference_config):
    response = requests.post(api_url, json={"input": inference_input.dict(), "config": inference_config.dict()})
    if response.status_code == 200:
        print("Inference initiated successfully")
        response_data = response.json()
        task_id = response_data["task_id"]
        print("Task ID:", task_id)
        
        # Check progress periodically
        while True:
            progress_response = requests.get(f"{progress_url}{task_id}")
            if progress_response.status_code == 200:
                progress_data = progress_response.json()
                print("Progress:", progress_data["progress"])
                if "Post-processing Results" in progress_data["progress"]:
                    break
            else:
                print("Failed to get progress")
            time.sleep(5)
            
    else:
        print("Failed to start inference")
        print("Status Code:", response.status_code)
        print("Response:", response.text)

if __name__ == "__main__":
    test_inference(api_url, progress_url, inference_input, inference_config)


Inference initiated successfully
Task ID: 5d2dbacb-b806-46b8-8643-ff218e07c3c6
Progress: No such task
Progress: No such task
Progress: No such task


KeyboardInterrupt: 