Library imports and data preparation

In [21]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset, SampleDataset
from pyhealth.data import Patient, Visit, Event
import pandas as pd
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import BaseModel, RNN, RETAIN, Deepr
from pyhealth.trainer import Trainer
from pyhealth.medcode import InnerMap
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional
from enum import Enum
from functools import reduce
from operator import mul
from tqdm import tqdm
import numpy as np

# Set this to the directory with all MIMIC-3 dataset files
data_root = "data"

In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
        dev=False
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711
	- Number of events per visit in PRESCRIPTIONS: 70.4013



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n\t- Number of events per visit in PRESCRIPTIONS: 70.4013\n'

In [4]:
# Find all diagnoses codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diagnosis_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    all_diagnosis_codes.extend(conditions)

codes = pd.Series(all_diagnosis_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
num_unique_diag_codes = len(filtered_diag_codes)

In [5]:
MIN_NUM_VISITS_PER_PATIENT = 2

# Filter Dataset to requirements specified in paper

filtered_patients = {}
for patient_id, patient in tqdm(mimic3_ds.patients.items()):

    filtered_patient: Patient = Patient(
        patient_id=patient.patient_id,
        birth_datetime=patient.birth_datetime,
        death_datetime=patient.death_datetime,
        gender=patient.gender,
        ethnicity=patient.ethnicity
    )

    for i_visit, visit in enumerate(patient):
        filtered_visit: Visit = Visit(
            visit_id=visit.visit_id,
            patient_id=visit.patient_id,
            encounter_time=visit.encounter_time,
            discharge_time=visit.discharge_time,
            discharge_status=visit.discharge_status
        )

        diagnoses_codes = visit.get_code_list("DIAGNOSES_ICD")
        procedures_codes = visit.get_code_list("PROCEDURES_ICD")
        prescriptions_codes = visit.get_code_list("PRESCRIPTIONS")

        if len(diagnoses_codes) > 0:
            diagnosis_events = visit.event_list_dict["DIAGNOSES_ICD"]
            for i_event in range(len(diagnosis_events) - 1, -1, -1):
                event: Event = diagnosis_events[i_event]
                if event.code not in filtered_diag_codes:
                    diagnosis_events.pop(i_event) # Remove the diagnosis code with fewer than the cutoff occurrences

            if len(diagnosis_events) == 0: continue # Don't include visits with no diagnoses

            filtered_visit.set_event_list("DIAGNOSES_ICD", diagnosis_events)
        else:
            continue # Don't include visits with no diagnoses

        if len(procedures_codes) > 0:
           filtered_visit.set_event_list("PROCEDURES_ICD", visit.event_list_dict["PROCEDURES_ICD"])

        if len(prescriptions_codes) > 0:
            filtered_visit.set_event_list("PRESCRIPTIONS", visit.event_list_dict["PRESCRIPTIONS"])

        filtered_patient.add_visit(filtered_visit)

    if len(filtered_patient.visits) >= MIN_NUM_VISITS_PER_PATIENT:
        filtered_patients[patient_id] = filtered_patient


100%|██████████| 46520/46520 [01:17<00:00, 597.46it/s]


In [6]:
mimic3_ds.patients = filtered_patients
mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 7496
	- Number of visits: 19905
	- Number of visits per patient: 2.6554
	- Number of events per visit in DIAGNOSES_ICD: 12.9735
	- Number of events per visit in PROCEDURES_ICD: 4.0975
	- Number of events per visit in PRESCRIPTIONS: 82.0433



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 7496\n\t- Number of visits: 19905\n\t- Number of visits per patient: 2.6554\n\t- Number of events per visit in DIAGNOSES_ICD: 12.9735\n\t- Number of events per visit in PROCEDURES_ICD: 4.0975\n\t- Number of events per visit in PRESCRIPTIONS: 82.0433\n'

In [37]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

icd9cm = InnerMap.load("ICD9CM")

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window: int = 30, max_length_visits: int = None):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    # # if the patient only has one visit, we drop it
    # if len(patient) <= 1:
    #     return []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # Clip the patient visits to the most recent max_length_visits + 1 if max_length_visits is not None
    if max_length_visits is not None:
        n_visits = len(sorted_visits)
        if n_visits > max_length_visits + 1:
            sorted_visits = sorted_visits[n_visits - (max_length_visits + 1):]

    feature_visits: List[Visit] = sorted_visits[:-1]
    last_visit: Visit = sorted_visits[-1]
    second_to_last_visit: Visit = feature_visits[-1]
    first_visit: Visit = feature_visits[0]

    # step 1 a: define readmission label
    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff <= time_window else 0

    # step 1 b: define diagnosis prediction label
    diagnosis_label = list(set([icd9cm.get_ancestors(code)[1] for code in last_visit.get_code_list("DIAGNOSES_ICD")]))

    # step 2: obtain features
    visits_diagnoses = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(feature_visits):
        diagnoses = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        # Exclude visits that are missing either diagnoses or procedures.
        # BiteNet can handle missing procedures, but other PyHealth models like RNN
        # require all features have a length greater than 0.
        if len(diagnoses) * len(procedures) == 0:
            continue

        visits_diagnoses.append(diagnoses)
        visits_procedures.append(procedures)
        visits_intervals.append([str(time_diff_from_first_visit)])

    unique_diagnoses = list(set(flatten(visits_diagnoses)))

    # step 3: exclusion criteria
    if len(unique_diagnoses) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "diagnoses": visits_diagnoses,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "readmission_label": readmission_label,
            "diagnosis_label": diagnosis_label
        }
    )
    return samples

In [38]:
# Define the models
VERY_BIG_NUMBER = 1e30
VERY_SMALL_NUMBER = 1e-30
VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

class MaskedLayerNorm(nn.Module):
    def __init__(self, normalized_shape: int):
        super().__init__()
        self.scale = nn.parameter.Parameter(torch.ones(normalized_shape, dtype=torch.float32, device=device))
        self.bias = nn.parameter.Parameter(torch.zeros(normalized_shape, dtype=torch.float32, device=device))
        self.normalized_shape = normalized_shape

    def forward(self, x: torch.Tensor, eps=1e-5):
        mean = torch.mean(x, dim=-1, keepdim=True)
        variance = torch.mean(torch.square(x - mean), dim=-1, keepdim=True)
        norm_x = (x - mean) * torch.rsqrt(variance + eps)
        return norm_x * self.scale + self.bias

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, keep: int):
        fixed_shape = list(x.size())
        start = len(fixed_shape) - keep
        left = reduce(mul, [fixed_shape[i] or x.shape[i] for i in range(start)])
        out_shape = [left] + [fixed_shape[i] or x.shape[i] for i in range(start, len(fixed_shape))]
        return torch.reshape(x, out_shape)


class Unflatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, v: torch.Tensor, ref: torch.Tensor, embedding_dim):
        batch_size = ref.shape[0]
        n_visits = ref.shape[1]
        out = torch.reshape(v, [batch_size, n_visits, embedding_dim])
        return out


class AttentionPooling(nn.Module):
    def __init__(self, embedding_size: int):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, inputs):
        x, mask = inputs
        x = self.fc(x)
        x[mask] = VERY_NEGATIVE_NUMBER
        soft = F.softmax(x, dim=1)
        x[mask] = 0
        attn_output = torch.sum(soft * x, 1)
        return attn_output


class MultiHeadAttention(nn.Module):
    def __init__(self, direction, dropout, num_units, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.direction = direction
        self.num_units = num_units
        self.q_linear = nn.Linear(num_units, num_units, bias=False)
        self.k_linear = nn.Linear(num_units, num_units, bias=False)
        self.v_linear = nn.Linear(num_units, num_units, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):

        # because of self-attention, queries and keys is equal to inputs
        input_tensor, input_mask = inputs
        queries = input_tensor
        keys = input_tensor

        # Linear projections
        Q = self.q_linear(queries)  # (N, L_q, d)
        K = self.k_linear(keys)  # (N, L_k, d)
        V = self.v_linear(keys)  # (N, L_k, d)

        # print('Q shape: ', Q.get_shape())

        # Split and concat
        assert self.num_units % self.num_heads == 0
        Q_ = torch.cat(torch.split(Q, self.num_units // self.num_heads, dim=2), dim=0)  # (h*N, L_q, d/h)
        K_ = torch.cat(torch.split(K, self.num_units // self.num_heads, dim=2), dim=0)  # (h*N, L_k, d/h)
        V_ = torch.cat(torch.split(V, self.num_units // self.num_heads, dim=2), dim=0)  # (h*N, L_k, d/h)

        # Multiplication
        outputs = torch.matmul(Q_, torch.permute(K_, [0, 2, 1]))  # (h*N, L_q, L_k)

        # Scale
        outputs = outputs / (list(K_.shape)[-1] ** 0.5)  # (h*N, L_q, L_k)

        # Key Masking
        key_masks = torch.sign(torch.sum(torch.abs(K_), dim=-1))  # (h*N, T_k)
        key_masks = torch.unsqueeze(key_masks, 1)  # (h*N, 1, T_k)
        key_masks = torch.tile(key_masks, [1, list(Q_.shape)[1], 1])  # (h*N, T_q, T_k)

        # Apply masks to outputs
        paddings = torch.ones_like(outputs, device=device) * (-2 ** 32 + 1)  # exp mask
        outputs = torch.where(key_masks == 0, paddings, outputs)  # (h*N, T_q, T_k)

        n_visits = list(input_tensor.shape)[1]
        sw_indices = torch.arange(0, n_visits, dtype=torch.int32, device=device)
        sw_col, sw_row = torch.meshgrid(sw_indices, sw_indices)
        if self.direction == MaskDirection.DIAGONAL:
            # shape of (n_visits, n_visits)
            attention_mask = (torch.diag(- torch.ones([n_visits], dtype=torch.int32, device=device)) + 1).bool()
        elif self.direction == MaskDirection.FORWARD:
            attention_mask = torch.greater(sw_row, sw_col)  # shape of (n_visits, n_visits)
        else: # MaskDirection.BACKWARD
            attention_mask = torch.greater(sw_col, sw_row)  # shape of (n_visits, n_visits)
        adder = (1.0 - attention_mask.type(outputs.dtype)) * -10000.0
        outputs += adder

        # softmax
        outputs = F.softmax(outputs, -1)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = torch.sign(torch.sum(torch.abs(Q_), dim=-1))  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, -1)  # (h*N, T_q, 1)
        query_masks = torch.tile(query_masks, [1, 1, list(K_.shape)[1]])  # (h*N, T_q, T_k)

        # Apply masks to outputs
        outputs = outputs * query_masks

        # Dropouts
        outputs = self.dropout(outputs)
        # Weighted sum
        outputs = torch.matmul(outputs, V_)  # ( h*N, T_q, C/h)

        # Restore shape
        outputs = torch.cat(torch.split(outputs, outputs.shape[0] // self.num_heads, dim=0), dim=2)  # (N, L_q, d)

        # input padding
        val_mask = torch.unsqueeze(input_mask, -1)
        outputs = torch.multiply(outputs, (~val_mask).float())

        return outputs


class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = MultiHeadAttention(
            direction=temporal_mask_direction,
            dropout=dropout,
            num_units=embedding_dim,
            num_heads=num_heads
        )

        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )

        self.layer_norm1 = MaskedLayerNorm(embedding_dim)
        self.layer_norm2 = MaskedLayerNorm(embedding_dim)

    def forward(self, inputs):
        x, key_padding_mask = inputs

        attn_output = self.attention((x, key_padding_mask))
        attn_output = self.layer_norm1(x + attn_output)
        out = self.fc(attn_output)
        out = out * (~key_padding_mask.unsqueeze(-1)).float()
        out = self.layer_norm2(out + attn_output)
        out = out * (~key_padding_mask.unsqueeze(-1)).float()

        return out, key_padding_mask

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None
        if self.temporal_mask_direction == MaskDirection.FORWARD:
            return torch.tril(torch.full((n, n), -10000, device=device)).fill_diagonal_(0).float()
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            return torch.triu(torch.full((n, n), -10000, device=device)).fill_diagonal_(0).float()
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            return torch.zeros(n, n, device=device).fill_diagonal_(-10000).float()


class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 128,
            num_heads: int = 4,
            dropout: float = 0.1,
            n_mask_enc_layers: int = 2,
            use_procedures: bool = True,
            use_intervals: bool = True,
    ):
        super().__init__()

        self.use_intervals = use_intervals
        self.use_procedures = use_procedures
        self.embedding_dim = embedding_dim

        self.flatten = Flatten()
        self.unflatten = Unflatten()

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                dropout = dropout,
                temporal_mask_direction = temporal_mask_direction,
            )

        self.code_attn = nn.Sequential()
        self.visit_attn_fw = nn.Sequential()
        self.visit_attn_bw = nn.Sequential()
        for _ in range(n_mask_enc_layers):
            self.code_attn.append(_make_mask_enc_block(MaskDirection.DIAGONAL))
            self.visit_attn_fw.append(_make_mask_enc_block(MaskDirection.FORWARD))
            self.visit_attn_bw.append(_make_mask_enc_block(MaskDirection.BACKWARD))

        # Attention pooling layers
        self.code_attn.append(AttentionPooling(embedding_dim))
        self.visit_attn_fw.append(AttentionPooling(embedding_dim))
        self.visit_attn_bw.append(AttentionPooling(embedding_dim))

        self.fc = nn.Sequential(
            nn.Linear(2*embedding_dim, embedding_dim),
            nn.ReLU()
        )

    def forward(
            self,
            embedded_codes: torch.Tensor,
            embedded_intervals: torch.Tensor,
            codes_mask: torch.Tensor,
            visits_mask: torch.Tensor,
    ) -> torch.Tensor:

        codes_mask = ~(codes_mask.bool())

        # input tensor, reshape 4 dimension to 3
        flattened_codes = self.flatten(embedded_codes, 2)

        # input mask, reshape 3 dimension to 2
        flattened_codes_mask = self.flatten(codes_mask, 1)

        code_attn = self.code_attn((flattened_codes, flattened_codes_mask))
        code_attn = self.unflatten(code_attn, embedded_codes, self.embedding_dim)

        if self.use_intervals:
            code_attn += embedded_intervals

        visits_mask = ~(visits_mask.bool())

        u_fw = self.visit_attn_fw((code_attn, visits_mask))
        u_bw = self.visit_attn_bw((code_attn, visits_mask))
        u_bi = torch.cat([u_fw, u_bw], dim=-1)

        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: SampleDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            embedding_dim: int = 128,
            n_mask_enc_layers: int = 2,
            use_intervals: bool = True,
            use_procedures: bool = True,
            num_heads: int = 4,
            dropout: float = 0.1,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        self.use_intervals = use_intervals
        self.use_procedures = use_procedures

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info, special_tokens=["<pad>", "<unk>"]
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            num_heads = num_heads,
            dropout = dropout,
            use_intervals=use_intervals,
            use_procedures=use_procedures,
            n_mask_enc_layers=n_mask_enc_layers,
        )

        self.fc = nn.Linear(self.embedding_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:

        embeddings = {}
        masks = {}
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]

            # each patient's feature is represented by [[code1, code2],[code3]]
            assert input_info["dim"] == 3 and input_info["type"] == str
            feature_vals = kwargs[feature_key]

            x = self.feat_tokenizers[feature_key].batch_encode_3d(feature_vals, truncation=(False, False))
            x = torch.tensor(x, dtype=torch.long, device=self.device)
            pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>")
            #create the mask
            mask = (x != pad_idx).long()
            embeds = self.embeddings[feature_key](x)
            embeddings[feature_key] = embeds
            masks[feature_key] = mask

        embedded_codes = embeddings['diagnoses']
        codes_mask = masks['diagnoses']
        if self.use_procedures:
            embedded_codes = torch.cat((embedded_codes, embeddings['procedures']), dim=2)
            codes_mask = torch.cat((codes_mask, masks['procedures']), dim=2)

        output = self.bite_net(embedded_codes, embeddings['intervals'].squeeze(2), codes_mask, masks['intervals'].squeeze(-1))
        logits = self.fc(output)

        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)

        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

Device: cuda


In [48]:
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, lr=0.001, monitor="pr-auc"):
    trainer = Trainer(model=model, device=device)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=10,
        monitor=monitor,
        optimizer_class=torch.optim.RMSprop,
        optimizer_params = {"lr" : lr},
        load_best_model_at_last=True
    )

    # Run inference and evaluate
    # option 1: use our built-in evaluation metric
    score = trainer.evaluate(test_loader)
    print (score)

In [40]:
# Define the dataset and data loaders
# Create the task datasets
dataset = mimic3_ds.set_task(task_fn=patient_level_readmission_prediction)
print(len(dataset))

BATCH_SIZE = 32
train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1])

train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

Generating samples for patient_level_readmission_prediction: 100%|██████████| 7496/7496 [00:03<00:00, 2133.79it/s]


6930


In [34]:
train_and_evaluate_model(
    PyHealthBiteNet(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures", "intervals"],
        label_key = "readmission_label",
        mode = "binary",
        use_procedures=True,
        use_intervals=True
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
    (intervals): Embedding(1756, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): MultiHeadAttention(
          (q_linear): Linear(in_features=128, out_features=128, bias=False)
          (k_linear): Linear(in_features=128, out_features=128, bias=False)
          (v_linear): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=128, out_features=128, bias=True)
        )
        (layer_norm1): MaskedLayerNorm()
        (layer_norm2): 

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
--- Train epoch-0, step-174 ---
loss: 0.5046
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.28it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.3108
roc_auc: 0.5620
f1: 0.1702
loss: 0.4618
New best pr_auc score (0.3108) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.4783
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.28it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.3058
roc_auc: 0.5599
f1: 0.1690
loss: 0.4603



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.4686
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.67it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.3055
roc_auc: 0.5528
f1: 0.1633
loss: 0.4658



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.4593
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.67it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.3014
roc_auc: 0.5379
f1: 0.1830
loss: 0.4869



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.4469
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 64.71it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3028
roc_auc: 0.5477
f1: 0.1690
loss: 0.4781



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.4365
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.67it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.2966
roc_auc: 0.5378
f1: 0.1911
loss: 0.4867



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.4254
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 64.90it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.2915
roc_auc: 0.5354
f1: 0.1988
loss: 0.4995



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.4098
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 65.67it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.2950
roc_auc: 0.5357
f1: 0.1899
loss: 0.5138



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.3956
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.36it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.2928
roc_auc: 0.5387
f1: 0.2025
loss: 0.5227



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.3826
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.97it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.2890
roc_auc: 0.5348
f1: 0.2172
loss: 0.6436
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.44it/s]

{'pr_auc': 0.3729449988885468, 'roc_auc': 0.5987539732994278, 'f1': 0.3254237288135593, 'loss': 0.6153426658023488}





In [49]:
train_and_evaluate_model(
    PyHealthBiteNet(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures", "intervals"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        use_procedures=True,
        use_intervals=False
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001,
    monitor="pr_auc_samples"
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
    (intervals): Embedding(1756, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): MultiHeadAttention(
          (q_linear): Linear(in_features=128, out_features=128, bias=False)
          (k_linear): Linear(in_features=128, out_features=128, bias=False)
          (v_linear): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=128, out_features=128, bias=True)
        )
        (layer_norm1): MaskedLayerNorm()
        (layer_norm2): 

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.0946
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.97it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.3041
loss: 0.0903
New best pr_auc_samples score (0.3041) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0862
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.62it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.2999
loss: 0.0939



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0855
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.44it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.3417
loss: 0.0907
New best pr_auc_samples score (0.3417) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0843
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.11it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.3517
loss: 0.0894
New best pr_auc_samples score (0.3517) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0836
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.36it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.3591
loss: 0.0872
New best pr_auc_samples score (0.3591) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0829
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.77it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.3699
loss: 0.0895
New best pr_auc_samples score (0.3699) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0825
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.62it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.3655
loss: 0.0866



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0820
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.94it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.3740
loss: 0.0887
New best pr_auc_samples score (0.3740) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0815
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.11it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.3743
loss: 0.0855
New best pr_auc_samples score (0.3743) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0803
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.77it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.3776
loss: 0.0868
New best pr_auc_samples score (0.3776) at epoch-9, step-1740
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 56.12it/s]


{'pr_auc_samples': 0.3748846788378517, 'loss': 0.08555971424687993}


In [25]:
train_and_evaluate_model(
    RNN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001
)

RNN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (diagnoses): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x00000211A794E700>
Monitor: pr_auc
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5060
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 147.66it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.4103
roc_auc: 0.6276
f1: 0.2275
loss: 0.4788
New best pr_auc score (0.4103) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.4188
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 141.03it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.4115
roc_auc: 0.6165
f1: 0.2712
loss: 0.4967
New best pr_auc score (0.4115) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.3401
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 140.13it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.4139
roc_auc: 0.6251
f1: 0.3021
loss: 0.5159
New best pr_auc score (0.4139) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.2591
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 147.65it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.3980
roc_auc: 0.6171
f1: 0.2817
loss: 0.5742



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.1863
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 131.74it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.4049
roc_auc: 0.6146
f1: 0.3529
loss: 0.6299



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.1322
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 120.22it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3986
roc_auc: 0.6215
f1: 0.2909
loss: 0.6772



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0958
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 146.67it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3949
roc_auc: 0.6128
f1: 0.3291
loss: 0.7314



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0665
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 144.74it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.3848
roc_auc: 0.6095
f1: 0.2715
loss: 0.8290



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0481
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 147.65it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.3936
roc_auc: 0.6116
f1: 0.3226
loss: 0.8718



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0369
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 148.65it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.3827
roc_auc: 0.6051
f1: 0.2918
loss: 0.9339
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 161.77it/s]

{'pr_auc': 0.3738644161210829, 'roc_auc': 0.6120972073039742, 'f1': 0.3047619047619048, 'loss': 0.8288188739256426}





In [45]:
train_and_evaluate_model(
    RNN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001,
    monitor="pr_auc_samples"
)

RNN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (diagnoses): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=465, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x000001D43AC050D0>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1048
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 150.69it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.3754
loss: 0.0889
New best pr_auc_samples score (0.3754) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0818
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 141.03it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.4102
loss: 0.0844
New best pr_auc_samples score (0.4102) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0776
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 146.67it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.4215
loss: 0.0829
New best pr_auc_samples score (0.4215) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0749
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 142.86it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.4342
loss: 0.0819
New best pr_auc_samples score (0.4342) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0726
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 139.24it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.4377
loss: 0.0816
New best pr_auc_samples score (0.4377) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0705
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 143.79it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.4433
loss: 0.0815
New best pr_auc_samples score (0.4433) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0687
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 121.55it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.4422
loss: 0.0816



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0669
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 134.15it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.4442
loss: 0.0817
New best pr_auc_samples score (0.4442) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0654
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 133.34it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.4445
loss: 0.0819
New best pr_auc_samples score (0.4445) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0638
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 119.57it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.4427
loss: 0.0824
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 140.13it/s]


{'pr_auc_samples': 0.44439526194302753, 'loss': 0.08051219142296097}


In [27]:
train_and_evaluate_model(
    RETAIN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader
)

RETAIN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (diagnoses): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 't

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5355
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 84.29it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.3553
roc_auc: 0.5890
f1: 0.1775
loss: 0.5163
New best pr_auc score (0.3553) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.3683
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.88it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.3871
roc_auc: 0.6180
f1: 0.2737
loss: 0.5194
New best pr_auc score (0.3871) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.2285
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 77.19it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.3715
roc_auc: 0.6024
f1: 0.2526
loss: 0.6153



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.1225
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.01it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.3835
roc_auc: 0.5939
f1: 0.3097
loss: 0.6973



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0688
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 79.42it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3848
roc_auc: 0.6013
f1: 0.3117
loss: 0.8054



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0419
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.39it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3802
roc_auc: 0.5960
f1: 0.3084
loss: 0.9247



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0313
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.29it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3778
roc_auc: 0.6086
f1: 0.3070
loss: 0.9865



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0186
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 79.71it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.3776
roc_auc: 0.5983
f1: 0.2857
loss: 1.1027



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0181
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 73.83it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.3936
roc_auc: 0.6030
f1: 0.3004
loss: 1.1505
New best pr_auc score (0.3936) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0158
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.29it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.3926
roc_auc: 0.6023
f1: 0.3064
loss: 1.2158
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 82.09it/s]

{'pr_auc': 0.3606033694496517, 'roc_auc': 0.5797932330827067, 'f1': 0.29729729729729726, 'loss': 1.2300627353516491}





In [46]:
train_and_evaluate_model(
    RETAIN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001,
    monitor="pr_auc_samples"
)

RETAIN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (diagnoses): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=465, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1180
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 81.18it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.4010
loss: 0.0930
New best pr_auc_samples score (0.4010) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0811
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 75.34it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.4227
loss: 0.0878
New best pr_auc_samples score (0.4227) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0747
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 70.29it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.4304
loss: 0.0862
New best pr_auc_samples score (0.4304) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0707
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.57it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.4357
loss: 0.0863
New best pr_auc_samples score (0.4357) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0674
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.02it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.4313
loss: 0.0869



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0648
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 77.74it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.4356
loss: 0.0875



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0626
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 67.69it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.4325
loss: 0.0880



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0604
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.39it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.4332
loss: 0.0890



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0586
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.66it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.4300
loss: 0.0904



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0569
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.88it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.4272
loss: 0.0914
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 68.54it/s]


{'pr_auc_samples': 0.4249831201075734, 'loss': 0.09086060625585643}


In [29]:
train_and_evaluate_model(
    Deepr(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128
    ).to(device),
    train_loader,
    val_loader,
    test_loader
)

Deepr(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3375, 128, padding_idx=0)
    (procedures): Embedding(1363, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (cnn): ModuleDict(
    (diagnoses): DeeprLayer(
      (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
    )
    (procedures): DeeprLayer(
      (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x00000211A794E700>
Monitor: pr_auc
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5101
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 183.33it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.4078
roc_auc: 0.6345
f1: 0.1863
loss: 0.4984
New best pr_auc score (0.4078) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.3764
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 268.29it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.4154
roc_auc: 0.6332
f1: 0.2755
loss: 0.4897
New best pr_auc score (0.4154) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.2655
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 258.82it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.3905
roc_auc: 0.6049
f1: 0.3258
loss: 0.5340



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.1544
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 249.96it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.4250
roc_auc: 0.6310
f1: 0.3577
loss: 0.5696
New best pr_auc score (0.4250) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0791
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 244.44it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3979
roc_auc: 0.6118
f1: 0.2922
loss: 0.6220



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0389
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 261.87it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3990
roc_auc: 0.6011
f1: 0.2842
loss: 0.7871



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0175
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.79it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3948
roc_auc: 0.6088
f1: 0.2737
loss: 0.8633



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0168
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.82it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.4105
roc_auc: 0.6164
f1: 0.2932
loss: 0.9251



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0055
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.81it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.3957
roc_auc: 0.6097
f1: 0.3136
loss: 0.8593



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0027
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 252.87it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.4183
roc_auc: 0.6189
f1: 0.3349
loss: 0.9303
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 224.49it/s]

{'pr_auc': 0.39479580130583247, 'roc_auc': 0.6421455424274973, 'f1': 0.303030303030303, 'loss': 0.8376922133294019}



