# Training a neural network on each example from the DeepCME paper.

For each CRN, the following configuration holds:
- End time: 1 second
- Time steps: 50
- Initial concentration for all species: 0

## Helper functions

In [None]:
import random
random.seed(42)

def generate_colours(num_colours):
    colours = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                 for i in range(num_colours)]
    return colours

COLOURS = generate_colours(50)   

## Initialise

In [None]:
# Jupyter magic to keep track of file changes in real-time
%load_ext autoreload
%autoreload 2

# Core imports
import os, sys, json, h5py, pickle, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from importlib import import_module
from time import time
from pathlib import Path
from pprint import pprint

# StochNetV2 library imports (unchanged files from the framework)
from stochnet_v2.dataset.dataset import DataTransformer, HDF5Dataset

from stochnet_v2.static_classes.model import StochNet
from stochnet_v2.static_classes.trainer import ToleranceDropLearningStrategy

from stochnet_v2.dynamic_classes.model import NASStochNet
from stochnet_v2.dynamic_classes.trainer import Trainer

from stochnet_v2.utils.file_organisation import ProjectFileExplorer
from stochnet_v2.utils.util import merge_species_and_param_settings, plot_random_traces, visualize_genotypes

# StochNetV2 local imports (files that have been modified)
from simulation_gillespy import build_simulation_dataset
from util import generate_gillespy_traces
from evaluation import evaluate

np.random.seed(42)

print(f"GPU is available: {tf.test.is_built_with_cuda(), tf.test.is_gpu_available()}")

## Define CRNs

In [None]:
# (network name, number of species, [params to randomise])
networks = [
    ('BD5', 5, ['k', 'gamma']),
    ('BD10', 10, ['k', 'gamma']),
    ('BD20', 20, ['k', 'gamma']),
    ('LSC2', 2, ['beta_0', 'k', 'gamma']),
    ('LSC5', 5, ['beta_0', 'k', 'gamma']),
    ('LSC10', 10, ['beta_0', 'k', 'gamma']),
    ('NSC2', 2, ['b', 'k_m', 'k_0', 'H', 'beta_0', 'gamma']),
    ('NSC5', 5, ['b', 'k_m', 'k_0', 'H', 'beta_0', 'gamma']),
    ('NSC10', 10, ['b', 'k_m', 'k_0', 'H', 'beta_0', 'gamma']),
    ('LSCF2', 2, ['b', 'k_m', 'k_0', 'H', 'k', 'gamma']),
    ('LSCF5', 5, ['b', 'k_m', 'k_0', 'H', 'k', 'gamma']),
    ('LSCF10', 10, ['b', 'k_m', 'k_0', 'H', 'k', 'gamma'])
]

## Train NNs on 1000-trace datasets

In [None]:
for i, (name, n_species, params) in enumerate(networks):
    print(f">>> Working with model {name}.")
    # Configure the model parameters
    model_name = name
    timestep = 0.02
    endtime = 1.0 
    dataset_id = name
    model_id = name
    nb_features = n_species
    params_to_randomize = params

    # Configure the simulation parameters
    nb_settings = 500
    nb_trajectories = 10

    nb_histogram_settings = 500
    nb_histogram_trajectories = 10

    # File-handling and housekeeping
    project_folder = Path('').parent.resolve()/model_name
    project_explorer = ProjectFileExplorer(project_folder)
    dataset_explorer = project_explorer.get_dataset_file_explorer(timestep, dataset_id)
    model_explorer = project_explorer.get_model_file_explorer(timestep, model_id)

    body_config_path = model_explorer.body_config_fp
    mixture_config_path = model_explorer.mixture_config_fp

    CRN_module = import_module(model_name)
    CRN_class = getattr(CRN_module, model_name)

    # Generate and save the initial species concentrations
    settings = CRN_class.get_initial_settings(nb_settings)
    print(f"Settings shape: {settings.shape}")
    print(f"Saving settings to {dataset_explorer.settings_fp}\n")
    np.save(dataset_explorer.settings_fp, settings)

    histogram_settings = CRN_class.get_initial_settings(nb_histogram_settings)
    print(f"Histogram settings shape: {histogram_settings.shape}")
    print(f"Saving histogram_settings to {dataset_explorer.histogram_settings_fp}")
    np.save(dataset_explorer.histogram_settings_fp, histogram_settings)

    # Generate the dataset of trajectories of shape (n_settings * n_trajectories, n_steps, n_species)
    dataset = build_simulation_dataset(
        model_name,                              
        nb_settings,                            
        nb_trajectories,                         
        timestep,                                
        endtime,                                 
        dataset_explorer.dataset_folder,         
        params_to_randomize=params_to_randomize,
        how='concat'
    )

    np.save(dataset_explorer.dataset_fp, dataset)

    # Generate the histogram dataset of shape (n_settings, n_trajectories, n_steps, n_species)
    histogram_dataset = build_simulation_dataset(
        model_name,
        nb_histogram_settings,
        nb_histogram_trajectories,
        timestep,
        endtime,
        dataset_explorer.dataset_folder,
        params_to_randomize=params_to_randomize,
        prefix='histogram_partial_',
        how='stack',
        settings_filename=os.path.basename(dataset_explorer.histogram_settings_fp),
    )

    np.save(dataset_explorer.histogram_dataset_fp, histogram_dataset)

    # Transform the dataset
    dt = DataTransformer(
        dataset_explorer.dataset_fp,
        with_timestamps=True,
        nb_randomized_params=len(params_to_randomize)
    )

    # Convert and save as HDF5
    dt.save_data_for_ml_hdf5(
        dataset_folder=dataset_explorer.dataset_folder,
        nb_past_timesteps=1,
        test_fraction=0.2,
        keep_timestamps=False,
        rescale=True,
        positivity=False,
        shuffle=True,
        slice_size=100,
        force_rewrite=True
    )
    
    # Architecture parameters
    body_n_cells = 2
    body_cell_size = 2
    body_expansion_multiplier = 20
    body_n_states_reduce = 2
    body_kernel_constraint = "none"
    body_bias_constraint = "none"
    body_kernel_regularizer = "l2"
    body_bias_regularizer = "l2"
    body_regularizer = "none"

    components_hidden_size = "none"
    n_normal_diag = 6
    n_normal_tril = 0
    n_log_normal_tril = 0
    components_activation = "none"
    components_regularizer = "none"
    components_kernel_constraint = "none"
    components_bias_constraint = "none"
    components_kernel_regularizer = "l2"
    components_bias_regularizer = "l2"

    # Architecture configurations
    body_config = {
        "n_cells": body_n_cells,
        "cell_size": body_cell_size,
        "expansion_multiplier": body_expansion_multiplier,
        "n_states_reduce": body_n_states_reduce,
        "kernel_constraint": body_kernel_constraint,
        "kernel_regularizer": body_kernel_regularizer,
        "bias_constraint": body_bias_constraint,
        "bias_regularizer": body_bias_regularizer,
        "activity_regularizer": body_regularizer,
    }

    categorical_config = {
        "hidden_size": components_hidden_size,
        "activation": components_activation,
        "coeff_regularizer": "none",
        "kernel_constraint": body_kernel_constraint,  # unitnorm
        "bias_constraint": body_bias_constraint,  # unitnorm
        "kernel_regularizer": components_kernel_regularizer,
        "bias_regularizer": components_bias_regularizer
    }

    normal_diag_config = {
        "hidden_size": components_hidden_size,
        "activation": components_activation,
        "mu_regularizer": components_regularizer,
        "diag_regularizer": "l2",
        "kernel_constraint": components_kernel_constraint,
        "bias_constraint": components_bias_constraint,
        "kernel_regularizer": components_kernel_regularizer,
        "bias_regularizer": components_bias_regularizer
    }

    normal_tril_config = {
        "hidden_size": components_hidden_size,
        "activation": components_activation,
        "mu_regularizer": components_regularizer,
        "diag_regularizer": components_regularizer,
        "sub_diag_regularizer": components_regularizer,
        "kernel_constraint": components_kernel_constraint,
        "bias_constraint": components_bias_constraint,
        "kernel_regularizer": components_kernel_regularizer,
        "bias_regularizer": components_bias_regularizer
    }

    log_normal_tril_config = {
        "hidden_size": components_hidden_size,
        "activation": components_activation,
        "mu_regularizer": components_regularizer,
        "diag_regularizer": components_regularizer,
        "sub_diag_regularizer": components_regularizer,
        "kernel_constraint": components_kernel_constraint,
        "bias_constraint": components_bias_constraint,
        "kernel_regularizer": components_kernel_regularizer,
        "bias_regularizer": components_bias_regularizer
    }

    # Write the configurations to disk
    mixture_config = \
    [["categorical", categorical_config]] + \
    [["normal_diag", normal_diag_config] for i in range(n_normal_diag)] + \
    [["normal_tril", normal_tril_config] for i in range(n_normal_tril)] + \
    [["log_normal_tril", log_normal_tril_config] for i in range(n_log_normal_tril)]

    with open(body_config_path, 'w+') as f:
        json.dump(body_config, f, indent='\t')

    with open(mixture_config_path, 'w+') as f:
        json.dump(mixture_config, f, indent='\t')
        
    # Training parameters
    n_epochs_main = 100
    n_epochs_heat_up = 20
    n_epochs_interval = 5
    n_epochs_arch = 5
    n_epochs_finetune = 40
    
    batch_size = 256
    dataset_kind = 'hdf5'
    add_noise = False
    stddev = 0.01

    # Trainng strategy
    learning_strategy_main = ToleranceDropLearningStrategy(
        optimizer_type='adam',
        initial_lr=1e-4,
        lr_decay=0.3,
        epochs_tolerance=7,
        minimal_lr=1e-7,
    )

    learning_strategy_arch = ToleranceDropLearningStrategy(
        optimizer_type='adam',
        initial_lr=1e-3,
        lr_decay=0.5,
        epochs_tolerance=20,
        minimal_lr=1e-7,
    )

    learning_strategy_finetune = ToleranceDropLearningStrategy(
        optimizer_type='adam',
        initial_lr=1e-4,
        lr_decay=0.3,
        epochs_tolerance=5,
        minimal_lr=1e-7,
    )
    
    nn = NASStochNet(
        nb_past_timesteps=1,
        nb_features=nb_features,
        nb_randomized_params=len(params_to_randomize),
        project_folder=project_folder,
        timestep=timestep,
        dataset_id=dataset_id,
        model_id=model_id,
    )
    
    start = time()
    ckpt_path = None
    ckpt_path = Trainer().train(
        nn,
        n_epochs_main=n_epochs_main,
        n_epochs_heat_up=n_epochs_heat_up,
        n_epochs_arch=n_epochs_arch,
        n_epochs_interval=n_epochs_interval,
        n_epochs_finetune=n_epochs_finetune,
        batch_size=batch_size,
        learning_strategy_main=learning_strategy_main,
        learning_strategy_arch=learning_strategy_arch,
        learning_strategy_finetune=learning_strategy_finetune,
        ckpt_path=ckpt_path,
        dataset_kind=dataset_kind,
        add_noise=add_noise,
        stddev=stddev,
        mode=['search', 'finetune']
    )
    end = time()
    time_taken = end - start
    
    with open(os.path.join(nn.model_explorer.model_folder, 'genotypes.pickle'), 'rb') as f:
        genotypes = pickle.load(f)

    visualize_genotypes(genotypes, fr'{name}\genotypes_{name}')
    
    distance_kind = 'dist'
    target_species_names = [ f'S{i+1}' for i in range(n_species)]
    # target_species_names = [ f'S{n_species}']
    time_lag_range = [1, 3, 5, 10, 15, 20]
    settings_idxs_to_save_histograms = [i for i in range(nb_settings)]

    histogram_explorer = dataset_explorer.get_histogram_file_explorer(model_id, 0)
    nn_histogram_data_fp = os.path.join(histogram_explorer.model_histogram_folder, 'nn_histogram_data.npy')

    evaluate(
        model_name=model_name,
        project_folder=project_folder,
        timestep=timestep,
        dataset_id=dataset_id,
        model_id=model_id,
        nb_randomized_params=len(params_to_randomize),
        nb_past_timesteps=1,
        n_bins=100,
        distance_kind=distance_kind,
        with_timestamps=True,
        save_histograms=True,
        time_lag_range=time_lag_range,
        target_species_names=target_species_names,
        path_to_save_nn_traces=nn_histogram_data_fp,
        settings_idxs_to_save_histograms=settings_idxs_to_save_histograms,
    )
    
    # Initialise the network and corresponding parameters
    nn = StochNet(
        nb_past_timesteps=1,
        nb_features=nb_features,
        nb_randomized_params=len(params_to_randomize),
        project_folder=project_folder,
        timestep=timestep,
        dataset_id=dataset_id,
        model_id=model_id,
        mode='inference'
    )

    n_settings = 500
    traj_per_setting = 10
    n_steps = 50

    m = CRN_class(endtime, timestep)

    initial_settings = m.get_initial_settings(n_settings)
    randomized_params = m.get_randomized_parameters(params_to_randomize, n_settings)
    settings = merge_species_and_param_settings(initial_settings, randomized_params)

    # Get the current state to be fed into the network
    setting_idx = np.random.randint(0, n_settings)
    curr_state = settings[setting_idx:setting_idx+1, np.newaxis, :]
    print(f"Current state shape: {curr_state.shape}")
    print(f"Current state: {curr_state}")
    
    # Predict the next state
    next_state_samples = nn.next_state(
        curr_state_values=curr_state,
        curr_state_rescaled=False,
        scale_back_result=True,
        round_result=False,
        n_samples=10000,
    )

    random_sample_idx = np.random.randint(0, 10000)

    print(f"Shape: {next_state_samples.shape}")
    print(f"Random sample index: {random_sample_idx}")
    print(f"Random concentration prediction:\n{next_state_samples[random_sample_idx, :, :, :]}")
    
    # Visualise the distribution of concentration predictions for a single species across all samples
    species = CRN_class.get_species_for_histogram()
    for i, s in enumerate(species):
        samples = np.squeeze(next_state_samples, -2)[..., i]
        _ = plt.hist(samples, bins=50)
    plt.legend(species)
    plt.savefig(fr"{name}\{name}-concetrations.png")
    plt.close()
    
    # Run and time a gillespy2 simulation
    start = time()

    gillespy_traces = generate_gillespy_traces(
        settings=settings,
        n_steps=n_steps,
        timestep=timestep,
        gillespy_model=m,
        params_to_randomize=params_to_randomize,
        traj_per_setting=traj_per_setting,
    )

    gillespy_time = time() - start

    # Run and time a neural network simulation
    start = time()

    nn_traces = nn.generate_traces(
        settings[:, np.newaxis, :],
        n_steps=n_steps,
        n_traces=traj_per_setting,
        curr_state_rescaled=False,
        scale_back_result=True,
        round_result=True,
        add_timestamps=True,
    )

    nn_time = time() - start
    
    print(f"Gillespy2 shape and time: {gillespy_traces.shape, gillespy_time}")
    print(f"StochNetV2 shape and time: {nn_traces.shape, nn_time}")
    
    k = 1
    n_traces = 1

    plt.figure(figsize=(16, 10))
    plot_random_traces(gillespy_traces[k][...,:nb_features+1], n_traces, linestyle='--', marker='')
    plot_random_traces(nn_traces[k], n_traces, linestyle='-', marker='')
    plt.savefig(fr"{name}\{name}-ssa-vs-nn-traces.png")
    plt.close()
    
    new_traces = np.reshape(nn_traces, (traj_per_setting*n_settings, n_steps+1, n_species+1))
    print(f"new traces shape: {new_traces.shape}")
        
    print(f">>> FINISHED WITH {name}")
    print(f"Generated traces shape: {new_traces.shape}")