# 6: Comparative analysis component

This component handles comparing the data generated by the two deep abstract models with the testing datasets generated by the SSA.

## Step 0: Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from simulation_manager import SimulationManager
from analysis_manager import AnalysisManager

import pickle
import os
from time import time

In [None]:
current_dir = os.getcwd()

In [None]:
def compare_per_config(name, path, data_configs):
    for case in data_configs:
        config = data_configs[case]
        end_time = config["end_time"]
        n_steps = config["n_steps"]
        n_init_conditions = config["n_init_conditions"]
        n_sims_per_init_condition = config["n_sims_per_init_condition"]
        time_step = end_time / n_steps

        config_name = f"{name}_{n_steps}_{end_time}_{n_init_conditions}_{n_sims_per_init_condition}"

        sm = SimulationManager(
            path_to_sbml=path,
            model_name=name,
            n_init_conditions=n_init_conditions,
            n_sims_per_init_condition=n_sims_per_init_condition,
            end_time=end_time,
            n_steps=n_steps
        )
        am = AnalysisManager()

        init_condition_indices = am.pick_random_initial_condition_indices(n_init_conditions, 2)

        with open(f"gan_datasets/{config_name}.pickle", "rb") as f:
            gan_data = pickle.load(f)

        with open(f"mdn_datasets/{config_name}.pickle", "rb") as f:
            mdn_data = pickle.load(f)
        
        with open(f"ssa_datasets/{config_name}_test.pickle", "rb") as f:
            ssa_data = pickle.load(f)

        ssa_data_split = am.split_dataset_by_initial_conditions(ssa_data, n_init_conditions)

        #################################### SSA vs. GAN ####################################
        gan_data_split = am.split_dataset_by_initial_conditions(gan_data, n_init_conditions)
        
        am.plot_trajectory_comparison_for_initial_conditions(
            ssa_data_split,
            gan_data_split,
            "SSA",
            "GAN",
            sm.get_species_names(),
            init_condition_indices
        )

        am.plot_distribution_comparison(
            ssa_data,
            gan_data,
            "SSA",
            "GAN",
            sm.get_species_names(),
            [1, n_steps // 2, n_steps]
        )

        am.plot_state_transitions(
            ssa_data,
            gan_data,
            "SSA",
            "GAN",
            sm.get_species_names(),
            n_steps - 1
        )

        am.compute_and_plot_mae_rmse(
            ssa_data,
            gan_data,
            sm.get_species_names()
        )

        first_moments_ssa, second_moments_ssa = am.compute_moments(ssa_data)
        first_moments_gan, second_moments_gan = am.compute_moments(gan_data)
        diff_first_moments = first_moments_gan - first_moments_ssa
        diff_second_moments = second_moments_gan - second_moments_ssa

        am.plot_moment_differences(diff_first_moments, diff_second_moments, sm.get_species_names())

        #################################### SSA vs. MDN ####################################
        mdn_data_split = am.split_dataset_by_initial_conditions(mdn_data, n_init_conditions)
        
        am.plot_trajectory_comparison_for_initial_conditions(
            ssa_data_split,
            mdn_data_split,
            "SSA",
            "MDN",
            sm.get_species_names(),
            init_condition_indices
        )

        am.plot_distribution_comparison(
            ssa_data,
            mdn_data,
            "SSA",
            "MDN",
            sm.get_species_names(),
            [1, n_steps // 2, n_steps]
        )

        am.plot_state_transitions(
            ssa_data,
            mdn_data,
            "SSA",
            "MDN",
            sm.get_species_names(),
            n_steps - 1
        )

        am.compute_and_plot_mae_rmse(
            ssa_data,
            mdn_data,
            sm.get_species_names()
        )

        first_moments_ssa, second_moments_ssa = am.compute_moments(ssa_data)
        first_moments_mdn, second_moments_mdn = am.compute_moments(mdn_data)
        diff_first_moments = first_moments_mdn - first_moments_ssa
        diff_second_moments = second_moments_mdn - second_moments_ssa

        am.plot_moment_differences(diff_first_moments, diff_second_moments)

        am.plot_histograms(

## Step 1: Multifeedback model

In [None]:
relative_path = "crn_models/1_multifeedback.txt"
path = os.path.join(current_dir, relative_path)
name = "multifeedback"

In [None]:
data_config = {
    # depth
    "case_1": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 100,
        "n_sims_per_init_condition": 200
    },
    # breadth
    "case_2": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100
    }
}

In [None]:
compare_per_config(name, path, data_config)

## Step 2: Repressilator model

In [None]:
relative_path = "crn_models/2_repressilator.txt"
path = os.path.join(current_dir, relative_path)
name = "repressilator"

In [None]:
gen_config = {
    # breadth
    "case_1": {
        "end_time": 128,
        "n_steps": 32,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100
    }
}

In [None]:
compare_per_config(name, path, data_config)