# Estimating parameter sensitivities of moment estimations with trained StochNetV2 neural networks

## Helper functions

In [None]:
import random
random.seed(42)

def generate_colours(num_colours):
    colours = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                 for i in range(num_colours)]
    return colours

COLOURS = generate_colours(50)

def get_moments(dataset, n_species):
    total_1 = 0
    total_2 = 0
    counter = 0
    for trajectory in dataset:
        concentration = trajectory[-1, n_species]
        total_1 += concentration
        total_2 += np.power(concentration, 2)
        counter += 1
        
    avg_1 = total_1 / counter
    avg_2 = total_2 / counter

    return avg_1, avg_2

def plot_moment(data, name, moment, ssa=True):
    # create a dataset
    height = list(data.values())
    bars = list(data.keys())
    x_pos = np.arange(len(bars))

    # Create bars with different colors
    plt.bar(x_pos, height, color=COLOURS[:len(bars)])

    # Create names on the x-axis
    plt.xticks(x_pos, bars)

    # Show graph
    # plt.show()
    if ssa:
        plt.title(f"{name}, {moment} moment, ssa")
        plt.savefig(f"{name}-{moment}-ssa.png")
    else:
        plt.title(f"{name}, {moment} moment, nn")
        plt.savefig(f"{name}-{moment}-nn.png")
        
    plt.close()

def plot_moment_comparison(ssa_data, nn_data, name, moment):
    ssa_avg = sum(ssa_data.values()) / len(ssa_data)
    nn_avg = sum(nn_data.values()) / len(nn_data)
    
    height = [nn_avg, ssa_avg]
    bars = ['nn', 'ssa']
    x_pos = np.arange(len(bars))
    
    # Create bars with different colors
    plt.bar(x_pos, height, color=['#4266f5', '#f59042'])
    
    # Create names on the x-axis
    plt.xticks(x_pos, bars)
    
    for index, value in enumerate(height):
        plt.text(index, value, f"{value:.2f}")
        
    plt.title(f"{moment} moment")
    
    # Save plot
    plt.savefig(f"comparison-{name}-{moment}.png")
    plt.close()
    
def plot_output_fluctuations(param_output_list, param_name, model_name):
    g_1_results = [(t[0], t[1]) for t in param_output_list]
    g_2_results = [(t[0], t[2]) for t in param_output_list]
    
    x = [t[0] for t in g_1_results]
    y1 = [t[1] for t in g_1_results]
    y2 = [t[1] for t in g_2_results]
    
    plt.subplot(2, 1, 1)
    plt.plot(x, y1)
    plt.title(f'g_1 fluctuations | {model_name}')
    plt.ylabel('function output deviation')
    plt.xlabel(f'{param_name} deviation')

    plt.subplot(2, 1, 2)
    plt.plot(x, y2)
    plt.title(f'g_2 fluctuations | {model_name}')
    plt.ylabel('function output deviation')
    plt.xlabel(f'{param_name} deviation')
    
    plt.tight_layout()

    plt.savefig(f'{model_name}_{param_name}_plots.png')
    plt.close()
    
def compute_param_sensitivity(outputs, param_name, moment):
    if moment == 'g_1':
        slopes = [t[1] / t[0] for t in outputs[param_name]]
    else:
        slopes = [t[2] / t[0] for t in outputs[param_name]]
    tot = 0
    counter = 0
    for s in slopes:
        if -100000 < s < 100000:
            tot += s
            counter += 1
    return tot / counter

def plot_param_sensitivities(model_name, param_names, param_sensitivities, moment):
    values = [param_sensitivities[i] for i in range(len(param_names))]
    
    plt.bar(param_names, values)
    plt.ylabel(f'Parameter Sensitivity Estimate')
    if moment == 'g_1':
        plt.title(r"$\mathbb{E}(X_n)$" + f" | {model_name}")
    else:
        plt.title(r"$\mathbb{E}(X^2_n)$" + f" | {model_name}")
    
    plt.savefig(f'{model_name}_{moment}_param_sensitivities.png')
    plt.close()    

## Imports

In [None]:
# Jupyter magic to keep track of file changes in real-time
%load_ext autoreload
%autoreload 2

# Core imports
import os, sys, json, h5py, pickle, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from importlib import import_module
from time import time
from pathlib import Path
from pprint import pprint

# StochNetV2 library imports
from stochnet_v2.dataset.dataset import DataTransformer, HDF5Dataset

from stochnet_v2.static_classes.model import StochNet
from stochnet_v2.static_classes.trainer import ToleranceDropLearningStrategy

from stochnet_v2.dynamic_classes.model import NASStochNet
from stochnet_v2.dynamic_classes.trainer import Trainer

from stochnet_v2.utils.file_organisation import ProjectFileExplorer
from stochnet_v2.utils.util import merge_species_and_param_settings, plot_random_traces, visualize_genotypes
# from stochnet_v2.utils.evaluation import evaluate

# StochNetV2 local imports
from simulation_gillespy import build_simulation_dataset
from util import generate_gillespy_traces
from evaluation import evaluate

np.random.seed(42)

print(f"GPU is available: {tf.test.is_built_with_cuda(), tf.test.is_gpu_available()}")

## Estimate parameter sensitivities

In [None]:
networks = [
    ('BD20', 20, {'k': np.array([10.]), 'gamma': np.array([1.])}),
    ('NSC5', 5, {'b':np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'beta_0': np.array([10.]), 'gamma': np.array([1.])}),
    ('LSC2', 2, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSC5', 5, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSC10', 10, {'beta_0': np.array([10.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSCF5', 5, {'b': np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'k': np.array([5.]), 'gamma': np.array([1.])}),
    ('LSCF2', 2, {'b': np.array([1.]), 'k_m': np.array([100.]), 'k_0': np.array([10.]), 'H': np.array([1.]), 'k': np.array([5.]), 'gamma': np.array([1.])})
]

In [None]:
start = time()
for model_name, nb_features, params_to_randomize in networks:
    print(f"Producing traces with {model_name}.")
    timestep = 0.02
    endtime =  1.0
    dataset_id = model_name
    model_id = model_name

    project_folder = Path(r'trained_models/randomised')/model_name
    project_explorer = ProjectFileExplorer(project_folder)
    dataset_explorer = project_explorer.get_dataset_file_explorer(timestep, dataset_id)
    model_explorer = project_explorer.get_model_file_explorer(timestep, model_id)

    CRN_module = import_module(model_name)
    CRN_class = getattr(CRN_module, model_name)

    nn = StochNet(
            nb_past_timesteps=1,
            nb_features=nb_features,
            nb_randomized_params=len(params_to_randomize),
            project_folder=project_folder,
            timestep=timestep,
            dataset_id=dataset_id,
            model_id=model_id,
            mode='inference'
        )
    
    # get baseline outputs with default parameter values
    n_settings = 1
    traj_per_setting = 500
    n_steps = 50

    m = CRN_class(endtime, timestep)
    
    initial_settings = m.get_initial_settings(n_settings)
    settings = merge_species_and_param_settings(initial_settings, params_to_randomize)
    curr_state = settings[:, np.newaxis, :]
    print(curr_state)
    
    
    nn_traces = nn.generate_traces(
        curr_state,
        n_steps=n_steps,
        n_traces=traj_per_setting,
        curr_state_rescaled=False,
        scale_back_result=True,
        round_result=True,
        add_timestamps=True,
    )
    
    traces = np.reshape(nn_traces, (traj_per_setting, n_steps+1, nb_features+1))
    g_1_baseline, g_2_baseline = get_moments(traces, nb_features)
    print(f"Function outputs with default parameter values: {g_1_baseline}, {g_2_baseline}.")
    
    baseline_outputs = (g_1_baseline, g_2_baseline)
    # baseline_outputs = (21.24, 536.32)

    # controlled parameter generation:
    # 'sigma' is how far to shift the default parameter value percentage-wise;
    # 'randomised' will contain arrays of initial states where only one parameter
    # value changes, for each parameter.
    randomised = {}
    sigma = 0.2

    for param in params_to_randomize.keys():
        val = params_to_randomize[param]
        randomised[param] = [val+i for i in np.arange(-val*sigma, val*sigma, val/100)]
    print("Created value ranges for each parameter.")
    
    # list of tuples (param name, diff from the default value, starting state)
    starting_states = []
    n_params = len(params_to_randomize.keys())

    for param in params_to_randomize.keys():
        n_values = len(randomised[param])
        for i in range(n_values):
            diff = 0
            current_params = {}
            for p in params_to_randomize.keys():
                if p == param:
                    current_params[p] = randomised[p][i]
                    diff = params_to_randomize[p] - randomised[p][i]
                else:
                    current_params[p] = params_to_randomize[p]
            if diff != 0:
                state = np.array([0.0]*nb_features + [current_params[k] for k in current_params.keys()])
                state = np.reshape(state, (1, nb_features+n_params))
                state = state[:, np.newaxis, :]
                starting_states.append((param, diff, state))

    print(f"Generated {len(starting_states)} starting states for {n_params} parameters.")
    print(f"Example state:\n{starting_states[0]}")
    
    # perform trace generation
    outputs = {}

    for param in params_to_randomize.keys():
        # list of tuples [(param_diff, g_1_diff, g_2_diff), ...]
        outputs[param] = []
        
    for param, param_diff, starting_state in starting_states:
        nn_traces = nn.generate_traces(
            starting_state,
            n_steps=n_steps,
            n_traces=traj_per_setting,
            curr_state_rescaled=False,
            scale_back_result=True,
            round_result=True,
            add_timestamps=True,
        )

        traces = np.reshape(nn_traces, (traj_per_setting, n_steps+1, nb_features+1))
        g_1, g_2 = get_moments(traces, nb_features)

        g_1_diff = baseline_outputs[0] - g_1
        g_2_diff = baseline_outputs[1] - g_2

        outputs[param].append((param_diff, g_1_diff, g_2_diff))

    print('Done.')
    
    for param_name in outputs.keys():
        plot_output_fluctuations(outputs[param_name], param_name, model_name)
        
    for moment in ['g_1', 'g_2']:
        param_sensitivities = []
        for param in outputs.keys():
            param_sensitivities.append(float(compute_param_sensitivity(outputs, param, moment)))
        
        plot_param_sensitivities(model_name, outputs.keys(), param_sensitivities, moment)
end = time()

In [None]:
print(f"time taken: {end - start}")