In [None]:
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 15:36:30 2020

@author: stravsm
"""

import importlib
from importlib import reload
from tqdm import tqdm
import os

import tensorflow as tf
import numpy as np
import pandas as pd

from fp_management import database as db
from fp_management import mist_fingerprinting
import smiles_config as sc

sc.config_file.append("config.EULER-eval.yaml")
sc.config_reload()

import infrastructure.generator as gen
import infrastructure.decoder as dec

import time
from datetime import datetime
import pickle
import pathlib


from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
import infrastructure.score as msc
import gc
import random

# Disable dropout. Is there a more elegant way to adapt config at runtime?
sc.config["model_config"]["training"] = False

# Randomness is relevant for stochastic sampling
random_seed = sc.config['random_seed_global']
if random_seed != '':
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.random.experimental.set_seed(random_seed)

# Setup logger
import logging
logging.basicConfig(format='%(asctime)s - %(message)s', 
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("MSNovelist")
logger.setLevel(logging.INFO)
logger.info("evaluation startup")

eval_folder = pathlib.Path(sc.config["eval_folder"])
eval_folder.mkdir(parents=True, exist_ok=True)

eval_id = str(int(time.time()))
pickle_id = eval_id
if sc.config['eval_id'] != '':
    eval_id = sc.config['eval_id']
if sc.config['eval_counter'] != '':
    pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
    
if isinstance(sc.config['weights'], list):
    weights_list = sc.config['weights']
else:
    weights_list = [sc.config['weights']]
    


In [None]:

# First, do everything independent of weights

fpr.MistFingerprinter.init_instance()
fingerprinter = fpr.Fingerprinter.get_instance()

  
n = sc.config["eval_n"]
n_total = sc.config["eval_n_total"]
#n_total_ = n_total // n * n
k = sc.config["eval_k"]
kk = sc.config["eval_kk"]
steps = sc.config["eval_steps"]

decoder_name = sc.config["decoder_name"]

sc.config.setdefault('cv_fold', 0)
cv_fold = sc.config["cv_fold"]
#evaluation_set_ = sc.config['evaluation_set']
evaluation_set = f"fold{cv_fold}"

# File for CSI:FingerID validation data
data_eval_ = sc.config["db_path_eval"]
# Load mapping table for the CSI:FingerID predictors
# Load dataset and process appropriately
db_eval = db.FpDatabase.load_from_config(data_eval_)
pipeline_options =  db_eval.get_pipeline_options()
    
pipeline_encoder = sc.config['pipeline_encoder']
pipeline_reference = sc.config['pipeline_reference']

dataset_val = db_eval.get_grp(evaluation_set)
if n_total != -1:
    dataset_val = dataset_val[:n_total]
else:
    n_total = len(dataset_val)

# Load dataset and sampler, apply sampler to dataset
# (so we can also evaluate from fingerprint_sampled)
fp_dataset_val_ = gen.smiles_pipeline(dataset_val,
                                    batch_size = n,
                                    **pipeline_options,
                                    map_fingerprints=False)

fp_dataset_val = gen.dataset_zip(fp_dataset_val_, 
                                 pipeline_encoder, pipeline_reference,
                                 **pipeline_options)

sampler_name = sc.config['sampler_name']
round_fingerprints = True
if sampler_name != '':
    logger.info(f"Sampler {sampler_name} loading")
    sampler_module = importlib.import_module('fp_sampling.' + sampler_name, 'fp_sampling')
    sampler_factory = sampler_module.SamplerFactory(sc.config)
    round_fingerprints = sampler_factory.round_fingerprint_inference()
    sampler = sampler_factory.get_sampler()
    logger.info(f"Sampler {sampler_name} loaded")
    fp_dataset_val_ = sampler.map_dataset(fp_dataset_val_)


for weights_i, weights_ in enumerate(weights_list):
    eval_id = str(int(time.time()))
    pickle_id = eval_id
    if sc.config['eval_id'] != '':
        eval_id = sc.config['eval_id']
    if sc.config['eval_counter'] != '':
        pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter']
        if len(weights_list) > 1:
            pickle_id = sc.config['eval_id'] + "-" + sc.config['eval_counter'] + "-" + weights_i
    
    # logpath_topn = eval_folder / ("eval_" + eval_id + "_topn.txt")
    # logpath_top1 = eval_folder / ("eval_" + eval_id + "_top1.txt")
    picklepath = eval_folder / ("eval_" + pickle_id + ".pkl")
    logger.info(picklepath)
    logger.info(weights_)
    weights = os.path.join(sc.config["weights_folder"], weights_)

    
    retain_single_duplicate = True

    fp_dataset_iter = iter(fp_dataset_val)
    blueprints = gen.dataset_blueprint(fp_dataset_val_)
    
    # Load models
    
    import model
    
    model_encode = model.EncoderModel(
                     blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    model_decode = model.DecoderModel(
                     blueprints = blueprints,
                     config = sc.config,)
    model_transcode = model.TranscoderModel(
                    blueprints = blueprints,
                     config = sc.config,
                     round_fingerprints = round_fingerprints)
    
    # Build models by calling them
    y_ = model_transcode(blueprints)
    enc = model_encode(next(fp_dataset_iter)[0])
    _ = model_decode(enc)
    
    model_transcode.load_weights(weights, by_name=True)
    model_encode.copy_weights(model_transcode)
    model_decode.copy_weights(model_transcode)
    