# Estimating species concentration moments with trained StochNetV2 neural networks

## Helper functions

In [None]:
import random
random.seed(42)

def generate_colours(num_colours):
    colours = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                 for i in range(num_colours)]
    return colours

COLOURS = generate_colours(50)

def get_moments(dataset, n_species):
    total_1 = 0
    total_2 = 0
    counter = 0
    for trajectory in dataset:
        concentration = trajectory[-1, n_species]
        total_1 += concentration
        total_2 += np.power(concentration, 2)
        counter += 1
        
    avg_1 = total_1 / counter
    avg_2 = total_2 / counter

    return avg_1, avg_2

def plot_moment(data, name, moment, ssa=True):
    # create a dataset
    height = list(data.values())
    bars = list(data.keys())
    x_pos = np.arange(len(bars))

    # Create bars with different colors
    plt.bar(x_pos, height, color=COLOURS[:len(bars)])

    # Create names on the x-axis
    plt.xticks(x_pos, bars)

    # Show graph
    # plt.show()
    if ssa:
        plt.title(f"{name}, {moment} moment, ssa")
        plt.savefig(f"{name}-{moment}-ssa.png")
    else:
        plt.title(f"{name}, {moment} moment, nn")
        plt.savefig(f"{name}-{moment}-nn.png")
        
    plt.close()

def plot_moment_comparison(ssa_data, nn_data, name, moment):
    ssa_avg = sum(ssa_data.values()) / len(ssa_data)
    nn_avg = sum(nn_data.values()) / len(nn_data)
    
    height = [nn_avg, ssa_avg]
    bars = ['nn', 'ssa']
    x_pos = np.arange(len(bars))
    
    # Create bars with different colors
    plt.bar(x_pos, height, color=['#4266f5', '#f59042'])
    
    # Create names on the x-axis
    plt.xticks(x_pos, bars)
    
    for index, value in enumerate(height):
        plt.text(index, value, f"{value:.2f}")
        
    plt.title(f"{moment} moment")
    
    # Save plot
    plt.savefig(f"comparison-{name}-{moment}.png")
    plt.close()   

## Imports

In [None]:
# Jupyter magic to keep track of file changes in real-time
%load_ext autoreload
%autoreload 2

# Core imports
import os, sys, json, h5py, pickle, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from importlib import import_module
from time import time
from pathlib import Path
from pprint import pprint

# StochNetV2 library imports (unchanged files from the framework)
from stochnet_v2.dataset.dataset import DataTransformer, HDF5Dataset

from stochnet_v2.static_classes.model import StochNet
from stochnet_v2.static_classes.trainer import ToleranceDropLearningStrategy

from stochnet_v2.dynamic_classes.model import NASStochNet
from stochnet_v2.dynamic_classes.trainer import Trainer

from stochnet_v2.utils.file_organisation import ProjectFileExplorer
from stochnet_v2.utils.util import merge_species_and_param_settings, plot_random_traces, visualize_genotypes

# StochNetV2 local imports (files that have been modified)
from simulation_gillespy import build_simulation_dataset
from util import generate_gillespy_traces
from evaluation import evaluate

np.random.seed(42)

print(f"GPU is available: {tf.test.is_built_with_cuda(), tf.test.is_gpu_available()}")

## Compute moment estimations

In [None]:
networks = [
    ('BD20', 20, {'k': np.array([10.]), 'gamma': np.array([1.])}),
    ('NSC5', 5, {'b':np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'beta_0': np.array([10.]), 'gamma': np.array([1.])}),
    ('LSC2', 2, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSC5', 5, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSC10', 10, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSCF5', 5, {'b': np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSCF2', 2, {'b': np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'k': np.array([5.]), 'gamma': np.array([1.])})
]

In [None]:
for model_name, nb_features, params_to_randomize in networks:
    print(f"Producing traces with {model_name}.")
    timestep = 0.02
    endtime =  1.0
    dataset_id = model_name
    model_id = model_name

    # 'randomised' here signifies that the neural networks were trained with randomised
    # parameters
    project_folder = Path(r'trained_models/randomised')/model_name
    project_explorer = ProjectFileExplorer(project_folder)
    dataset_explorer = project_explorer.get_dataset_file_explorer(timestep, dataset_id)
    model_explorer = project_explorer.get_model_file_explorer(timestep, model_id)

    CRN_module = import_module(model_name)
    CRN_class = getattr(CRN_module, model_name)

    nn = StochNet(
            nb_past_timesteps=1,
            nb_features=nb_features,
            nb_randomized_params=len(params_to_randomize),
            project_folder=project_folder,
            timestep=timestep,
            dataset_id=dataset_id,
            model_id=model_id,
            mode='inference'
        )
    
    n_settings = 1
    traj_per_setting = 500
    n_steps = 50

    m = CRN_class(endtime, timestep)
    
    initial_settings = m.get_initial_settings(n_settings)
    settings = merge_species_and_param_settings(initial_settings, params_to_randomize)
    curr_state = settings[:, np.newaxis, :]
    print(curr_state)
    
    nn_traces = nn.generate_traces(
        curr_state,
        n_steps=n_steps,
        n_traces=traj_per_setting,
        curr_state_rescaled=False,
        scale_back_result=True,
        round_result=True,
        add_timestamps=True,
    )
    
    traces = np.reshape(nn_traces, (traj_per_setting, n_steps+1, nb_features+1))
    g_1, g_2 = get_moments(traces, nb_features)
    
    print(f"Function outputs with default parameter values for {model_name}: {g_1}, {g_2}.")