Google Cloud authentication and data download

In [1]:
# ! gcloud auth login
# ! gcloud auth application-default login
# ! gcloud config set project dl4h-final-project-383605

In [3]:
# ! pip install --upgrade google-api-python-client google-cloud-storage
from google.cloud import storage

# Replace these values with your project and bucket as needed
project_id = "dl4h-final-project-383605"
mimic3_bucket = "mimiciii-1.4.physionet.org"

storage_client = storage.Client(project=project_id)
bucket = storage_client.bucket(mimic3_bucket)
data_folder = "./data"
for blob in bucket.list_blobs():
  if "CHARTEVENTS" in blob.name:
    continue
  blob.download_to_filename(f"{data_folder}/{blob.name}")

# Extract all the files
! gunzip {data_folder}/*.gz

gzip: *.gz: No such file or directory


In [5]:
! gunzip {data_folder}/*.gz

Library imports and data preparation

In [1]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit
import pandas as pd
from pyhealth.datasets import split_by_patient, get_dataloader, BaseDataset
from pyhealth.models import BaseModel
from pyhealth.trainer import Trainer
from pyhealth.metrics.binary import binary_metrics_fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import List, Dict, Optional
from enum import Enum
from functools import reduce
from operator import mul

# Set this to the directory with all MIMIC-3 dataset files
data_root = "./data"

  from tqdm.autonotebook import trange


In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=True
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: MIMIC3Dataset
	- Number of patients: 1000
	- Number of visits: 1295
	- Number of visits per patient: 1.2950
	- Number of events per visit in DIAGNOSES_ICD: 9.3544
	- Number of events per visit in PROCEDURES_ICD: 4.3351



'\nStatistics of base dataset (dev=True):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 1000\n\t- Number of visits: 1295\n\t- Number of visits per patient: 1.2950\n\t- Number of events per visit in DIAGNOSES_ICD: 9.3544\n\t- Number of events per visit in PROCEDURES_ICD: 4.3351\n'

In [4]:
# Find all diagnoses codes
# Find all procedure codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diag_codes = []
all_proc_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    procedures = visit.get_code_list(table="PROCEDURES_ICD")
    all_diag_codes.extend(conditions)
    all_proc_codes.extend(procedures)

codes = pd.Series(all_diag_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
num_unique_diag_codes = len(filtered_diag_codes)

unique_proc_codes = list(set(all_proc_codes))
num_unique_proc_codes = len(unique_proc_codes)

In [5]:
proc_code_to_index_map = {}
diag_code_to_index_map = {}

index = 0
for proc_code in unique_proc_codes:
    proc_code_to_index_map[proc_code] = index
    index += 1

index = 0
for diag_code in filtered_diag_codes:
    diag_code_to_index_map[diag_code] = index
    index += 1

In [6]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window=30):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    # if the patient only has one visit, we drop it
    if len(patient) == 1:
        return []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # step 1: define label
    idx_last_visit = len(sorted_visits)-1
    last_visit: Visit = sorted_visits[idx_last_visit]
    second_to_last_visit: Visit = sorted_visits[idx_last_visit - 1]
    first_visit: Visit = sorted_visits[0]

    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff < time_window else 0

    # step 2: obtain features
    visits_conditions = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(sorted_visits):
        if idx == len(sorted_visits) - 1: break
        conditions = [c for c in visit.get_code_list(table="DIAGNOSES_ICD") if c in filtered_diag_codes]
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        if len(conditions) * len(procedures) == 0:
            continue

        visits_conditions.append(conditions)
        visits_procedures.append(procedures)
        visits_intervals.append([time_diff_from_first_visit])

    unique_conditions = list(set(flatten(visits_conditions)))
    unique_procedures = list(set(flatten(visits_procedures)))

    # step 3: exclusion criteria
    if len(unique_conditions) * len(unique_procedures) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "conditions": visits_conditions,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "label": readmission_label,
        }
    )
    return samples

In [7]:
# Create the task datasets
mimic3_dxtx = mimic3_ds.set_task(task_fn=patient_level_readmission_prediction)

Generating samples for patient_level_readmission_prediction: 100%|██████████| 1000/1000 [00:00<00:00, 30153.16it/s]


In [8]:
# Get the unique visit intervals
all_intervals = []
for sample in mimic3_dxtx:
    intervals = flatten(sample['intervals'])
    all_intervals.extend(intervals)

unique_intervals = list(set(all_intervals))
num_unique_intervals = len(unique_intervals)

visit_interval_to_index_map = {}

index = 0
for visit_interval in unique_intervals:
    visit_interval_to_index_map[visit_interval] = index
    index += 1

In [9]:
def collate_codes(batch_codes, i_patient, patient, feature_key, code_to_index_map, mask_tensor):
    for i_visit, visit_codes in enumerate(patient[feature_key]):
        for i_code, code in enumerate(visit_codes):
            batch_codes[i_patient][i_visit][i_code] = code_to_index_map[code]

        # Set the mask for the visit codes
        num_codes = len(visit_codes)
        mask_tensor[i_patient][i_visit][:num_codes] = 1

def collate_fn(batch):
    batch_size = len(batch)

    max_num_visits = 0
    max_num_diagnosis_codes_per_visit = 0
    max_num_procedure_codes_per_visit = 0

    for patient in batch:
        patient_num_visits = len(patient['conditions'])
        if patient_num_visits > max_num_visits:
            max_num_visits = patient_num_visits
        for visit_conditions in patient['conditions']:
            num_visit_diagnosis_codes = len(visit_conditions)
            if num_visit_diagnosis_codes > max_num_diagnosis_codes_per_visit:
                max_num_diagnosis_codes_per_visit = num_visit_diagnosis_codes
        for visit_procedures in patient["procedures"]:
            num_visit_procedure_codes = len(visit_procedures)
            if num_visit_procedure_codes > max_num_procedure_codes_per_visit:
                max_num_procedure_codes_per_visit = num_visit_procedure_codes

    batch_procedures = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
    batch_conditions = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()
    batch_intervals = torch.zeros(batch_size, max_num_visits).long()
    batch_labels = torch.zeros(batch_size).long()

    batch_visits_masks = torch.zeros(batch_size, max_num_visits).long()
    batch_procedures_masks = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
    batch_conditions_masks = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()

    for i_patient, patient in enumerate(batch):
        # Collate diagnosis and procedure codes
        collate_codes(batch_procedures, i_patient, patient, "procedures", proc_code_to_index_map, batch_procedures_masks)
        collate_codes(batch_conditions, i_patient, patient, "conditions", diag_code_to_index_map, batch_conditions_masks)

        # Get the number of visits this patient has
        visit_intervals = flatten(patient["intervals"])
        num_visits = len(visit_intervals)

        # Set the visits mask for the patient
        batch_visits_masks[i_patient][:num_visits] = 1

        # Get the visit intervals (num of days since first visit)

        # Collate the visit intervals by setting the onehot encoding
        visit_interval_onehot_indices = [visit_interval_to_index_map[interval] for interval in visit_intervals]
        batch_intervals[i_patient][:num_visits] = 1

        # Set the label
        batch_labels[i_patient] = patient["label"]

    return (
        batch_procedures,
        batch_conditions,
        batch_intervals,
        batch_visits_masks,
        batch_procedures_masks,
        batch_conditions_masks,
        batch_labels
    )

BATCH_SIZE = 2
train, val, test = split_by_patient(mimic3_dxtx, [0.8, 0.1, 0.1])

# obtain train/val/test dataloader, they are <torch.data.DataLoader> object
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

_ = next(iter(train_loader))
print(_[0])

tensor([[[258, 499,   0,   0,   0,   0,   0,   0,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0]],

        [[118, 174, 520, 383, 200,  70, 126, 317, 274, 143],
         [448, 480, 381, 274,   0,   0,   0,   0,   0,   0]]])


In [23]:
# Define the models
VERY_BIG_NUMBER = 1e30
VERY_SMALL_NUMBER = 1e-30
VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

# class MaskedLayerNorm(nn.Module):
#     def __init__(self, normalized_shape: int):
#         super().__init__()
#         self.gamma = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.beta = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.eps = 1e-5
#         self.normalized_shape = normalized_shape
#
#     def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
#         print(x.shape)
#         print(key_padding_mask.shape)
#         n = torch.sum(torch.ones_like(x) * (~key_padding_mask).int().unsqueeze(-1).expand(-1, -1, self.normalized_shape), dim=-1)
#         print(n.shape)
#         print(n)
#
#         # expected_val = torch.nanmean(x, dim=1)
#         # variance = torch.nansum(())


class AttentionPooling(nn.Module):
    def __init__(self, embedding_size: int):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, input):
        x = self.fc(input)
        # x = torch.nan_to_num(x, nan=VERY_NEGATIVE_NUMBER)
        soft = F.softmax(x, dim=1)
        attn_output = torch.nansum(soft * input, 1)
        return attn_output


# todo: add temporal encoding
# todo: implement layer normalization
class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            batch_first: bool = True,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = nn.MultiheadAttention(
                embed_dim=embedding_dim,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=batch_first
            )
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )

        self.attention_pooling = AttentionPooling(embedding_dim)

        # self.masked_layer_norm1 = MaskedLayerNorm(embedding_dim)
        # self.masked_layer_norm2 = MaskedLayerNorm(embedding_dim)

    def forward(
            self,
            x: torch.Tensor,
            key_padding_mask: Optional[torch.Tensor] = None
    ):
        attn_mask = self._make_temporal_mask(x.shape[1])

        attn_output, attn_output_weights = self.attention(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        attn_output = self.dropout(attn_output)
        out = self.fc(attn_output)
        out = self.attention_pooling(out)

        # x = self.masked_layer_norm1(x + attn_output, key_padding_mask)
        # x = self.masked_layer_norm2(x + self.fc(x), key_padding_mask)
        return out

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None

        if self.temporal_mask_direction == MaskDirection.FORWARD:
            return torch.tril(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            return torch.triu(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            return torch.zeros(n,n).fill_diagonal_(1)

class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 128,
            output_dim: int = 1,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            use_procedure_codes: bool = True,
            num_diag_codes: int = num_unique_diag_codes,
            num_proc_codes: int = num_unique_proc_codes,
            num_intervals: int = num_unique_intervals,
    ):
        super().__init__()

        self.use_procedure_codes = use_procedure_codes
        self.embedding_dim = embedding_dim

        self.diag_emb = nn.Embedding(num_diag_codes, embedding_dim)
        self.proc_emb = nn.Embedding(num_proc_codes, embedding_dim)
        self.interval_emb = nn.Embedding(num_intervals, embedding_dim)

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                dropout = dropout,
                batch_first = batch_first,
                temporal_mask_direction = temporal_mask_direction,
            )

        self.code_level_attn = _make_mask_enc_block(MaskDirection.DIAGONAL)
        self.visit_level_attn_forward = _make_mask_enc_block(MaskDirection.FORWARD)
        self.visit_level_attn_backward = _make_mask_enc_block(MaskDirection.BACKWARD)
        self.fc = nn.Sequential(
            nn.Linear(2*embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, output_dim)
        )

    def flatten(self, x: torch.Tensor, keep: int):
        fixed_shape = list(x.size())
        start = len(fixed_shape) - keep
        left = reduce(mul, [fixed_shape[i] or x.shape[i] for i in range(start)])
        out_shape = [left] + [fixed_shape[i] or x.shape[i] for i in range(start, len(fixed_shape))]
        return torch.reshape(x, out_shape)

    def reshape(self, v, ref, embedding_dim):
        batch_size = ref.shape[0]
        n_visits = ref.shape[1]
        out = torch.reshape(v, [batch_size, n_visits, embedding_dim])
        return out

    def forward(
            self,
            procedures: torch.Tensor,
            conditions: torch.Tensor,
            intervals: torch.Tensor,
            visits_mask: torch.Tensor,
            procedures_mask: torch.Tensor,
            conditions_mask: torch.Tensor
    ) -> torch.Tensor:

        embedded_conditions = self.diag_emb(conditions)

        if self.use_procedure_codes:
            procedures_emb = self.proc_emb(procedures)
            embedded_codes = torch.cat([embedded_conditions, procedures_emb], dim=2)
            codes_mask = torch.cat([conditions_mask, procedures_mask], dim=-1)
        else:
            embedded_codes = embedded_conditions
            codes_mask = conditions_mask

        codes_mask = ~(codes_mask.bool())
        visits_mask = ~(visits_mask.bool())

        # input tensor, reshape 4 dimension to 3
        flattened_codes = self.flatten(embedded_codes, 2)

        # input mask, reshape 3 dimension to 2
        flattened_codes_mask = self.flatten(codes_mask, 1)

        code_attn = self.code_level_attn(flattened_codes, flattened_codes_mask)
        code_attn = self.reshape(code_attn, embedded_codes, self.embedding_dim)

        u_fw = self.visit_level_attn_forward(code_attn, visits_mask)
        u_bw = self.visit_level_attn_backward(code_attn, visits_mask)
        u_bi = torch.cat([u_fw, u_bw], dim=-1)
        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: BaseDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            use_interval_emb: bool = True,
            use_procedure_codes: bool = True,
            embedding_dim: int = 128,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            output_dim = output_size,
            num_heads = num_heads,
            dropout = dropout,
            batch_first = batch_first
        )

        # self.fc = nn.Linear(len(self.feature_keys) * hidden_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
        print(kwargs)
        patient_emb = []
        patient_mask = []
        for feature_key in self.feature_keys:

            feature_vals = kwargs[feature_key]
            for visits_feature_vals in feature_vals:
                x = self.feat_tokenizers[feature_key].batch_encode_2d(visits_feature_vals, padding=True, truncation=False)
                print(visits_feature_vals)

            x = self.feat_tokenizers[feature_key].batch_encode_2d(feature_vals, padding=True, truncation=False)
            x = torch.tensor(x, dtype=torch.long, device=self.device)
            pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>")

            #create the mask
            mask = (x != pad_idx)

            embeds = self.embeddings[feature_key](x)
            patient_emb.append(embeds)
            patient_mask.append(mask)

        # (patient, features * hidden_dim)
        patient_emb = torch.cat(patient_emb, dim=1)
        patient_mask = ~(torch.cat(patient_mask, dim=1))

        # (patient, label_size)
        logits = self.bite_net(patient_emb, patient_mask)
        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)
        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ['procedures', 'conditions', 'intervals'],
    label_key = "label",
    mode = "binary",
)

data = next(iter(train_loader))
model_dxtx(**data)

# model = BiteNet(
#     embedding_dim = 4,
#     output_dim = 1,
#     num_heads = 4,
#     dropout = 0
# )
# procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks, labels = next(iter(train_loader))
# model(procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks)

TypeError: PyHealthBiteNet object argument after ** must be a mapping, not tuple

In [143]:
model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "procedures"],
    label_key = "label",
    mode = "binary",
)

model_dx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions"],
    label_key = "label",
    mode = "binary",
)

In [19]:
data = next(iter(train_loader))
model_dx(**data)

{'loss': tensor(0.7363, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'y_prob': tensor([[0.4940],
         [0.5260],
         [0.5713],
         [0.6283],
         [0.4630],
         [0.4646],
         [0.5993],
         [0.5617],
         [0.5722],
         [0.5628],
         [0.5436],
         [0.4776],
         [0.6207],
         [0.5105],
         [0.6244],
         [0.4756],
         [0.5805],
         [0.5534],
         [0.6040],
         [0.5931],
         [0.5287],
         [0.5496],
         [0.5750],
         [0.4966],
         [0.5697],
         [0.4620],
         [0.5815],
         [0.5287],
         [0.6660],
         [0.5615],
         [0.5513],
         [0.5093]], grad_fn=<SigmoidBackward0>),
 'y_true': tensor([[0.],
         [0.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]

In [20]:
trainer_dxtx = Trainer(model=model_dxtx)
trainer_dxtx.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="pr_auc",
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (conditions): Embedding(3329, 128, padding_idx=0)
    (procedures): Embedding(1333, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (code_level_attn): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (fc): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0, inplace=False)
    )
    (visit_level_attn_forward): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (fc): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0, inplace=False)
    )
    (visit_level_attn_backward): MaskEnc(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, o

Epoch 0 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-0, step-268 ---
loss: 0.6895
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 123.56it/s]
--- Eval epoch-0, step-268 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6888
New best pr_auc score (0.5514) at epoch-0, step-268



Epoch 1 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-1, step-536 ---
loss: 0.6861
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 122.29it/s]
--- Eval epoch-1, step-536 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6882



Epoch 2 / 5:   0%|          | 0/268 [00:00<?, ?it/s]

--- Train epoch-2, step-804 ---
loss: 0.6853
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 35/35 [00:00<00:00, 122.83it/s]
--- Eval epoch-2, step-804 ---
pr_auc: 0.5514
roc_auc: 0.5000
f1: 0.7109
loss: 0.6883



Epoch 3 / 5:   0%|          | 0/268 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [66]:
# option 1: use our built-in evaluation metric
score_dxtx = trainer_dxtx.evaluate(test_loader)
print (score_dxtx)

# option 2: use our pyhealth.metrics to evaluate
y_true_dxtx, y_prob_dxtx, loss_dxtx = trainer_dxtx.inference(test_loader)
binary_metrics_fn(y_true_dxtx, y_prob_dxtx, metrics=["pr_auc"])

Evaluation: 100%|██████████| 33/33 [00:00<00:00, 367.20it/s]


{'pr_auc': 0.726784508178901, 'roc_auc': 0.685080036129632, 'f1': 0.7090909090909092, 'loss': 0.6331828191424861}


Evaluation: 100%|██████████| 33/33 [00:00<00:00, 363.51it/s]


{'pr_auc': 0.726784508178901}