In [1]:
from tqdm import tqdm
import sys
sys.path.append('/home/stravsmi/msmsgym/MSNovelist-private')


from fp_management import database as db
from fp_management import mist_fingerprinting as fpr
from fp_management import fingerprint_map as fpm
import os

In [2]:


import smiles_process as sp
import importlib
from importlib import reload
import smiles_config as sc

import infrastructure.generator as gen

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, \
    LambdaCallback, Callback

import numpy as np
import pandas as pd
import time
import math
import pickle
import json

In [3]:
sc.config_file.append('config.MULTIUSR.yaml')
sc.config_reload()

In [4]:

# Setup logger
import logging
logging.basicConfig(format='%(asctime)s - %(message)s', 
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("MSNovelist")
logger.setLevel(logging.INFO)
logger.info("training startup")


31-May-24 13:04:21 - training startup


In [5]:

sampler_name = sc.config['sampler_name']
sampler_module = None
if sampler_name != '':
    sampler_module = importlib.import_module('fp_sampling.' + sampler_name, 'fp_sampling')
#import models.quicktrain_fw_20190327 as sm

pipeline_x = sc.config['pipeline_x']
pipeline_y = sc.config['pipeline_y']
logger.info(f"pipeline_x: {pipeline_x}")
logger.info(f"pipeline_y: {pipeline_y}")

training_id = str(int(time.time()))
if sc.config['training_id'] != '':
    training_id = sc.config['training_id']

sc.config.setdefault('cv_fold', 0)
training_set = f"train"
validation_set = 'test'

31-May-24 13:04:22 - pipeline_x: ['fingerprint_selected', 'mol_form', 'tokens_X']
31-May-24 13:04:22 - pipeline_y: ['tokens_y', 'n_hydrogen']


In [6]:

logger.info(f"Training model id {training_id}, training set")

model_tag_id = "m-" + training_id + "-" + sc.config['model_tag']
logger.info(f"Tag: {model_tag_id}")

weights_path = os.path.join(
    sc.config["weights_folder"],
    model_tag_id,
    "train")
log_path = os.path.join(
    sc.config['log_folder'],
    model_tag_id,
    "train")

config_dump_path = os.path.join(
    weights_path,
    'config.yaml'
    )

os.makedirs(weights_path)
os.makedirs(log_path)

sc.config_dump(config_dump_path)


31-May-24 13:04:22 - Training model id 1717153462, training set
31-May-24 13:04:22 - Tag: m-1717153462-msnovelmist


In [23]:
sc.config

{'base_folder': '/home/stravsmi/msmsgym/MSNovelist-private/',
 'db_path': '',
 'fp_source': '',
 'fp_map': '/home/stravsmi/msmsgym/msnovelist-data/fingerprint_map_pseudo.tsv',
 'epochs': 30,
 'training_id': '',
 'cv_fold': 1,
 'cv_folds': 10,
 'steps_per_epoch': -1,
 'steps_per_epoch_validation': -1,
 'batch_size': 256,
 'hdf5_lock': 'FALSE',
 'model_name': 'models.model_flex_20190401',
 'sampler_name': 'gamma_bitmatrix',
 'fingerprinter_path': '',
 'fingerprinter_threads': 2,
 'fingerprinter_cache': '/tmp/fingerprint_cache.db',
 'weights_folder': '/target/evaluation/m-36719628/1',
 'eval_folder': '/tmp/mistnovelist-eval/',
 'log_folder': '/tmp/',
 'weights': 'w-20-0.069-0.074.hdf5',
 'sirius_bin': '/usr/local/bin/sirius',
 'training_set': 'train',
 'validation_set': 'validate',
 'pipeline_x': ['fingerprint_selected', 'mol_form', 'tokens_X'],
 'pipeline_y': ['tokens_y', 'n_hydrogen'],
 'pipeline_x_eval': [],
 'eval_n': 8,
 'eval_n_total': 300,
 'eval_k': 128,
 'eval_kk': 128,
 'eval_st

In [24]:


logger.info(f"Datasets - loading database")
fp_db  = db.FpDatabase.load_from_config(sc.config['db_path_train'])
fp_train = fp_db.get_grp(training_set)
fp_val = fp_db.get_grp(validation_set)
logger.info(f"Datasets - loading evaluation")
# File for CSI:FingerID validation data




data_eval_ =  sc.config["db_path_eval"]
# note: with CV, the evaluation set name is the same as the validation set name
db_eval = db.FpDatabase.load_from_config(data_eval_)
dataset_eval = db_eval.get_grp(validation_set)


31-May-24 13:14:04 - Datasets - loading database
31-May-24 13:16:03 - Datasets - loading evaluation


In [25]:
ro = next(iter(fp_train))

In [29]:
fpfp.shape

(1, 4096)

In [27]:
fpr.MistFingerprinter.init_instance()
fingerprinter = fpr.MistFingerprinter.get_instance()
fpfp = fingerprinter.get_fp(ro["fingerprint"])

In [31]:
fp_db.get_pipeline_options()

{'embed_X': False, 'fingerprint_selected': 'fingerprint_sampled'}

In [34]:

logger.info(f"Datasets - building pipeline for database")


fp_dataset_train_ = gen.smiles_pipeline(fp_train, 
                                        batch_size = sc.config['batch_size'],
                                        map_fingerprints=False,
                                        **fp_db.get_pipeline_options())


31-May-24 13:24:53 - Datasets - building pipeline for database
31-May-24 13:25:09 - using unpickle_mf
31-May-24 13:25:48 - not using unpack
31-May-24 13:25:48 - not using fp_map
31-May-24 13:25:49 - not using embed_X


In [35]:
fp_dataset_train_

{'fingerprint': <MapDataset shapes: (None,), types: tf.float32>,
 'fingerprint_degraded': <MapDataset shapes: (None,), types: tf.float32>,
 'mol_form': <MapDataset shapes: (None, 10), types: tf.float32>,
 'smiles_generic': <MapDataset shapes: (None,), types: tf.string>,
 'smiles_canonical': <MapDataset shapes: (None,), types: tf.string>,
 'n_hydrogen': <MapDataset shapes: (None,), types: tf.float32>,
 'tokens_X': <MapDataset shapes: (None, None, 40), types: tf.float32>,
 'tokens_y': <MapDataset shapes: (None, None, 40), types: tf.float32>}

In [None]:

fp_dataset_val_ = gen.smiles_pipeline(fp_val, 
                                        batch_size = sc.config['batch_size'],
                                        map_fingerprints=False,
                                        **fp_db.get_pipeline_options())


In [None]:

logger.info(f"Datasets - building pipeline for evaluation")
fp_dataset_eval_ = gen.smiles_pipeline(dataset_eval, 
                                    batch_size = sc.config['batch_size'],
                                    map_fingerprints=False,
                                    **db_eval.get_pipeline_options())

logger.info(f"Datasets - pipelines built")


In [13]:
next(iter(fp_dataset_train_['fingerprint']))

<tf.Tensor: shape=(256, 684), dtype=float32, numpy=
array([[65., 65., 65., ..., 65., 65., 61.],
       [65., 65., 65., ..., 65., 65., 61.],
       [65., 65., 67., ..., 65., 65., 61.],
       ...,
       [65., 65., 65., ..., 65., 65., 61.],
       [65., 67., 65., ..., 65., 65., 61.],
       [65., 65., 65., ..., 65., 65., 61.]], dtype=float32)>

In [None]:

# If fingerprint sampling is configured: load the sampler and map it
if sampler_module is not None:
    logger.info(f"Sampler {sampler_name} loading")
    sampler_factory = sampler_module.SamplerFactory(sc.config)
    sampler = sampler_factory.get_sampler()
    logger.info(f"Sampler {sampler_name} loaded")
    fp_dataset_train_ = sampler.map_dataset(fp_dataset_train_)
    fp_dataset_val_ = sampler.map_dataset(fp_dataset_val_)

fp_dataset_train = gen.dataset_zip(fp_dataset_train_, pipeline_x, pipeline_y,
                                   **fp_db.get_pipeline_options())
fp_dataset_train = fp_dataset_train.repeat(sc.config['epochs'])
fp_dataset_train = fp_dataset_train.prefetch(tf.data.experimental.AUTOTUNE)

blueprints = gen.dataset_blueprint(fp_dataset_train_)

fp_dataset_val = gen.dataset_zip(fp_dataset_val_, pipeline_x, pipeline_y,
                                 **fp_db.get_pipeline_options())
fp_dataset_val = fp_dataset_val.repeat(sc.config['epochs'])
fp_dataset_val = fp_dataset_val.prefetch(tf.data.experimental.AUTOTUNE)

fp_dataset_eval = gen.dataset_zip(fp_dataset_eval_, pipeline_x, pipeline_y,
                                  **db_eval.get_pipeline_options())
fp_dataset_eval = fp_dataset_eval.prefetch(tf.data.experimental.AUTOTUNE)

training_total = len(fp_train)
validation_total= len(fp_val)
training_steps = math.floor(training_total /  sc.config['batch_size'])
if sc.config['steps_per_epoch'] > 0:
    training_steps = sc.config['steps_per_epoch']

validation_steps = math.floor(validation_total /  sc.config['batch_size'])
if sc.config['steps_per_epoch_validation'] > 0:
    validation_steps = sc.config['steps_per_epoch_validation']
    
batch_size = sc.config["batch_size"]
epochs=sc.config['epochs']

logger.info(f"Preparing training: {epochs} epochs, {training_steps} steps per epoch, batch size {batch_size}")


round_fingerprints = False
if sampler_name != '':
    round_fingerprints = sampler_factory.round_fingerprint_inference()

import model
transcoder_model = model.TranscoderModel(
    blueprints = blueprints,
    config = sc.config,
    round_fingerprints = round_fingerprints
    )

initial_epoch = 0


logger.info("Building model")

transcoder_model.compile()
#
# If set correspondingly: load weights and continue training
if 'continue_training_epoch' in sc.config: 
    if sc.config['continue_training_epoch'] > 0:
        transcoder_model.load_weights(os.path.join(
            sc.config['weights_folder'],
            sc.config['weights']))
        transcoder_model._make_train_function()
        with open(os.path.join(
                sc.config['weights_folder'],
                sc.config['weights_optimizer']), 'rb') as f:
            weight_values = pickle.load(f)
        transcoder_model.optimizer.set_weights(weight_values)
        initial_epoch = sc.config['continue_training_epoch']


logger.info("Model built")
# {eval_loss:.3f}
filepath= os.path.join(
    weights_path,
    "w-{epoch:02d}-{loss:.3f}-{val_loss:.3f}.hdf5"
    )


tensorflow_trace = sc.config["tensorflow_trace"]
if tensorflow_trace:
    tensorboard_profile_batch = 2
else:
    tensorboard_profile_batch = 0
verbose = sc.config["training_verbose"]

checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, 
                             save_best_only=True, mode='min', 
                             save_weights_only=True)
tensorboard = TensorBoard(log_dir=log_path, 
                          histogram_freq=1,  
                          profile_batch = tensorboard_profile_batch,
                          write_graph=tensorflow_trace,
                          write_images=tensorflow_trace)

save_optimizer = model.resources.SaveOptimizerCallback(weights_path)
evaluation = model.resources.AdditionalValidationSet(fp_dataset_eval, 
                                                     "eval", 
                                                     verbose = 0)

print_logs = LambdaCallback(
    on_epoch_end = lambda epoch, logs: print(logs)
    )

json_log = open(os.path.join(weights_path, 'loss_log.json'),
 mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps({'epoch': epoch, 'loss': logs}) + '\n'),
    on_train_end=lambda logs: json_log.close()
)
#

callbacks_list = [evaluation, 
                  tensorboard, 
                  print_logs, 
                  json_logging_callback,
                  checkpoint, 
                  save_optimizer]

logger.info("Training - start")
transcoder_model.fit(x=fp_dataset_train, 
          epochs=epochs, 
          #batch_size=sc.config['batch_size'],
          steps_per_epoch=training_steps,
          callbacks = callbacks_list,
          validation_data = fp_dataset_val,
          validation_steps = validation_steps,
          initial_epoch = initial_epoch,
          verbose = verbose)
logger.info("Training - done")
fp_db.close()

logger.info("training end")

