# 2: GAN training component

This handles the training of the GAN deep abstraction method using the datasets generated in 1.

## Step 0: Setup

In [3]:
%load_ext autoreload
%autoreload 2

In [13]:
from simulation_manager import SimulationManager
from gan_manager import GanManager

import pickle
import os
from time import time

In [5]:
current_dir = os.getcwd()

In [1]:
def train_per_config(name, path, train_config):
    """
    Train a GAN deep abstract model for each provided case.
    Example `train_config` dict:
        {"case_1": {
                "end_time": 32,
                "n_steps": 16,
                "n_init_conditions": 100,
                "n_sims_per_init_condition": 2000,
                "n_epochs": 1,
                "batch_size": 256,
                "n_critic": 5
            }
        }

    The trained models are stored in the "trained_gan" directory.
    
    Parameters:
    - name: string representing the name of the CRN
    - path: string representing the filepath to the CRN definition (Antimony .txt or SML .xml)
    - train_config: dict containing the different training configurations
    """
    training_times = dict()
    if not os.path.exists("trained_gan"):
        os.makedirs("trained_gan")
        
    for case in train_config:
        config = train_config[case]
        end_time = config["end_time"]
        n_steps = config["n_steps"]
        n_init_conditions = config["n_init_conditions"]
        n_sims_per_init_condition = config["n_sims_per_init_condition"]
        n_epochs = config["n_epochs"]
        batch_size = config["batch_size"]
        n_critic = config["n_critic"]

        config_name = f"{name}_{n_steps}_{end_time}_{n_init_conditions}_{n_sims_per_init_condition}"
    
        with open(f"ssa_datasets/{config_name}_train.pickle", "rb") as f:
            training_data = pickle.load(f)

        with open(f"ssa_datasets/{config_name}_test.pickle", "rb") as f:
            test_data = pickle.load(f)
    
        sm = SimulationManager(
            path_to_sbml=path,
            model_name=name,
            n_init_conditions=n_init_conditions,
            n_sims_per_init_condition=n_sims_per_init_condition,
            end_time=end_time,
            n_steps=n_steps
        )
    
        gm = GanManager(
            model_name=sm.model_name,
            n_epochs=n_epochs,
            batch_size=batch_size,
            species_names=sm.get_species_names(),
            x_dim=sm.get_num_species(),
            traj_len=n_steps,
            end_time=end_time,
            id=config_name
        )
    
        gan_train_data = sm.transform_data_for_gan(training_data)
        gan_test_data = sm.transform_validation_data_for_gan(test_data, n_sims_per_init_condition)
    
        gm.init_dataset(
            train_data=gan_train_data,
            test_data=gan_test_data
        )
        gm.init_models()
    
        start_time = time()
        gm.train(n_critic=n_critic)
        time_taken = time() - start_time
        training_times[case] = time_taken

        print(f"Finished training GAN for {case}.")
    
    with open(f"{name}_gan_training_times.pickle", "wb") as f:
        pickle.dump(training_times, f)

## Step 1: Multifeedback model

In [7]:
relative_path = "crn_models/1_multifeedback.txt"
path = os.path.join(current_dir, relative_path)
name = "multifeedback"

In [8]:
train_config = {
    "case_1": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 100,
        "n_sims_per_init_condition": 200,
        "n_epochs": 2,
        "batch_size": 256,
        "n_critic": 5
    },
    "case_2": {
        "end_time": 32,
        "n_steps": 16,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 2,
        "batch_size": 256,
        "n_critic": 5
    }
}

In [9]:
train_per_config(name, path, train_config)

>> Initialised GanManager with model: multifeedback, ID: multifeedback_16_32_100_200, epochs: 2.
>> Loaded data.
>> Train shape: X (20000, 4, 16), Y (20000, 4, 1).
>> Test shape: X (100, 200, 4, 16), Y (100, 4, 1).
>> Initialised Generator and Discriminator.
[Epoch 1/2] [Batch 0/78] [D loss: 13.518375396728516] [G loss: 0.21728309988975525]
[Epoch 1/2] [Batch 5/78] [D loss: 8.718344688415527] [G loss: 0.2524811625480652]
[Epoch 1/2] [Batch 10/78] [D loss: 5.079339027404785] [G loss: 0.3166402280330658]
[Epoch 1/2] [Batch 15/78] [D loss: 3.152920722961426] [G loss: 0.3182479739189148]
[Epoch 1/2] [Batch 20/78] [D loss: 2.0068793296813965] [G loss: 0.42528122663497925]
[Epoch 1/2] [Batch 25/78] [D loss: 1.3987910747528076] [G loss: 0.4600575566291809]
[Epoch 1/2] [Batch 30/78] [D loss: 0.697594404220581] [G loss: 0.5850587487220764]
[Epoch 1/2] [Batch 35/78] [D loss: 0.1972091794013977] [G loss: 0.632701575756073]
[Epoch 1/2] [Batch 40/78] [D loss: -0.07169246673583984] [G loss: 0.690274

## Step 2: Repressilator model

In [10]:
relative_path = "crn_models/2_repressilator.txt"
path = os.path.join(current_dir, relative_path)
name = "repressilator"

In [11]:
train_config = {
    "case_1": {
        "end_time": 128,
        "n_steps": 32,
        "n_init_conditions": 200,
        "n_sims_per_init_condition": 100,
        "n_epochs": 2,
        "batch_size": 256,
        "n_critic": 5
    },
}

In [12]:
train_per_config(name, path, train_config)

>> Initialised GanManager with model: repressilator, ID: repressilator_32_128_200_100, epochs: 2.
>> Loaded data.
>> Train shape: X (20000, 3, 32), Y (20000, 3, 1).
>> Test shape: X (200, 100, 3, 32), Y (200, 3, 1).
>> Initialised Generator and Discriminator.
[Epoch 1/2] [Batch 0/78] [D loss: 7.749721050262451] [G loss: -0.26374268531799316]
[Epoch 1/2] [Batch 5/78] [D loss: 3.654705286026001] [G loss: -0.1961423009634018]
[Epoch 1/2] [Batch 10/78] [D loss: 1.9742927551269531] [G loss: -0.052551645785570145]
[Epoch 1/2] [Batch 15/78] [D loss: 0.7699592709541321] [G loss: 0.14438873529434204]
[Epoch 1/2] [Batch 20/78] [D loss: -0.030763983726501465] [G loss: 0.23808316886425018]
[Epoch 1/2] [Batch 25/78] [D loss: -0.5037968158721924] [G loss: 0.43505150079727173]
[Epoch 1/2] [Batch 30/78] [D loss: -0.9584794044494629] [G loss: 0.6020069718360901]
[Epoch 1/2] [Batch 35/78] [D loss: -1.5083694458007812] [G loss: 0.7988519668579102]
[Epoch 1/2] [Batch 40/78] [D loss: -1.94746732711792] [G 