# 2. Set up training config


***pseudodynamics+*** provides two ways of passing arguments when runing with a script.   
One is the to pass `-a` or `--arg`.    The other one save different args a config file to set up the training for tracing differernt models and better reproduciibility. 

In this notebook, we will 
- go through different arguments
- create a config `json` file
- how to generate an `ExperimetnConfig` object using the a config file

In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
# if sys.platform.startswith("darwin"):
#     os.environ['KMP_DUPLICATE_LIB_OK']='True'

import json
import pseudodynamics as pdp



In [5]:
os.chdir(pdp.main_dir)

# Basic args

In [9]:
basic_config =  {
        "config": None,
        "dataset": "tom_pos",                  # the dataset prefix of the h5ad file, i.e. "data/tom_pos.h5ad"
        "log_name": "tom_pos_fulltime",  # the name of the logging directory ( logs/`log_name` )
        "progress_bar": True
}

In [14]:
model_config = {
    
    # Model choice
    "model" : "pde_params",     # pseudodynamics model , `log_pde_params` as another option
    "time_sensitive": True,      # True: parameters time and state dependent. False : parameters time independent
    
    # Neural network
    "n_dimension": 10,           # input size and the dimension for density estimation
    "channels": "64,64",         # the hidden layer size, e.g. "64,64" means 4 layer density network [input, 64, 64, output]

    "lr": 0.0003,                # learning rate
    "schedule_lr": "CyclicLR",   # learning rate scheduler

    # dynamics equation precision
    "tol": 0.0001,               # tolerance for NeuralODE integral , atol = tol, rtol = tol
    "time_scale_factor" : 1,     # factor scaling the time of integral for NeuralODE, smaller factor -> longer integral
    
    
    "pretrained": None,      # resume 
    "gpu_devices": 0,        # which GPU to use, set to None for CPU training
    

    # Loss term related weights
    "weight_intensity": None,      # the weight to emphasize the high density cell, > 1 for weighting, <1 for unweighting
    "deltax_weight": 0.01,         # the weight used to inform v with local state transition, which is the similarity of deltax and v
    "R_weight": 1,                 # the weight to balance PDE residue loss and the data-related loss
    "growth_weight": None,         # the weight to regularize the contribution of growth to overall density gain, greater means harder boundary
    "D_penalty" :  1,              # the level of restricting Diffusion

}

An important arg is 

In [15]:
dataset_config =  {

        "cellstate_key": "DM_scaled",   # obsm key used as cell state
        "deltax_key": "Delta_DM",                    # obsm key used for local cell state changes
        "timepoint_idx": [                              
            0,
            1,
            2,
            3,
            4,
            6,
            8
        ],                                          # the timepoints to use , numeric index


        "knn_volume": False,
        "batch_size": 50,
        "bw": None,
    
        "norm_time": False,
    }

In [16]:
raw_args = {}
raw_args.update(basic_config)
raw_args.update(model_config)
raw_args.update(dataset_config)

configs = {"raw_args": raw_args}

: 

In [17]:
with open('logs/testing_config.json', 'w') as f:
    json.dump(configs, f, indent=4)

# instanize a Config object

In [18]:
test_config = pdp.ExperimentConfig(config='testing_config.json')

In [20]:
test_config._get_model_config()

{'model_class': 'pde_params',
 'channels': '64,64',
 'activation_fn': None,
 'ode_tol': 0.0001,
 'growth_weight': None,
 'R_weight': 1,
 'D_penalty': 1,
 'deltax_weight': 0.01,
 'weight_intensity': None,
 'time_scale_factor': 1,
 'time_sensitive': True,
 'v_channels': None,
 'g_channels': None,
 'D_channels': None}

# Training

We suggest using the script to train the model. Under the prject main directory, run the following command:

```bash
python main_train.py --config logs/testing_config.json -G None
```

An experimental record will be automatically generated under the log_name directory. For example, the above command will generate a record under the `logs/tom_pos_fulltime/pde_params_tsense/V0-config.json`. The record file can be used to resume the model and the dataset.

In [26]:
# load record json
v0_config = pdp.ExperimentConfig(config='logs/tom_pos_fulltime/pde_params_tsense/V0_config.json')


# check updated model config
v0_config.model_config

{'model_class': 'pde_params',
 'channels': None,
 'activation_fn': None,
 'ode_tol': 0.0001,
 'growth_weight': 0,
 'R_weight': 1,
 'D_penalty': 1,
 'deltax_weight': 0.01,
 'weight_intensity': 1,
 'time_scale_factor': 1,
 'time_sensitive': True,
 'v_channels': [11, 64, 64, 10],
 'g_channels': [11, 64, 64, 1],
 'D_channels': [11, 64, 64, 1]}

In [24]:
# we can locate the checkpoiant by: 
v0_config.find_lastest_ckpt()

