In [1]:
import numpy as np
import gym
from gym import spaces
from rxitect.models.lightning.generator import Generator
from rxitect.models.vanilla.predictor import Predictor
from typing import List
import selfies as sf


class DrugPropEnv(gym.Env):
    """
    Custom Environment that follows gym interface.
    This is a simple env where the agent must learn to go always left. 
    """
    metadata = {'render.modes': ['console']}

    def __init__(self, generator: Generator, predictor: Predictor):
        super(DrugPropEnv, self).__init__()

        self.generator = generator
        self.predictor = predictor

        n_actions = 1
        self.action_space = spaces.Discrete(n_actions)
        self.observation_space = spaces.Box(low=0, high=self.grid_size,
                                            shape=(1,), dtype=np.float32)

    def reset(self):
        """
        Important: the observation must be a numpy array
        :return: (np.array) 
        """
        pass

    def step(self, action):
        enc_selfies = self.generator.sample(10)
        selfies = self.generator.voc.decode(enc_selfies)
        smiles = [sf.decoder(selfie) for selfie in selfies]
        reward = get_reward(predictor=self.predictor, mols=smiles)
        done = True

        # Optionally we can pass additional info, we are not using that for now
        info = {}

        return np.array([]), reward, done, info

    def render(self, mode='console'):
        pass

    def close(self):
        pass


In [None]:
from stable_baselines3.common.env_checker import check_env

In [None]:
env = GoLeftEnv()

In [None]:
check_env(env, warn=True)

In [2]:
p = Predictor(path="../models/RF_REG_CHEMBL226.pkg", type_="REG")

In [None]:
fps = p.calc_fp(["CCC", "CCC"])

In [3]:
p.get_reward(["CCC", "CCC"])

array([4.01651, 4.01651])

In [14]:
def get_scores(predictor: Predictor, mols: List[str]) -> np.ndarray:
    """Calculates the rewards for a list of SMILES
    Args:
        mols: A list of molecules in SMILES representation.
    Returns:
        A list of rewards per molecule
    """
    fps = predictor.calc_fp(mols)
    scores = predictor.__call__(fps)
    return scores

def get_reward(predictor: Predictor, mols: List[str]):
    scores = get_scores(predictor, mols)
    reward = np.exp(scores[0]/3)
    return reward

In [15]:
get_reward(predictor=p, mols=["CCC"])

3.814603267993802