# 3: MDN training component

This handles the training of the GAN deep abstraction method using the datasets generated in 1.

## Step 0: Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from simulation_manager import SimulationManager
from mdn_manager import MdnManager

import numpy as np
import pandas as pd
import tellurium as te
import pickle
import os
from time import time

In [None]:
current_dir = os.getcwd()

In [None]:
def train_per_config(name, path, train_config):
    training_times = dict()
    if not os.path.exists("trained_mdn"):
        os.makedirs("trained_mdn")

    for case in train_config:
        config = train_config[case]
        end_time = config["end_time"]
        n_steps = config["n_steps"]
        n_init_conditions = config["n_init_conditions"]
        n_sims_per_init_condition = config["n_sims_per_init_condition"]
        n_epochs = config["n_epochs"]
        batch_size = config["batch_size"]
        patience = config["patience"]
        time_step = end_time / n_steps

        config_name = f"{name}_{n_steps}_{end_time}_{n_init_conditions}_{n_sims_per_init_condition}"   

        with open(f"ssa_datasets/{config_name}_train.pickle", "rb") as f:
            training_data = pickle.load(f)

        sm = SimulationManager(
            path_to_sbml=path,
            model_name=name,
            n_init_conditions=n_init_conditions,
            n_sims_per_init_condition=n_sims_per_init_condition,
            end_time=end_time,
            n_steps=n_steps
        )

        mm = MdnManager(
            sm.model_name,
            sm.get_num_species()
        )

        mm.load_data(training_data)
        mm.prepare_data_loaders(batch_size=batch_size)

        start_time = time()
        mm.train(n_epochs=n_epochs, patience=patience)
        time_taken = time() - start_time

        training_times[case] = time_taken
        mm.validate()

        mm.save_model(f"trained_mdn/{config_name}")
        print(f"Finished training MDN for {case}.")
    
    with open(f"{name}_mdn_training_times.pickle", "wb") as f:
        pickle.dump(training_times, f)

## Step 1: Multifeedback model

In [None]:
relative_path = "crn_models/1_multifeedback.txt"
path = os.path.join(current_dir, relative_path)
name = "multifeedback"

In [None]:
train_config = {
    "case_1": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 100,
        "n_sims_per_init_condition": 200,
        "n_epochs": 100,
        "batch_size": 256,
        "patience": 20
    },
    "case_2": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 100,
        "batch_size": 256,
        "patience": 20
    }
}

In [None]:
train_per_config(name, path, train_config)

## Step 2: Repressilator model

In [None]:
relative_path = "crn_models/2_repressilator.txt"
path = os.path.join(current_dir, relative_path)
name = "repressilator"

In [None]:
train_config = {
    "case_1": {
        "end_time": 128,
        "n_steps": 32,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 100,
        "batch_size": 256,
        "patience": 20
    },
}

In [None]:
train_per_config(name, path, train_config)