In [None]:
#|default_exp generative_callback

In [None]:
#|hide
%reload_ext autoreload
%autoreload 2

In [None]:
#|hide
from nbdev.showdoc import show_doc

In [None]:
from fastai.callback.core import Callback

In [None]:
#|export
from typing import List
from fastinference.inference import *
from denovo_design.generative_basics import *
from fcd_torch import *
from random import choices
from guacamol.distribution_learning_benchmark import ValidityBenchmark, UniquenessBenchmark, NoveltyBenchmark
from guacamol.frechet_benchmark import FrechetBenchmark
from guacamol.distribution_matching_generator import DistributionMatchingGenerator

2022-08-15 23:46:17.385419: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-08-15 23:46:17.385474: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [None]:
#|export    
class MockGenerator(DistributionMatchingGenerator):
    """
    Mock generator that returns pre-defined molecules,
    possibly split in several calls
    """

    def __init__(self, molecules) -> None:
        self.molecules = molecules
        self.cursor = 0

    def generate(self, number_samples: int):
        end = self.cursor + number_samples

        sampled_molecules = self.molecules[self.cursor:end]
        self.cursor = end
        return sampled_molecules

In [None]:
#|export
class GenerativeCallback(Callback):
    
    """
    
    A callback to generate molecules while training a model. 

    
    """

    def __init__(self, reference_mols:List=[], text:str='', max_size:int=100, temperature:float=0.7, max_mols:int=100):
        
        """
        Arguments:
        
            reference_mols : list
                A list of reference SMILES to compute validation metrics
                
            text : str
                Seed string (default = '')
                
            max_size : int
                Maximum size of generate SMILES 
                
            temperature : float
                Sampling temperature (default = 1.0)
                
            max_mols : int
                Number of molecules to generate
                
        Returns:
        
            decoded : list
                A list of generated SMILES.
        
        """
        
        super().__init__()
        self.reference_mols = reference_mols
        self.text = text
        self.max_size = max_size
        self.temperature = temperature
        self.max_mols = max_mols
        self.smiles = []
        self.valid_mols = []
        
        # Define the benchmark before training because it needs to calculate the mean and covariance for ref mols
        self.fcd_benchmark = FrechetBenchmark(training_set=reference_mols, sample_size=len(reference_mols))
 
    def sampling(self):

        self.model.reset()    # Reset the model

        nums = self.dls.numericalize
        stop_index = self.dls.train.vocab.index(BOS)

        idxs = idxs_all = self.dls.test_dl([self.text]).items[0].to(self.dls.device)
        for _ in range(self.max_size):
            preds = self.get_preds(dl=[(idxs[None],)], decoded_loss=False)

            res = tensor(preds[0][0][-1])
            #print(res.shape)
            if self.temperature != 1.: res.pow_(1 / self.temperature)
            idx = torch.multinomial(res, 1).item()
            if idx != stop_index:

                idxs = idxs_all = torch.cat([idxs_all, idxs.new([idx])])
            else:
                break
        decoded = ''.join([nums.vocab[o] for o in idxs_all if nums.vocab[o] not in [BOS, PAD]])  # Decode predicted tokens
        return decoded
    

    def _validity_score(self):
        gen = MockGenerator(self.smiles)
        val = ValidityBenchmark(number_samples=len(gen.molecules)).assess_model(gen).score        
        return val
    
    def _fcd_score(self):
        gen = MockGenerator(self.smiles)
        fcd_score = self.fcd_benchmark.assess_model(gen).score
        return fcd_score

    def _uniqueness_score(self): 
        gen = MockGenerator(self.smiles)
        unq = UniquenessBenchmark(number_samples=len(gen.molecules)).assess_model(gen).score
        return unq

    def _novelty_score(self):
        gen = MockGenerator(self.smiles)
        nov = NoveltyBenchmark(number_samples=len(gen.molecules),training_set=self.reference_mols).assess_model(gen).score
        return nov       

    def before_epoch(self):
        self.val, self.unq, self.nov = 0, 0, 0
        self.smiles = []
        self.valid_mols = []

    def before_validate(self, **kwargs):

        self.smiles += [self.sampling() for _ in range(self.max_mols)]
        print(self.smiles[:5])

In [None]:
show_doc(GenerativeCallback)

---

### GenerativeCallback

>      GenerativeCallback (reference_mols:List=[], text:str='',
>                          max_size:int=100, temperature:float=0.7,
>                          max_mols:int=100)

A callback to generate molecules while training a model. 

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()