In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 15:36:30 2020

@author: stravsm
"""

import importlib
from importlib import reload
from tqdm import tqdm
import os


In [22]:
%reload_ext autoreload
%autoreload 2 

import tensorflow as tf
import numpy as np
import pandas as pd

from fp_management import database as db
from fp_management import mist_fingerprinting as fpr
from fp_management import fingerprint_map as fpm
import smiles_config as sc

sc.config_file.append("config.EULER-eval.yaml")
sc.config_reload()

import infrastructure.generator as gen
import infrastructure.decoder as dec

import time
from datetime import datetime
import pickle
import pathlib


from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
import infrastructure.score as msc
import gc
import random

# Disable dropout. Is there a more elegant way to adapt config at runtime?
sc.config["model_config"]["training"] = False

# Randomness is relevant for stochastic sampling
random_seed = sc.config['random_seed_global']
if random_seed != '':
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.random.experimental.set_seed(random_seed)

# Setup logger
import logging
logging.basicConfig(format='%(asctime)s - %(message)s', 
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("MSNovelist")
logger.setLevel(logging.INFO)
logger.info("evaluation startup")

eval_folder = pathlib.Path(sc.config["eval_folder"])
eval_folder.mkdir(parents=True, exist_ok=True)

eval_id = str(int(time.time()))
pickle_id = eval_id
if sc.config['eval_id'] != '':
    eval_id = sc.config['eval_id']
if sc.config['eval_counter'] != '':
    pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
    
if isinstance(sc.config['weights'], list):
    weights_list = sc.config['weights']
else:
    weights_list = [sc.config['weights']]
    


03-Jun-24 14:31:58 - evaluation startup


In [3]:

# First, do everything independent of weights

fpr.MistFingerprinter.init_instance()
fingerprinter = fpr.MistFingerprinter.get_instance()


In [4]:

  
n = sc.config["eval_n"]
n_total = sc.config["eval_n_total"]
#n_total_ = n_total // n * n
k = sc.config["eval_k"]
kk = sc.config["eval_kk"]
steps = sc.config["eval_steps"]

decoder_name = sc.config["decoder_name"]

evaluation_set = sc.config["evaluation_set"]

# File for CSI:FingerID validation data
data_eval_ = sc.config["db_path_eval"]
# Load mapping table for the CSI:FingerID predictors
# Load dataset and process appropriately
db_eval = db.FpDatabase.load_from_config(data_eval_)
pipeline_options =  db_eval.get_pipeline_options()
    
pipeline_encoder = sc.config['pipeline_encoder']
pipeline_reference = sc.config['pipeline_reference']

dataset_val = db_eval.get_grp(evaluation_set)
if n_total != -1:
    dataset_val = dataset_val[:n_total]
else:
    n_total = len(dataset_val)


In [5]:
pipeline_options

{'embed_X': False,
 'unpackbits': False,
 'unpack': False,
 'fingerprint_selected': 'fingerprint_degraded'}

In [6]:
# On-the-fly translate the dataset :(
# 

def entry_for_row(row):
    res = {
        key: row[key] for key in row.keys()
    }
    res["fingerprint"] = fingerprinter.get_fp(row["fingerprint"])[0,:]
    res["fingerprint_degraded"] = fingerprinter.get_fp(row["fingerprint_degraded"])[0,:]
    return res

dataset_val_mapped = [entry_for_row(x) for x in dataset_val]

In [7]:
r = dataset_val_mapped[0]

In [8]:
r["fingerprint_degraded"].shape

(4096,)

In [9]:
%autoreload 2
# Load dataset and sampler, apply sampler to dataset
# (so we can also evaluate from fingerprint_sampled)
fp_dataset_val_ = gen.smiles_pipeline(dataset_val_mapped,
                                    batch_size = n,
                                    **pipeline_options,
                                    map_fingerprints=False,
                                    degraded_fingerprint_type = "uint8")

fp_dataset_val = gen.dataset_zip(fp_dataset_val_, 
                                 pipeline_encoder, pipeline_reference,
                                 **pipeline_options)


03-Jun-24 14:28:05 - using unpickle_mf
03-Jun-24 14:28:05 - not using unpack
03-Jun-24 14:28:05 - not using fp_map
03-Jun-24 14:28:06 - not using embed_X
03-Jun-24 14:28:06 - Selecting fingerprint fingerprint_degraded


In [10]:
next(iter(fp_dataset_val))

({'fingerprint_selected': <tf.Tensor: shape=(8, 4096), dtype=float32, numpy=
  array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
  'mol_form': <tf.Tensor: shape=(8, 10), dtype=float32, numpy=
  array([[15.,  0.,  0.,  0.,  7.,  3.,  0.,  0.,  1., 19.],
         [53.,  0.,  0.,  0.,  0., 22.,  0.,  0.,  0., 90.],
         [47.,  0.,  0.,  0.,  2., 13.,  0.,  0.,  0., 56.],
         [54.,  0.,  0.,  0.,  0., 23.,  0.,  0.,  0., 92.],
         [16.,  0.,  0.,  0.,  2.,  3.,  0.,  0.,  1., 16.],
         [21.,  0.,  0.,  0.,  2.,  3.,  0.,  0.,  0., 28.],
         [23.,  0.,  0.,  2.,  7.,  1.,  0.,  0.,  0., 31.],
         [21.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,  1., 19.]], dtype=float32)>,
  'n_hydrogen': <tf.Tensor: shape=(8,), dtype=float32, numpy=array([19., 90., 56.,

In [11]:

sampler_name = sc.config['sampler_name']
round_fingerprints = True
if sampler_name != '':
    logger.info(f"Sampler {sampler_name} loading")
    sampler_module = importlib.import_module('fp_sampling.' + sampler_name, 'fp_sampling')
    sampler_factory = sampler_module.SamplerFactory(sc.config)
    round_fingerprints = sampler_factory.round_fingerprint_inference()
    sampler = sampler_factory.get_sampler()
    logger.info(f"Sampler {sampler_name} loaded")
    fp_dataset_val_ = sampler.map_dataset(fp_dataset_val_)



03-Jun-24 14:28:06 - Sampler basic_tp_fp loading
03-Jun-24 14:28:07 - Sampler basic_tp_fp loaded


In [12]:

for weights_i, weights_ in enumerate(weights_list):
    eval_id = str(int(time.time()))
    pickle_id = eval_id
    if sc.config['eval_id'] != '':
        eval_id = sc.config['eval_id']
    if sc.config['eval_counter'] != '':
        pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
        if len(weights_list) > 1:
            pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter'] + "-" + weights_i
    
    # logpath_topn = eval_folder / ("eval_" + eval_id + "_topn.txt")
    # logpath_top1 = eval_folder / ("eval_" + eval_id + "_top1.txt")
    picklepath = eval_folder / ("eval_" + pickle_id + ".pkl")
    logger.info(picklepath)
    logger.info(weights_)
    weights = os.path.join(sc.config["weights_folder"], weights_)

    
    retain_single_duplicate = True

    fp_dataset_iter = iter(fp_dataset_val)
    blueprints = gen.dataset_blueprint(fp_dataset_val_)
    
    # Load models
    


03-Jun-24 14:28:07 - /data/MSNovelist-results/eval_1717424887.pkl
03-Jun-24 14:28:07 - w-05-0.071-0.069.hdf5


In [13]:
    import model
    
    model_encode = model.EncoderModel(
                     blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    model_decode = model.DecoderModel(
                     blueprints = blueprints,
                     config = sc.config,)
    model_transcode = model.TranscoderModel(
                    blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    


using fingerprint rounding in model
using fingerprint rounding in model


In [14]:
fp_dataset_val

<ZipDataset shapes: ({fingerprint_selected: (None, 4096), mol_form: (None, 10), n_hydrogen: (None,)}, ((None,), (None, 4096))), types: ({fingerprint_selected: tf.float32, mol_form: tf.float32, n_hydrogen: tf.float32}, (tf.string, tf.float32))>

In [15]:
    # Build models by calling them
    y_ = model_transcode(blueprints)
    enc = model_encode(next(fp_dataset_iter)[0])
    _ = model_decode(enc)
    
    model_transcode.load_weights(weights, by_name=True)
    model_encode.copy_weights(model_transcode)
    model_decode.copy_weights(model_transcode)
    

03-Jun-24 14:28:09 - Loading layer encoder weights
03-Jun-24 14:28:09 - Loaded
03-Jun-24 14:28:09 - Loading layer hydrogen_estimator weights
03-Jun-24 14:28:09 - Loaded
03-Jun-24 14:28:09 - Loading layer tokens_y weights
03-Jun-24 14:28:09 - Loaded


In [16]:

    # Initialize decoder
    decoder = dec.get_decoder(decoder_name)(
        model_encode, model_decode, steps, n, k, kk, config = sc.config)
    logger.info("Decoder initialized")
    logger.info(f"Processing and scoring predictions")
    


03-Jun-24 14:28:09 - Decoder initialized
03-Jun-24 14:28:09 - Processing and scoring predictions


In [17]:
    logger.info(f"Predicting {n_total} samples - start")
    logger.info(f"Beam block size {n}*{k}*{steps}, sequences retrieved per sample: {kk}")
    result_blocks = []
    reference_blocks = []
    for data in tqdm(fp_dataset_val, total = (n_total -1) // n + 1):
        # repeat the input data k times for each of n queries
        # (now we encode each of k samples individually because the encoding
        # may be probabilistic)
        
        # make a custom decoder if we don't have all n samples
        n_real = len(data[0]['n_hydrogen'])
        if n_real != n:
            decoder = dec.get_decoder(decoder_name)(
                    model_encode, model_decode, steps, n_real, k, kk, config = sc.config)
        
        data_k = {key: tf.repeat(x, k, axis=0) for key, x in data[0].items()}
        states_init = model_encode.predict(data_k)
        # predict k sequences for each query.
        sequences, y, scores = decoder.decode_beam(states_init)
        seq, score, length = decoder.beam_traceback(sequences, y, scores)
        smiles = decoder.sequence_ytoc(seq)
        results_df = decoder.format_results(smiles, score)
        result_blocks.append(results_df)
        reference_df = decoder.format_reference(
            [bytes.decode(x, 'UTF-8') for x in data[1][0].numpy()],
            [d for d in data[1][1].numpy()])
        reference_blocks.append(reference_df)
    results = pd.concat(result_blocks)        
    logger.info(f"Predicting {n_total} samples - done")
    pickle.dump(results, open(
        picklepath.with_suffix("").with_name(picklepath.name + "_all"), "wb")
        )


03-Jun-24 14:28:09 - Predicting 20 samples - start
03-Jun-24 14:28:09 - Beam block size 8*64*128, sequences retrieved per sample: 10
100%|██████████| 3/3 [00:07<00:00,  2.46s/it]
03-Jun-24 14:28:17 - Predicting 20 samples - done


In [18]:
results

Unnamed: 0,smiles,score,id,n,k
0,CS(=O)(=O)N1CCc2nc(-c3cnc(N)nc3N3CCOCC3)ncc21,-3.732617,0,0,0
1,CS(=O)(=O)N1CCc2c(N)nc(-c3cnc(N)nc3N3CCOC3)cc21,-3.855933,1,0,1
2,CS(=O)(=O)N1CCc2c(N3CCOCC3)nc(-c3cnc(N)nc3)nc21,-3.999996,2,0,2
3,CS(=O)(=O)N1CCc2c(N)nc(-c3cnc(N4CCOCC4)nc3)nc21,-4.053775,3,0,3
4,CS(=O)(=O)N1CCc2c(-c3cnc(N)nc3)nc(N3CCOCC3)nc21,-4.145812,4,0,4
...,...,...,...,...,...
35,c1cc(O)c2c(c1)OC(c1ccc(O)c(O)c1)C(O)C2c1c(O)cc...,-4.734622,35,3,5
36,c1c2c(O)cc(O)c(C3c4c(O)cc(O)cc4OC(c4ccc(O)c(O)...,-4.850604,36,3,6
37,Oc1cc(O)c(C2c3c(O)cc(O)cc3OC(c3ccc(O)c(O)c3)C2...,-4.935221,37,3,7
38,Oc1ccc(C2Oc3c(c(O)cc(O)c3C3c4c(O)cc(O)cc4OC(c4...,-5.153457,38,3,8


In [19]:
    %autoreload
    logger.info(f"Evaluating {n_total} blocks - start")
    
    results_evaluated = []
    for block_, ref_, block_id in zip(tqdm(result_blocks), 
                                    reference_blocks,
                                    range(len(result_blocks))):
        # Make a block with molecule, MF, smiles for candidates and reference
        block = db.process_df(block_, fingerprinter,
                              construct_from = "smiles",
                              block_id = block_id)
        
        if retain_single_duplicate:
            block.sort_values("score", ascending = False, inplace = True)
            block = block.groupby(["n", "inchikey1"]).first().reset_index()
            
        ref = db.process_df(ref_, fingerprinter,
                              construct_from = "smiles",
                              block_id = block_id)
        # Also actually compute the true fingerprint for the reference

        if sc.config["eval_fingerprint_all"]:
            fingerprinter.process_df(ref,
                                    out_column = "fingerprint_ref_true",
                                    inplace=True)
            
        # Match ref to predictions
        block = block.join(ref, on="n", rsuffix="_ref")
        # Keep only correct formula
        block_ok = block.loc[block["inchikey1"].notna()].loc[block["mf"] == block["mf_ref"]]
        # Now actually compute the fingerprints, only for matching MF
        if sc.config["eval_fingerprint_all"]:
            fingerprinter.process_df(block_ok,
                                 inplace=True)
        block = block.merge(
            block_ok[["n","k","fingerprint"]],
            left_on = ["n", "k"],
            right_on = ["n", "k"],
            suffixes = ["_ref", ""],
            how = "left")
    
        results_evaluated.append(block)
        


03-Jun-24 14:28:18 - Evaluating 20 blocks - start
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  block["smiles_generic"] = smiles_generic
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  block["smiles_canonical"] = smiles
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  block["mol"] = mol
A value is trying to be set on a copy of a slice from a Data

failed parsing id 12 - s???????????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 13 - c=??????????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 14 - C?O?????????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 15 - CC(P????????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 16 - CNCCC???????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 17 - CC(C=C??????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 18 - COCC(=O??????????????????????????????????????????????????????????????????????????????????????????????????????????

 33%|███▎      | 1/3 [00:00<00:00,  4.84it/s]

failed parsing id 12 - COc1ccccc1C=C(C#N)C(=O)c1ccc(Oc2ncncc2C(=O)O)c1
failed parsing id 19 - COc1ccccc1C=C(C#N)C(=O)c1ccc(Oc2ncnc(C(=O)O)c2)c1
failed parsing id 33 - O=C(OCC1OC(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2ccc(O)c2)C1O)O
failed parsing id 35 - O=C(OCC1OC(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c2)C1)O
failed parsing id 36 - O=C(OCC1OC(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c(O)c2)C(O)C1OC(=O)c1cccc1)OO
failed parsing id 38 - CC(=O)OCC1OC(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c(O)c2)C(OC(=O)c2cc(O)c(O)c2)O1
failed parsing id 48 - CC(=O)OCC1OC(OC2C(OC(C)=O)C(OC(C)=O)C(OC(C)=O)C(OC(C)=O)C(OC(C)=O)C2OC(C)=O)C(OC(C)=O)C(=O)C1=CC(=O)c1cccc1
failed parsing id 50 - c
failed parsing id 53 - s????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????
failed parsing id 54 - 7C????????????????????????????????????????????????????????????????????????????????????????

100%|██████████| 3/3 [00:00<00:00,  4.96it/s]

failed parsing id 36 - c1c2c(O)cc(O)c(C3c4c(O)cc(O)cc4OC(c4ccc(O)c(O)c4)C3O)c2OC(c2ccc(O)c(O)c2)C1O





In [20]:
    logger.info(f"Evaluating {n_total} blocks - merging")
    results_complete = pd.concat(results_evaluated)
    results_complete["nn"] = n * results_complete["block_id"] + results_complete["n"]
    results_complete ["evaluation_set"] = evaluation_set
    
    logger.info(f"Pickling predictions from [{evaluation_set}]")
    pickle.dump(results_complete, open(picklepath, "wb"))
    
    results_ok = results_complete.loc[results_complete["fingerprint"].notna()].copy()


03-Jun-24 14:28:18 - Evaluating 20 blocks - merging
03-Jun-24 14:28:18 - Pickling predictions from [val]


In [None]:
results_complete.columns

In [None]:
eval_folder = pathlib.Path("/data/MSNovelist-results")
out_path =  eval_folder / ("eval_" + pickle_id + ".tsv")


results_complete.to_csv(out_path, sep='\t')

In [None]:
short_out_path =  eval_folder / ("eval_" + pickle_id + "_short.tsv")
results_complete[["smiles", "inchikey1", "inchikey1_ref", "nn", "score"]].to_csv(short_out_path, sep='\t')

In [None]:
scores = msc.get_candidate_scores()
results_ok = msc.compute_candidate_scores(results_ok, fp_map, 
                                          additive_smoothing_n = n_total_,
                                          f1_cutoff = f1_cutoff)

In [27]:
scores

{'score_mod_platt': <function infrastructure.score.score_mod_platt(predicted, candidate, stats, f1_cutoff=0.5)>,
 'score_unit': <function infrastructure.score.score_unit(predicted, candidate, stats=None, f1_cutoff=0.5)>,
 'score_unit_pos': <function infrastructure.score.score_unit_pos(predicted, candidate, stats=None, f1_cutoff=0.5)>,
 'score_platt': <function infrastructure.score.score_platt(predicted, candidate, stats=None, f1_cutoff=0.5)>,
 'score_max_likelihood': <function infrastructure.score.score_max_likelihood(predicted, candidate, stats, f1_cutoff=0.5)>,
 'score_tanimoto': <function infrastructure.score.score_tanimoto(predicted, candidate, stats=None, f1_cutoff=0.5)>,
 'score_rel_mod_platt': <function infrastructure.score.score_rel_mod_platt(predicted, candidate, stats, f1_cutoff=0.5)>,
 'score_lim_mod_platt': <function infrastructure.score.score_lim_mod_platt(predicted, candidate, stats, f1_cutoff=0.5)>}

In [24]:

fp_map = fpm.FingerprintMap(sc.config["fp_map"])



FileNotFoundError: [Errno 2] No such file or directory: '/data/MSNovelist-data/fingerprint_map_pseudo.tsv'