<a href="https://colab.research.google.com/github/mingmcs/pyhealth/blob/week6/Tutorial_4_pyhealth_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Preparation**
- install pyhealth alpha version 

In [None]:
!pip install pyhealth

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyhealth
  Downloading pyhealth-1.1.2-py2.py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.5/106.5 KB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rdkit>=2022.03.4
  Downloading rdkit-2022.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.3/29.3 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit, pyhealth
Successfully installed pyhealth-1.1.2 rdkit-2022.9.3


### **Instruction on [pyhealth.trainer.Trainer](https://pyhealth.readthedocs.io/en/latest/api/trainer.html)**
- **[README]**: The Trainer class is the training handler (similar to [pytorch-lightning.Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html)) in the package. We use it to train the ML and DL model. It has the following arguments and functionality

- **[Arguments]**: 
To initialize a trainer instance, the following environments should be specified.
  - `model`: the pyhealth.models object
  - `checkpoint_path`: path to intermediate checkpoint
  - `metric`: which metrics to record during training. For example, we can record the pr_auc and auc_roc metrics. 
  - `device`: device to use
  - `enable_logging`: enable logging
  - `output_path`: output path
  - `exp_name`: experiment/task name

- **[Functionality]**:
  - `Trainer.train()`: simply call the `.train()` function will start to train the DL or ML model.
    - `train_dataloader`: train data loader
    - `val_dataloader`: validation data loader
    - `epochs`: number of epochs to train the model
    - `optimizer_class`: optimizer, such as `torch.optim.Adam`
    - `optimizer_params`: optimizer parameters, including
      - `lr`: learning rate
      - `weight_decay`: weight decay
    - `max_grad_norm`: max gradient norm
    - `monitor`: metric name to monitor, default is None
    - `monitor_criterion`: criterion to monitor, default is "max"
    - `load_best_moel_at_last`: whether to load the best model during the last iteration.

### **Step 1 & 2 & 3: Prepare datasets, task, and model**
- Example: We use **MIMIC-III dataset** and **RETAIN** model for **readmission prediction** task. Refer to [Tutorial 1](https://colab.research.google.com/drive/18kbzEQAj1FMs_J9rTGX8eCoxnWdx4Ltn?usp=sharing), [Tutorial 2](https://colab.research.google.com/drive/1r7MYQR_5yCJGpK_9I9-A10HmpupZuIN-?usp=sharing), and [Tutorial 3](https://colab.research.google.com/drive/1LcXZlu7ZUuqepf269X3FhXuhHeRvaJX5?usp=sharing).

In [None]:
# load dataset
from pyhealth.datasets import MIMIC3Dataset
mimic3dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=False, 
)


Parsing PATIENTS and ADMISSIONS: 100%|██████████| 49993/49993 [01:16<00:00, 654.63it/s]
Parsing DIAGNOSES_ICD: 100%|██████████| 52354/52354 [00:10<00:00, 5023.15it/s]
Parsing PROCEDURES_ICD: 100%|██████████| 45769/45769 [00:06<00:00, 6623.57it/s]
Parsing PRESCRIPTIONS: 100%|██████████| 50710/50710 [01:35<00:00, 533.06it/s]
Mapping codes: 100%|██████████| 49993/49993 [00:01<00:00, 38089.70it/s]


In [None]:
from pyhealth.tasks import readmission_prediction_mimic3_fn
from pyhealth.datasets import split_by_patient, get_dataloader

# set task
dataset = mimic3dataset.set_task(task_fn=readmission_prediction_mimic3_fn)

# dataset split
train_ds, val_ds, test_ds = split_by_patient(dataset, [0.8, 0.1, 0.1])

# obtain train/val/test dataloader, they are <torch.data.DataLoader> object
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)

Generating samples for readmission_prediction_mimic3_fn: 100%|██████████| 49993/49993 [00:00<00:00, 191070.25it/s]


In [None]:
# use RETAIN model
from pyhealth.models import RETAIN

model = RETAIN(
    dataset=dataset,
    # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
    feature_keys=["conditions", "procedures"],
    label_key="label",
    mode="binary",
)

### **How to use the Trainer**
- We first initialize the trainer by config the environments, setting `enable_logging` to be True, assign the output_path as `"../output"`, and specify the device.
- We use the trainer for training the `model`. In this step, we need the `train_loader` for model training, `val_loader` to do hold-out validation and use `average_precision_score` as the monitoring metric to select the best model. Training epoch is set 5.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=50,
    monitor="roc_auc",
)

RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(2528, 128, padding_idx=0)
    (procedures): Embedding(817, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cpu

Training:
Batch size: 32
Optimizer: <class 't

Epoch 0 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-0, step-55 ---
loss: 0.4712
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 62.86it/s]
--- Eval epoch-0, step-55 ---
pr_auc: 0.1755
roc_auc: 0.6045
f1: 0.0000
loss: 0.3530
New best roc_auc score (0.6045) at epoch-0, step-55



Epoch 1 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-1, step-110 ---
loss: 0.3831
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 80.47it/s]
--- Eval epoch-1, step-110 ---
pr_auc: 0.1541
roc_auc: 0.6276
f1: 0.0000
loss: 0.3483
New best roc_auc score (0.6276) at epoch-1, step-110



Epoch 2 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-2, step-165 ---
loss: 0.3508
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 128.09it/s]
--- Eval epoch-2, step-165 ---
pr_auc: 0.1596
roc_auc: 0.6302
f1: 0.0000
loss: 0.3469
New best roc_auc score (0.6302) at epoch-2, step-165



Epoch 3 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-3, step-220 ---
loss: 0.2999
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 137.06it/s]
--- Eval epoch-3, step-220 ---
pr_auc: 0.1527
roc_auc: 0.6191
f1: 0.0000
loss: 0.3498



Epoch 4 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-4, step-275 ---
loss: 0.2745
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 147.28it/s]
--- Eval epoch-4, step-275 ---
pr_auc: 0.1679
roc_auc: 0.6084
f1: 0.0800
loss: 0.3562



Epoch 5 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-5, step-330 ---
loss: 0.2385
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 76.39it/s]
--- Eval epoch-5, step-330 ---
pr_auc: 0.1575
roc_auc: 0.6210
f1: 0.0000
loss: 0.3512



Epoch 6 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-6, step-385 ---
loss: 0.2162
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 74.77it/s]
--- Eval epoch-6, step-385 ---
pr_auc: 0.1328
roc_auc: 0.5875
f1: 0.0000
loss: 0.3606



Epoch 7 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-7, step-440 ---
loss: 0.1925
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 77.54it/s]
--- Eval epoch-7, step-440 ---
pr_auc: 0.1325
roc_auc: 0.5817
f1: 0.0000
loss: 0.3646



Epoch 8 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-8, step-495 ---
loss: 0.1742
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 99.16it/s]
--- Eval epoch-8, step-495 ---
pr_auc: 0.1284
roc_auc: 0.5778
f1: 0.0000
loss: 0.3698



Epoch 9 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-9, step-550 ---
loss: 0.1437
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 118.93it/s]
--- Eval epoch-9, step-550 ---
pr_auc: 0.1264
roc_auc: 0.5684
f1: 0.0000
loss: 0.3820



Epoch 10 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-10, step-605 ---
loss: 0.1325
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 121.81it/s]
--- Eval epoch-10, step-605 ---
pr_auc: 0.1181
roc_auc: 0.5588
f1: 0.0000
loss: 0.3930



Epoch 11 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-11, step-660 ---
loss: 0.1228
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 151.09it/s]
--- Eval epoch-11, step-660 ---
pr_auc: 0.1246
roc_auc: 0.5709
f1: 0.0000
loss: 0.3985



Epoch 12 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-12, step-715 ---
loss: 0.1092
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 142.53it/s]
--- Eval epoch-12, step-715 ---
pr_auc: 0.1390
roc_auc: 0.5909
f1: 0.0000
loss: 0.3833



Epoch 13 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-13, step-770 ---
loss: 0.0965
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 140.40it/s]
--- Eval epoch-13, step-770 ---
pr_auc: 0.1227
roc_auc: 0.5758
f1: 0.0000
loss: 0.3986



Epoch 14 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-14, step-825 ---
loss: 0.0881
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 123.27it/s]
--- Eval epoch-14, step-825 ---
pr_auc: 0.1192
roc_auc: 0.5799
f1: 0.0000
loss: 0.4152



Epoch 15 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-15, step-880 ---
loss: 0.0860
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 128.93it/s]
--- Eval epoch-15, step-880 ---
pr_auc: 0.1260
roc_auc: 0.5806
f1: 0.0000
loss: 0.4117



Epoch 16 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-16, step-935 ---
loss: 0.0759
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 145.28it/s]
--- Eval epoch-16, step-935 ---
pr_auc: 0.1266
roc_auc: 0.5762
f1: 0.0000
loss: 0.4243



Epoch 17 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-17, step-990 ---
loss: 0.0634
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 134.40it/s]
--- Eval epoch-17, step-990 ---
pr_auc: 0.1341
roc_auc: 0.5725
f1: 0.0000
loss: 0.4267



Epoch 18 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-18, step-1045 ---
loss: 0.0785
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 114.93it/s]
--- Eval epoch-18, step-1045 ---
pr_auc: 0.1328
roc_auc: 0.5840
f1: 0.0000
loss: 0.4380



Epoch 19 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-19, step-1100 ---
loss: 0.0661
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 135.71it/s]
--- Eval epoch-19, step-1100 ---
pr_auc: 0.1317
roc_auc: 0.5923
f1: 0.0000
loss: 0.4427



Epoch 20 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-20, step-1155 ---
loss: 0.0605
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 104.38it/s]
--- Eval epoch-20, step-1155 ---
pr_auc: 0.1347
roc_auc: 0.5810
f1: 0.0000
loss: 0.4665



Epoch 21 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-21, step-1210 ---
loss: 0.0529
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 134.14it/s]
--- Eval epoch-21, step-1210 ---
pr_auc: 0.1261
roc_auc: 0.5597
f1: 0.0000
loss: 0.4827



Epoch 22 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-22, step-1265 ---
loss: 0.0578
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 114.57it/s]
--- Eval epoch-22, step-1265 ---
pr_auc: 0.1176
roc_auc: 0.5425
f1: 0.0000
loss: 0.5167



Epoch 23 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-23, step-1320 ---
loss: 0.0541
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 119.46it/s]
--- Eval epoch-23, step-1320 ---
pr_auc: 0.1224
roc_auc: 0.5503
f1: 0.0000
loss: 0.5219



Epoch 24 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-24, step-1375 ---
loss: 0.0529
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 112.32it/s]
--- Eval epoch-24, step-1375 ---
pr_auc: 0.1187
roc_auc: 0.5436
f1: 0.0000
loss: 0.5185



Epoch 25 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-25, step-1430 ---
loss: 0.0528
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 123.80it/s]
--- Eval epoch-25, step-1430 ---
pr_auc: 0.1230
roc_auc: 0.5420
f1: 0.0000
loss: 0.5356



Epoch 26 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-26, step-1485 ---
loss: 0.0492
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 139.70it/s]
--- Eval epoch-26, step-1485 ---
pr_auc: 0.1212
roc_auc: 0.5549
f1: 0.0000
loss: 0.5208



Epoch 27 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-27, step-1540 ---
loss: 0.0543
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 133.72it/s]
--- Eval epoch-27, step-1540 ---
pr_auc: 0.1240
roc_auc: 0.5491
f1: 0.0000
loss: 0.5497



Epoch 28 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-28, step-1595 ---
loss: 0.0438
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 132.17it/s]
--- Eval epoch-28, step-1595 ---
pr_auc: 0.1177
roc_auc: 0.5523
f1: 0.0000
loss: 0.5423



Epoch 29 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-29, step-1650 ---
loss: 0.0472
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 140.17it/s]
--- Eval epoch-29, step-1650 ---
pr_auc: 0.1120
roc_auc: 0.5510
f1: 0.0000
loss: 0.5627



Epoch 30 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-30, step-1705 ---
loss: 0.0282
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 120.80it/s]
--- Eval epoch-30, step-1705 ---
pr_auc: 0.1168
roc_auc: 0.5514
f1: 0.0000
loss: 0.5622



Epoch 31 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-31, step-1760 ---
loss: 0.0363
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 99.27it/s]
--- Eval epoch-31, step-1760 ---
pr_auc: 0.1169
roc_auc: 0.5553
f1: 0.0000
loss: 0.5590



Epoch 32 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-32, step-1815 ---
loss: 0.0274
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 124.81it/s]
--- Eval epoch-32, step-1815 ---
pr_auc: 0.1195
roc_auc: 0.5592
f1: 0.0000
loss: 0.5840



Epoch 33 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-33, step-1870 ---
loss: 0.0293
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 122.86it/s]
--- Eval epoch-33, step-1870 ---
pr_auc: 0.1199
roc_auc: 0.5507
f1: 0.0000
loss: 0.6053



Epoch 34 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-34, step-1925 ---
loss: 0.0411
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 124.18it/s]
--- Eval epoch-34, step-1925 ---
pr_auc: 0.1219
roc_auc: 0.5553
f1: 0.0000
loss: 0.5838



Epoch 35 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-35, step-1980 ---
loss: 0.0401
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 133.27it/s]
--- Eval epoch-35, step-1980 ---
pr_auc: 0.1217
roc_auc: 0.5592
f1: 0.0000
loss: 0.6109



Epoch 36 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-36, step-2035 ---
loss: 0.0436
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 148.84it/s]
--- Eval epoch-36, step-2035 ---
pr_auc: 0.1218
roc_auc: 0.5640
f1: 0.0000
loss: 0.5927



Epoch 37 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-37, step-2090 ---
loss: 0.0280
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 148.14it/s]
--- Eval epoch-37, step-2090 ---
pr_auc: 0.1349
roc_auc: 0.5574
f1: 0.0000
loss: 0.6210



Epoch 38 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-38, step-2145 ---
loss: 0.0294
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 127.98it/s]
--- Eval epoch-38, step-2145 ---
pr_auc: 0.1331
roc_auc: 0.5732
f1: 0.1379
loss: 0.5991



Epoch 39 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-39, step-2200 ---
loss: 0.0401
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 124.16it/s]
--- Eval epoch-39, step-2200 ---
pr_auc: 0.1386
roc_auc: 0.5893
f1: 0.0645
loss: 0.6114



Epoch 40 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-40, step-2255 ---
loss: 0.0382
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 154.90it/s]
--- Eval epoch-40, step-2255 ---
pr_auc: 0.1324
roc_auc: 0.5893
f1: 0.0667
loss: 0.5719



Epoch 41 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-41, step-2310 ---
loss: 0.0255
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 123.04it/s]
--- Eval epoch-41, step-2310 ---
pr_auc: 0.1464
roc_auc: 0.6028
f1: 0.0690
loss: 0.5664



Epoch 42 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-42, step-2365 ---
loss: 0.0437
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 127.35it/s]
--- Eval epoch-42, step-2365 ---
pr_auc: 0.1460
roc_auc: 0.6017
f1: 0.0690
loss: 0.5929



Epoch 43 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-43, step-2420 ---
loss: 0.0317
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 101.15it/s]
--- Eval epoch-43, step-2420 ---
pr_auc: 0.1459
roc_auc: 0.6019
f1: 0.0667
loss: 0.5975



Epoch 44 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-44, step-2475 ---
loss: 0.0320
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 139.46it/s]
--- Eval epoch-44, step-2475 ---
pr_auc: 0.1498
roc_auc: 0.6019
f1: 0.0800
loss: 0.6473



Epoch 45 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-45, step-2530 ---
loss: 0.0297
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 111.15it/s]
--- Eval epoch-45, step-2530 ---
pr_auc: 0.1430
roc_auc: 0.5971
f1: 0.0625
loss: 0.6159



Epoch 46 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-46, step-2585 ---
loss: 0.0215
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 130.00it/s]
--- Eval epoch-46, step-2585 ---
pr_auc: 0.1388
roc_auc: 0.5964
f1: 0.0588
loss: 0.6202



Epoch 47 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-47, step-2640 ---
loss: 0.0273
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 124.30it/s]
--- Eval epoch-47, step-2640 ---
pr_auc: 0.1377
roc_auc: 0.5962
f1: 0.0645
loss: 0.6133



Epoch 48 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-48, step-2695 ---
loss: 0.0388
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 142.49it/s]
--- Eval epoch-48, step-2695 ---
pr_auc: 0.1456
roc_auc: 0.6019
f1: 0.0645
loss: 0.6024



Epoch 49 / 50:   0%|          | 0/55 [00:00<?, ?it/s]

--- Train epoch-49, step-2750 ---
loss: 0.0333
Evaluation: 100%|██████████| 7/7 [00:00<00:00, 139.44it/s]
--- Eval epoch-49, step-2750 ---
pr_auc: 0.1467
roc_auc: 0.6159
f1: 0.0741
loss: 0.6072
Loaded best model


### **Evaluation**

In [None]:
# evaluation option 1: use our built-in evaluation metric
result = trainer.evaluate(test_loader)
print ('\n', result)

# evaluation option 2: use pyhealth.metrics
from pyhealth.metrics.binary import binary_metrics_fn
y_true, y_prob, loss = trainer.inference(test_loader)
result = binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc"])
print ('\n', result)

# evaluation option 3: use sklearn.metrics
from sklearn.metrics import average_precision_score, roc_auc_score
y_pred = (y_prob > 0.5).astype('int')
print (
    '\n',
    'roc_auc', roc_auc_score(y_true, y_prob), 
    'pr_auc:', average_precision_score(y_true, y_prob)
)

Evaluation: 100%|██████████| 7/7 [00:00<00:00, 131.55it/s]



 {'pr_auc': 0.1511461740187427, 'roc_auc': 0.5009398496240602, 'f1': 0.0, 'loss': 0.43073106024946484}


Evaluation: 100%|██████████| 7/7 [00:00<00:00, 125.87it/s]



 {'pr_auc': 0.1511461740187427, 'roc_auc': 0.5009398496240602}

 roc_auc 0.5009398496240602 pr_auc: 0.1511461740187427


If you find it useful, please give us a star ⭐ (fork, and watch) at https://github.com/sunlabuiuc/PyHealth. 

Thanks very much for your support!