In [1]:
from litellm import completion
from typing import Optional
import litellm
from dotenv import load_dotenv
import rootutils

root_path = rootutils.setup_root(".", indicator=".project-root", pythonpath=True)

In [2]:
root_path

WindowsPath('C:/Users/irahu/git_workspace/recursiveLLM')

## Imports

In [3]:
# from tests.variables_test import CLAUDE_ADV_MODEL
# from scipy.sparse import linalg
from src.prithvi import run_prithvi

In [4]:
DEEPSEEK_FIREWORKS_MODEL = "fireworks_ai/accounts/fireworks/models/deepseek-r1:adv"
molecule = "NC(=O)C1=NC(F)=CN=C1O"

In [None]:
""" Module to run prithvi retrosynthesis on a molecule """
import time
import os
import rootutils
import structlog

from src.utils.parse import format_output
from src.rec_prithvi import rec_run_prithvi
from src.utils.job_context import logger as context_logger
from src.utils.custom_logging import add_job_specific_handler
from src.metadata import reagent_agent, conditions_agent, literature_agent

root_dir = rootutils.setup_root(__file__,
                                indicator=".project-root",
                                pythonpath=True)

date_dir = f'{root_dir}/logs/{time.strftime("%Y-%m-%d")}'

ENABLE_LOGGING = False if os.getenv("ENABLE_LOGGING",
                                    "true").lower() == "false" else True


def log_message(message: str, logger=None):
    """Log the message

    Parameters
    ----------
    message : str
        The message to be logged
    logger : _type_, optional
        The logger object, by default None

    Returns
    -------
    None
    """
    if logger is not None:
        logger.info(message)
    else:
        print(message)


def run_prithvi(molecule: str,
                llm="claude-3-opus-20240229",
                az_model: str = "USPTO") -> dict:
    """Run prithvi services to generate retrosynthesis on a molecule.

    Parameters
    ----------
    molecule : str
        SMILE String of the molecule.
    llm : str, optional
        LLM Model, by default "claude-3-opus-20240229"

    Returns
    -------
    dict
        Result after running prithvi.
    """

    logger = context_logger.get() if ENABLE_LOGGING else None
    log_message(f"Calling {llm} with molecule: {molecule}", logger)

    # Generate a unique job ID using timestamp and a random suffix
    job_id = f"{time.strftime('%Y%m%d_%H%M%S')}_{os.getpid()}"

    job_log_file = f"{date_dir}/job_{job_id}.log"
    log = structlog.get_logger().bind(job_id=job_id)
    # Set the logger in the context variable
    token = context_logger.set(log)

    # Add job-specific handler
    handler = add_job_specific_handler(log, job_id)
    # log.info(f"Starting new synthesis job {job_id} for molecule {molecule}")
    log_message(f"Starting new synthesis job {job_id} for molecule {molecule}")

    try:
        result_dict, _ = rec_run_prithvi(molecule=molecule,
                                         job_id=job_id,
                                         llm=llm,
                                         az_model=az_model)
        output_data = format_output(result_dict)
        output_data = add_metadata(output_data)
        return output_data
    finally:
        # Clean up handlers
        log._logger.removeHandler(handler)
        handler.close()
        context_logger.reset(token)


def add_metadata(output_data: dict) -> dict:
    """method to add metadata to reaction metrics

    Parameters
    ----------
    output_data : dict
        json output without metadata

    Returns
    -------
    dict
        json output with metadata
    """
    for idx, step in enumerate(output_data['steps']):
        status, reagents = reagent_agent(step['reactants'], step['products'])
        output_data['steps'][idx]['reagents'].extend(reagents)

        status, conditions = conditions_agent(step['reactants'],
                                              step['products'],
                                              step['reagents'])
        output_data['steps'][idx]['conditions'] = conditions

        status, literature = literature_agent(step['reactants'],
                                              step['products'],
                                              step['reagents'],
                                              step['conditions'])
        output_data['steps'][idx]['reactionmetrics'][0][
            'closestliterature'] = literature

    return output_data


In [None]:
""" Recursive function to run Prithvi on a molecule """

from src.utils.llm import llm_pipeline
from src.utils.az import run_az
from src.utils.job_context import logger as context_logger


def rec_run_prithvi(molecule: str,
                    job_id: str,
                    llm: str = "claude-3-opus-20240229",
                    az_model: str = "USPTO") -> tuple[dict, bool]:
    """Recursive function to run Prithvi on a molecule

    Parameters
    ----------
    molecule : str
        Molecule SMILES
    job_id : str
        Job ID
    llm : str, optional
        LLM to be used, by default "claude-3-opus-20240229"

    Returns
    -------
    tuple(dict, bool)
        result_dict: result of retrosynthesis.
        solved: boolean value indicating if the molecule was solved.
    """
    solved, result_dict = run_az(smiles=molecule, az_model=az_model)
    result_dict = result_dict[0]
    logger = context_logger.get()
    if not solved:
        logger.info(f"AZ failed for {molecule}, running LLM")
        out_pathways, out_explained, out_confidence = llm_pipeline(
            molecule, llm)
        result_dict = {
            'type':
            'mol',
            'smiles':
            molecule,
            # 'confidence': out_confidence,
            "is_chemical":
            True,
            "in_stock":
            False,
            'children': [{
                "type": "reaction",
                "is_reaction": True,
                "metadata": {
                    "policy_probability": out_confidence,
                },
                "children": []
            }]
        }
        logger.info(f"LLM returned {out_pathways}")
        logger.info(f"LLM explained {out_explained}")
        for pathway in out_pathways:
            if isinstance(pathway, list):
                temp_stat = []
                for mol in pathway:
                    res, stat = rec_run_prithvi(mol, job_id, llm, az_model=az_model)
                    if stat:
                        temp_stat.append(True)
                        result_dict['children'][0]['children'].append(res)
                logger.info(f"temp_stat: {temp_stat}")
                if all(temp_stat):
                    solved = True
            else:
                res, solved = rec_run_prithvi(pathway, job_id, llm, az_model=az_model)
                result_dict['children'][0]['children'].append(res)
            if solved:
                logger.info('breaking')
                break
    else:
        logger.info(f"AZ solved {molecule}")
    # print(f"Solved : {solved}, Returning {result_dict}")
    return result_dict, solved


In [14]:
res = run_prithvi(molecule=molecule, llm=DEEPSEEK_FIREWORKS_MODEL, az_model="USPTO")


AttributeError: 'PrintLogger' object has no attribute 'addHandler'

# New tests

In [5]:
from src.prithvi import run_prithvi, add_metadata

In [17]:
import pytest
from unittest import mock as mock
from unittest.mock import patch, MagicMock

In [8]:
@pytest.fixture
def mock_context_logger():
    with patch('src.prithvi.context_logger') as mock_logger:
        yield mock_logger

In [9]:
@pytest.fixture
def mock_structlog():
    with patch('src.prithvi.structlog') as mock_log:
        yield mock_log

In [10]:
@pytest.fixture
def mock_add_job_specific_handler():
    with patch('src.prithvi.add_job_specific_handler') as mock_handler:
        yield mock_handler

In [11]:
@pytest.fixture
def mock_rec_run_prithvi():
    with patch('src.prithvi.rec_run_prithvi') as mock_run:
        yield mock_run

In [12]:
@pytest.fixture
def mock_format_output():
    with patch('src.prithvi.format_output') as mock_format:
        yield mock_format

In [13]:
MOLECULE = "NC(=O)C1=NC(F)=CN=C1O"

In [None]:
def test_run_prithvi_success(mock_context_logger, mock_structlog, mock_add_job_specific_handler, mock_rec_run_prithvi, mock_format_output):
    
    mock_rec_run_prithvi.return_value = ({'steps': []}, None)
    mock_format_output.return_value = {'steps': []}
    mock_add_job_specific_handler.return_value = MagicMock()

    result = run_prithvi(molecule=MOLECULE)

    assert result == {'steps': []}
    mock_rec_run_prithvi.assert_called_once()
    mock_format_output.assert_called_once()

In [None]:
mock_rec_run_prithvi.return_value = ({'steps': []}, None)
mock_format_output.return_value = {'steps': []}


In [None]:
# @mock.patch("src.prithvi.run_prithvi")
def test_run_prithvi(mock_structlog, mock_context_logger, mock_add_job_specific_handler, mock_rec_run_prithvi):

    assert isinstance(result, dict)
    assert 'steps' in result

In [8]:
output_data = run_prithvi("NC(=O)C1=NC(F)=CN=C1O", llm="claude-3-opus-20240229", az_model="USPTO")

AttributeError: 'PrintLogger' object has no attribute 'addHandler'