In [10]:
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from rdkit import Chem

from synthemol.constants import (
    BUILDING_BLOCKS_PATH, FINGERPRINT_TYPES, MODEL_TYPES, OPTIMIZATION_TYPES, REAL_BUILDING_BLOCK_ID_COL, SCORE_COL, SMILES_COL
)
from synthemol.generate import Generator, Node, generate
from synthemol.generate.utils import create_model_scoring_fn, save_generated_molecules
from synthemol.models import chemprop_load, chemprop_load_scaler, chemprop_predict_on_molecule_ensemble
from synthemol.reactions import Reaction, QueryMol, set_all_building_blocks, load_and_set_allowed_reaction_building_blocks
from synthemol.utils import convert_to_mol, random_choice

# Helper function to print test results
def print_test_result(test_name: str, passed: bool) -> None:
    status = "PASSED" if passed else "FAILED"
    print(f"{test_name}: {status}")

# Test convert_to_mol function
def test_convert_to_mol() -> None:
    smiles = "CCO"
    mol = convert_to_mol(smiles)
    passed = mol is not None and Chem.MolToSmiles(mol) == smiles
    print_test_result("test_convert_to_mol", passed)

# Test random_choice function
def test_random_choice() -> None:
    rng = np.random.default_rng(seed=0)
    array = [1, 2, 3, 4, 5]
    choice = random_choice(rng, array)
    passed = choice in array
    print_test_result("test_random_choice", passed)

# Test QueryMol class
def test_query_mol() -> None:
    smarts = "[OH1][C:1]([*:2])=[O:3]"
    query_mol = QueryMol(smarts)
    # Test with a molecule that matches the SMARTS pattern
    passed = query_mol.has_match("CCOC(=O)O")  # Example molecule that matches
    print_test_result("test_query_mol", passed)

# Test Reaction class
def test_reaction() -> None:
    reactants = [QueryMol("[OH1][C:1]([*:2])=[O:3]")]
    product = QueryMol("[OH1][C:1]([*:2])=[O:3]")
    reaction = Reaction(reactants, product, reaction_id=1)
    passed = reaction.num_reactants == 1
    print_test_result("test_reaction", passed)

# Test Node class
def test_node() -> None:
    scoring_fn = lambda x: 1.0
    node = Node(explore_weight=1.0, scoring_fn=scoring_fn, node_id=1)
    passed = node.P == 0.0
    print_test_result("test_node", passed)

# Test Generator class
def test_generator() -> None:
    building_block_smiles_to_id = {"CCO": 1}
    scoring_fn = lambda x: 1.0
    generator = Generator(
        building_block_smiles_to_id=building_block_smiles_to_id,
        max_reactions=1,
        scoring_fn=scoring_fn,
        explore_weight=1.0,
        num_expand_nodes=None,
        optimization="maximize",
        reactions=(),
        rng_seed=0,
        no_building_block_diversity=False,
        store_nodes=False,
        verbose=False
    )
    passed = generator is not None
    print_test_result("test_generator", passed)

# Test Chemprop model loading
def test_chemprop_load() -> None:
    model_path = Path("path_to_chemprop_model.pt")
    if model_path.exists():
        try:
            model = chemprop_load(model_path)
            passed = model is not None
            print_test_result("test_chemprop_load", passed)
        except Exception as e:
            print(f"Error loading Chemprop model: {e}")
            print_test_result("test_chemprop_load", False)
    else:
        print("Chemprop model file not found. Skipping test.")
        print_test_result("test_chemprop_load", False)

# Test Chemprop prediction
def test_chemprop_predict_on_molecule_ensemble() -> None:
    model_path = Path("path_to_chemprop_model.pt")
    if model_path.exists():
        try:
            model = chemprop_load(model_path)
            scaler = chemprop_load_scaler(model_path)
            smiles = "CCO"
            prediction = chemprop_predict_on_molecule_ensemble([model], smiles, scalers=[scaler])
            passed = isinstance(prediction, float)
            print_test_result("test_chemprop_predict_on_molecule_ensemble", passed)
        except Exception as e:
            print(f"Error predicting with Chemprop model: {e}")
            print_test_result("test_chemprop_predict_on_molecule_ensemble", False)
    else:
        print("Chemprop model file not found. Skipping test.")
        print_test_result("test_chemprop_predict_on_molecule_ensemble", False)

# Test full molecule generation pipeline
def test_generate_molecules() -> None:
    model_path = Path("path_to_model")
    save_dir = Path("test_output")
    save_dir.mkdir(exist_ok=True)

    try:
        # Load building blocks data
        building_blocks_data = pd.read_csv(BUILDING_BLOCKS_PATH)
        
        # Check if the 'score' column exists
        if 'score' not in building_blocks_data.columns:
            print("Warning: 'score' column not found in building_blocks.csv. Using a default score of 1.0.")
            building_blocks_data['score'] = 1.0  # Add a default score column

        # Map building blocks SMILES to IDs, IDs to SMILES, and SMILES to scores
        building_block_smiles_to_id = dict(zip(
            building_blocks_data[SMILES_COL],
            building_blocks_data[REAL_BUILDING_BLOCK_ID_COL]
        ))
        building_block_id_to_smiles = dict(zip(
            building_blocks_data[REAL_BUILDING_BLOCK_ID_COL],
            building_blocks_data[SMILES_COL]
        ))
        building_block_smiles_to_score = dict(zip(
            building_blocks_data[SMILES_COL],
            building_blocks_data['score']
        ))

        # Define model scoring function
        model_scoring_fn = create_model_scoring_fn(
            model_path=model_path,
            model_type="random_forest",
            fingerprint_type="morgan",  # Specify fingerprint_type for scikit-learn models
            smiles_to_score=building_block_smiles_to_score
        )

        # Set up Generator
        generator = Generator(
            building_block_smiles_to_id=building_block_smiles_to_id,
            max_reactions=1,
            scoring_fn=model_scoring_fn,
            explore_weight=1.0,
            num_expand_nodes=None,
            optimization="maximize",
            reactions=(),
            rng_seed=0,
            no_building_block_diversity=False,
            store_nodes=False,
            verbose=False
        )

        # Generate molecules
        nodes = generator.generate(n_rollout=1)

        # Save generated molecules
        save_generated_molecules(
            nodes=nodes,
            building_block_id_to_smiles=building_block_id_to_smiles,
            save_path=save_dir / "molecules.csv"
        )

        passed = (save_dir / "molecules.csv").exists()
        print_test_result("test_generate_molecules", passed)

    except Exception as e:
        print(f"Error in test_generate_molecules: {e}")
        print_test_result("test_generate_molecules", False)

# Run all tests
def run_tests() -> None:
    test_convert_to_mol()
    test_random_choice()
    test_query_mol()
    test_reaction()
    test_node()
    test_generator()
    test_chemprop_load()
    test_chemprop_predict_on_molecule_ensemble()
    test_generate_molecules()

if __name__ == "__main__":
    run_tests()

test_convert_to_mol: PASSED
test_random_choice: PASSED
test_query_mol: PASSED
test_reaction: PASSED
test_node: PASSED
test_generator: PASSED
Chemprop model file not found. Skipping test.
test_chemprop_load: FAILED
Chemprop model file not found. Skipping test.
test_chemprop_predict_on_molecule_ensemble: FAILED
Error in test_generate_molecules: [Errno 2] No such file or directory: 'path_to_model'
test_generate_molecules: FAILED
