# 4: GAN data generation component

This handles generating new data using the trained GAN deep abstract models produced in 2.

## Step 0: Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from simulation_manager import SimulationManager
from gan_manager import GanManager

import pickle
import os
from time import time

In [None]:
current_dir = os.getcwd()

In [None]:
def gen_per_config(name, path, data_configs, plot=False):
    simulation_times = dict()
    if not os.path.exists("gan_datasets"):
        os.makedirs("gan_datasets")
        
    for case in data_configs:
        config = data_configs[case]
        end_time = config["end_time"]
        n_steps = config["n_steps"]
        n_init_conditions = config["n_init_conditions"]
        n_sims_per_init_condition = config["n_sims_per_init_condition"]
        n_epochs = config["n_epochs"] # this is only for identifying the trained model

        config_name = f"{name}_{n_steps}_{end_time}_{n_init_conditions}_{n_sims_per_init_condition}"

        sm = SimulationManager(
            path_to_sbml=path,
            model_name=name,
            n_init_conditions=n_init_conditions,
            n_sims_per_init_condition=n_sims_per_init_condition,
            end_time=end_time,
            n_steps=n_steps
        )

        gm = GanManager(
            model_name=sm.model_name,
            n_epochs=n_epochs,
            batch_size=256,
            species_names=sm.get_species_names(),
            x_dim=sm.get_num_species(),
            traj_len=n_steps,
            end_time=end_time,
            id=config_name
        )

        gm.load_trained_model()
        with open(f"ssa_datasets/{config_name}_test.pickle", "rb") as f:
            testing_data = pickle.load(f)

        # placeholder calls, needed to initialise the data layer of the GanManager
        gan_train_data = sm.transform_data_for_gan(testing_data)
        gan_test_data = sm.transform_validation_data_for_gan(testing_data, n_sims_per_init_condition)
        gm.init_dataset(
            train_data=gan_train_data,
            test_data=gan_test_data
        )

        ic = sm.extract_initial_conditions_from_dataset(testing_data, n_init_conditions)
        start_time = time()
        gan_data = gm.generate_trajectories_for_init_conditions(ic, n_sims_per_init_condition=n_sims_per_init_condition)
        time_taken = time() - start_time

        simulation_times[case] = time_taken

        if plot:
            print("Plotting...")
            sm.plot_simulations(
                f"plots/{config_name}__gan",
                gan_data,
                n_init_conditions,
                n_sims_per_init_condition,
                sm.get_column_names()
            )

        with open(f"gan_datasets/{config_name}.pickle", "wb") as f:
            pickle.dump(gan_data, f)
        print(f"Generated GAN data of shape {gan_data.shape}.")

## Step 1: Multifeedback model

In [None]:
relative_path = "crn_models/1_multifeedback.txt"
path = os.path.join(current_dir, relative_path)
name = "multifeedback"

In [None]:
gen_config = {
    # depth
    "case_1": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 100,
        "n_sims_per_init_condition": 200,
        "n_epochs": 2
    },
    # breadth
    "case_2": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 2
    }
}

In [None]:
gen_per_config(name, path, gen_config, plot=True)

## Step 2: Repressilator model

In [None]:
relative_path = "crn_models/2_repressilator.txt"
path = os.path.join(current_dir, relative_path)
name = "repressilator"

In [None]:
gen_config = {
    # breadth
    "case_1": {
        "end_time": 128,
        "n_steps": 32,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 2
    }
}

In [None]:
gen_per_config(name, path, gen_config, plot=True)