In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 15:36:30 2020

@author: stravsm
"""

import importlib
from importlib import reload
from tqdm import tqdm
import os


In [14]:
%reload_ext autoreload
%autoreload 2 

import tensorflow as tf
import numpy as np
import pandas as pd

from fp_management import database as db
from fp_management import mist_fingerprinting as fpr
from fp_management import base_fingerprinting as bfpr
import smiles_config as sc

sc.config_file.append("config.EULER-eval.yaml")
sc.config_reload()

import infrastructure.generator as gen
import infrastructure.decoder as dec

import time
from datetime import datetime
import pickle
import pathlib


from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
import infrastructure.score as msc
import gc
import random

# Disable dropout. Is there a more elegant way to adapt config at runtime?
sc.config["model_config"]["training"] = False

# Randomness is relevant for stochastic sampling
random_seed = sc.config['random_seed_global']
if random_seed != '':
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.random.experimental.set_seed(random_seed)

# Setup logger
import logging
logging.basicConfig(format='%(asctime)s - %(message)s', 
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("MSNovelist")
logger.setLevel(logging.INFO)
logger.info("evaluation startup")

eval_folder = pathlib.Path(sc.config["eval_folder"])
eval_folder.mkdir(parents=True, exist_ok=True)

eval_id = str(int(time.time()))
pickle_id = eval_id
if sc.config['eval_id'] != '':
    eval_id = sc.config['eval_id']
if sc.config['eval_counter'] != '':
    pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
    
if isinstance(sc.config['weights'], list):
    weights_list = sc.config['weights']
else:
    weights_list = [sc.config['weights']]
    


03-Jun-24 13:21:03 - evaluation startup


In [3]:

# First, do everything independent of weights

fpr.MistFingerprinter.init_instance()
fingerprinter = fpr.MistFingerprinter.get_instance()


In [36]:

  
n = sc.config["eval_n"]
n_total = sc.config["eval_n_total"]
#n_total_ = n_total // n * n
k = sc.config["eval_k"]
kk = sc.config["eval_kk"]
steps = sc.config["eval_steps"]

decoder_name = sc.config["decoder_name"]

evaluation_set = sc.config["evaluation_set"]

# File for CSI:FingerID validation data
data_eval_ = sc.config["db_path_eval"]
# Load mapping table for the CSI:FingerID predictors
# Load dataset and process appropriately
db_eval = db.FpDatabase.load_from_config(data_eval_)
pipeline_options =  db_eval.get_pipeline_options()
    
pipeline_encoder = sc.config['pipeline_encoder']
pipeline_reference = sc.config['pipeline_reference']

dataset_val = db_eval.get_grp(evaluation_set)
if n_total != -1:
    dataset_val = dataset_val[:n_total]
else:
    n_total = len(dataset_val)


In [37]:
pipeline_options

{'embed_X': False,
 'unpackbits': False,
 'unpack': False,
 'fingerprint_selected': 'fingerprint_degraded'}

In [65]:
# On-the-fly translate the dataset :(
# 

def entry_for_row(row):
    res = {
        key: row[key] for key in row.keys()
    }
    res["fingerprint"] = fingerprinter.get_fp(row["fingerprint"])[0,:]
    res["fingerprint_degraded"] = fingerprinter.get_fp(row["fingerprint_degraded"])[0,:]
    return res

dataset_val_mapped = [entry_for_row(x) for x in dataset_val]

In [67]:
r = dataset_val_mapped[0]

In [68]:
r["fingerprint_degraded"].shape

(4096,)

In [69]:
%autoreload 2
# Load dataset and sampler, apply sampler to dataset
# (so we can also evaluate from fingerprint_sampled)
fp_dataset_val_ = gen.smiles_pipeline(dataset_val_mapped,
                                    batch_size = n,
                                    **pipeline_options,
                                    map_fingerprints=False,
                                    degraded_fingerprint_type = "uint8")

fp_dataset_val = gen.dataset_zip(fp_dataset_val_, 
                                 pipeline_encoder, pipeline_reference,
                                 **pipeline_options)


03-Jun-24 13:45:20 - using unpickle_mf
03-Jun-24 13:45:20 - not using unpack
03-Jun-24 13:45:20 - not using fp_map
03-Jun-24 13:45:21 - not using embed_X
03-Jun-24 13:45:21 - Selecting fingerprint fingerprint_degraded


In [70]:
next(iter(fp_dataset_val))

({'fingerprint_selected': <tf.Tensor: shape=(8, 4096), dtype=float32, numpy=
  array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
  'mol_form': <tf.Tensor: shape=(8, 10), dtype=float32, numpy=
  array([[15.,  0.,  0.,  0.,  7.,  3.,  0.,  0.,  1., 19.],
         [53.,  0.,  0.,  0.,  0., 22.,  0.,  0.,  0., 90.],
         [47.,  0.,  0.,  0.,  2., 13.,  0.,  0.,  0., 56.],
         [54.,  0.,  0.,  0.,  0., 23.,  0.,  0.,  0., 92.],
         [16.,  0.,  0.,  0.,  2.,  3.,  0.,  0.,  1., 16.],
         [21.,  0.,  0.,  0.,  2.,  3.,  0.,  0.,  0., 28.],
         [23.,  0.,  0.,  2.,  7.,  1.,  0.,  0.,  0., 31.],
         [21.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,  1., 19.]], dtype=float32)>,
  'n_hydrogen': <tf.Tensor: shape=(8,), dtype=float32, numpy=array([19., 90., 56.,

In [71]:

sampler_name = sc.config['sampler_name']
round_fingerprints = True
if sampler_name != '':
    logger.info(f"Sampler {sampler_name} loading")
    sampler_module = importlib.import_module('fp_sampling.' + sampler_name, 'fp_sampling')
    sampler_factory = sampler_module.SamplerFactory(sc.config)
    round_fingerprints = sampler_factory.round_fingerprint_inference()
    sampler = sampler_factory.get_sampler()
    logger.info(f"Sampler {sampler_name} loaded")
    fp_dataset_val_ = sampler.map_dataset(fp_dataset_val_)



03-Jun-24 13:45:29 - Sampler basic_tp_fp loading
03-Jun-24 13:45:29 - Sampler basic_tp_fp loaded


In [72]:

for weights_i, weights_ in enumerate(weights_list):
    eval_id = str(int(time.time()))
    pickle_id = eval_id
    if sc.config['eval_id'] != '':
        eval_id = sc.config['eval_id']
    if sc.config['eval_counter'] != '':
        pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
        if len(weights_list) > 1:
            pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter'] + "-" + weights_i
    
    # logpath_topn = eval_folder / ("eval_" + eval_id + "_topn.txt")
    # logpath_top1 = eval_folder / ("eval_" + eval_id + "_top1.txt")
    picklepath = eval_folder / ("eval_" + pickle_id + ".pkl")
    logger.info(picklepath)
    logger.info(weights_)
    weights = os.path.join(sc.config["weights_folder"], weights_)

    
    retain_single_duplicate = True

    fp_dataset_iter = iter(fp_dataset_val)
    blueprints = gen.dataset_blueprint(fp_dataset_val_)
    
    # Load models
    


03-Jun-24 13:45:30 - /tmp/mistnovelist-eval/eval_1717422330.pkl
03-Jun-24 13:45:30 - w-02-0.086-0.079.hdf5


In [73]:
    import model
    
    model_encode = model.EncoderModel(
                     blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    model_decode = model.DecoderModel(
                     blueprints = blueprints,
                     config = sc.config,)
    model_transcode = model.TranscoderModel(
                    blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    


using fingerprint rounding in model
using fingerprint rounding in model


In [74]:
    # Build models by calling them
    y_ = model_transcode(blueprints)
    enc = model_encode(next(fp_dataset_iter)[0])
    _ = model_decode(enc)
    
    model_transcode.load_weights(weights, by_name=True)
    model_encode.copy_weights(model_transcode)
    model_decode.copy_weights(model_transcode)
    

03-Jun-24 13:45:37 - Loading layer encoder weights
03-Jun-24 13:45:37 - Loaded
03-Jun-24 13:45:37 - Loading layer hydrogen_estimator weights
03-Jun-24 13:45:37 - Loaded
03-Jun-24 13:45:37 - Loading layer tokens_y weights
03-Jun-24 13:45:37 - Loaded
