In [17]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [18]:
from typing import Optional
from guacamol.goal_directed_generator import GoalDirectedGenerator
from guacamol.scoring_function import ScoringFunction
from guacamol.assess_goal_directed_generation import assess_goal_directed_generation
from rdkit import Chem

from reward import MolReward
from utils import conf_from_yaml, generator_from_conf

class GuacaMolReward(MolReward):
    def __init__(self, scoring_function: ScoringFunction):    
        self.scoring_function = scoring_function
        
    # implement
    def mol_objective_functions(self):
        def raw_score(mol):
            smiles = Chem.MolToSmiles(mol)
            return self.scoring_function.score(smiles)

        return [raw_score]

    # implement
    def reward_from_objective_values(self, objective_values):
        score = objective_values[0]
        return score

class V3DeNovoGenerator(GoalDirectedGenerator):
    # implement
    def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int, starting_population: Optional[list[str]] = None) -> list[str]:       
        yaml_path = "config/mcts_guacamol_de_novo.yaml"
        conf = conf_from_yaml(yaml_path, repo_root)
        generator = generator_from_conf(conf)
        generator.reward = GuacaMolReward(scoring_function=scoring_function)
        generator.generate(max_generations=conf.get("max_generations"), time_limit=conf.get("time_limit"))
        
        return generator.generated_keys()

class V3LeadGenerator(GoalDirectedGenerator):
    # implement
    def generate_optimized_molecules(self, scoring_function: ScoringFunction, number_molecules: int, starting_population: Optional[list[str]] = None) -> list[str]:       
        yaml_path = "config/mcts_guacamol_lead.yaml"
        conf = conf_from_yaml(yaml_path, repo_root)
        conf["root"] = starting_population or "c1ccccc1"
        generator = generator_from_conf(conf)
        generator.reward = GuacaMolReward(scoring_function=scoring_function)
        generator.generate(max_generations=conf.get("max_generations"), time_limit=conf.get("time_limit"))
        
        return generator.generated_keys(last=number_molecules)

In [None]:
from guacamol.assess_goal_directed_generation import _evaluate_goal_directed_benchmarks
from guacamol.benchmark_suites import goal_directed_benchmark_suite
from guacamol.goal_directed_benchmark import GoalDirectedBenchmarkResult

def assess(goal_directed_molecule_generator: GoalDirectedGenerator, benchmark_version='v1') -> list[GoalDirectedBenchmarkResult]:

    benchmarks = goal_directed_benchmark_suite(version_name=benchmark_version)

    results = _evaluate_goal_directed_benchmarks(
        goal_directed_molecule_generator=goal_directed_molecule_generator,
        benchmarks=benchmarks)
    
    return results

results = assess(V3DeNovoGenerator())

scores = []
for result in results:
    scores.append(result.score)
    print(result.benchmark_name, result.score)
print(sum(scores))