## A tutorial for running MDITRE code on a pre-processed dataset

In this tutorial we show how to run and interpret results of MDITRE on dataset from David et al. (2014)

### Import packages
First we import webbrowser package (for viewing pdf output files from MDITRE). Secondly we import objects from mditre package, Trainer is the main driver object and parse is a function that grabs the default inputs to the Trainer object

In [1]:
import webbrowser
from mditre.trainer_model import Trainer, parse

In [2]:
# Parse default args
args = parse()

### Initialize Trainer object
After the Trainer object is initialized with the default arguments, we see the following output. It shows all of the default arguments being used as input to Trainer object. In addition it shows that we will run the model on a GPU (device cuda:0)

In [3]:
# initialize trainer object
trainer_obj = Trainer(args)

Directory  ./logs/David  already exists
Directory  ./logs/David/seed_42  already exists
Directory  ./logs/David/seed_42/rank_0  already exists
12/09 07:36:07 PM | Namespace(batch_size=128, cv_type='None', data='./datasets/david_agg_filtered.pickle', data_name='David', deterministic=True, distributed=False, epochs=2000, inner_cv=False, kfolds=5, local_rank=0, lr_alpha=0.001, lr_beta=0.001, lr_bias=0.001, lr_eta=0.001, lr_fc=0.001, lr_kappa=0.001, lr_mu=0.01, lr_slope=0.0001, lr_thresh=0.001, lr_time=0.01, max_k_bc=100, max_k_otu=100, max_k_slope=10000.0, max_k_thresh=1000, max_k_time=10, min_k_bc=1, min_k_otu=10, min_k_slope=1000.0, min_k_thresh=100, min_k_time=1, n_d=10, rank=0, save_as_csv=True, seed=42, verbose=False, w_var=100000.0, workers=0, world_size=1, z_mean=0, z_r_mean=0, z_r_var=1, z_var=1)
12/09 07:36:07 PM | Using device: cuda:0
12/09 07:36:07 PM | Trainer initialized!


### Load preprocessed dataset
Next we load the dataset we want to run the model on. In this case the pre-processed dataset is available in the following path './datasets/david_agg_filtered.pickle', as printed in the default args above under 'data'.

In [4]:
# load dataset
trainer_obj.load_data()

12/09 07:36:17 PM | Dataset: David Variables: 308, Otus: 185,                Subjects: 20, Total samples: 233
12/09 07:36:17 PM | Outcomes: (array([0., 1.]), array([10, 10]))
12/09 07:36:17 PM | Exp start: -5.0 Exp end: 10.0
12/09 07:36:17 PM | Loaded and preprocessed dataset!


### Run the training procedure
Next we run the training procedure on the loaded dataset. The model subsequently prints out the classification metrics to the console.

In [None]:
# run train loop
trainer_obj.train_loop()

12/09 07:36:28 PM | Initialized priors!
12/09 07:36:28 PM | Rules: 10 Detectors: 10
12/09 07:36:28 PM | Using cross-validation type: None
12/09 07:36:28 PM | Initializing model!
12/09 07:36:43 PM | Model training started!
Directory  ./logs/David/seed_42/rank_0/fold_0  already exists


### Examine model outputs
The model saves the training losses and the final learned rules as pdf files in the following locations

In [7]:
rules_path = './logs/David/seed_42/rank_0/fold_0/rules.pdf'
loss_path = './logs/David/seed_42/rank_0/fold_0/losses.pdf'

#### Examine model convergence
Executing the following cell opens the pdf file containing all of the different losses optimized during the training process. In order to confirm that the model has converged, take a look at 'Train CE Loss' plot to make sure that the loss has reached close to 0 and has flattened out at the end of the training.

In [11]:
webbrowser.open_new(loss_path)

True

#### Examine learned rules
Executing the following cell opens the pdf file containing the learned rules by the model. In this case the model learns 2 rules, each with 1 detector. For each rule, we show the english description as the title and the visual plot representation with the abundances over time and the selected clade of bacteria. For example Rule 3 Detector 0 reads: TRUE for Plant diet if the average slope of selected taxa between days 1 to 9 is greater than -0.0023. This rule captures the pattern that the abundances of selected clade of taxa (shown on the right in the figure) increase at a rate higher than -0.0023 per day between days 1 to 9 for subjects with Plant-based diet. In each of the plots, the red dashed line shows the median slope of all subjects and the black dashed line shows the learned threshold of -0.0023.

In [12]:
webbrowser.open_new(rules_path)

True