Library imports and data preparation

In [70]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset, SampleDataset
from pyhealth.data import Patient, Visit, Event
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import BaseModel, RNN, RETAIN, Deepr
from pyhealth.trainer import Trainer
from pyhealth.medcode import InnerMap
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional
from enum import Enum
from functools import reduce
from operator import mul
from tqdm import tqdm
from pyhealth.metrics.binary import binary_metrics_fn
from sklearn.metrics import precision_score
import numpy as np
import pandas as pd

# Set this to the directory with all MIMIC-3 dataset files
data_root = "data"

In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n'

In [4]:
# Find all diagnoses codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diagnosis_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    all_diagnosis_codes.extend(conditions)

codes = pd.Series(all_diagnosis_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
n_unique_diag_codes = len(filtered_diag_codes)

In [5]:
MIN_N_VISITS_PER_PATIENT = 2

# Filter Dataset to requirements specified in paper

filtered_patients = {}
for patient_id, patient in tqdm(mimic3_ds.patients.items()):

    filtered_patient: Patient = Patient(
        patient_id=patient.patient_id,
        birth_datetime=patient.birth_datetime,
        death_datetime=patient.death_datetime,
        gender=patient.gender,
        ethnicity=patient.ethnicity
    )

    for i_visit, visit in enumerate(patient):
        filtered_visit: Visit = Visit(
            visit_id=visit.visit_id,
            patient_id=visit.patient_id,
            encounter_time=visit.encounter_time,
            discharge_time=visit.discharge_time,
            discharge_status=visit.discharge_status
        )

        diagnoses_codes = visit.get_code_list("DIAGNOSES_ICD")
        procedures_codes = visit.get_code_list("PROCEDURES_ICD")
        prescriptions_codes = visit.get_code_list("PRESCRIPTIONS")

        if len(diagnoses_codes) > 0:
            diagnosis_events = visit.event_list_dict["DIAGNOSES_ICD"]
            for i_event in range(len(diagnosis_events) - 1, -1, -1):
                event: Event = diagnosis_events[i_event]
                if event.code not in filtered_diag_codes:
                    diagnosis_events.pop(i_event) # Remove the diagnosis code with fewer than the cutoff occurrences

            if len(diagnosis_events) == 0: continue # Don't include visits with no diagnoses

            filtered_visit.set_event_list("DIAGNOSES_ICD", diagnosis_events)
        else:
            continue # Don't include visits with no diagnoses

        if len(procedures_codes) > 0:
           filtered_visit.set_event_list("PROCEDURES_ICD", visit.event_list_dict["PROCEDURES_ICD"])

        if len(prescriptions_codes) > 0:
            filtered_visit.set_event_list("PRESCRIPTIONS", visit.event_list_dict["PRESCRIPTIONS"])

        filtered_patient.add_visit(filtered_visit)

    if len(filtered_patient.visits) >= MIN_N_VISITS_PER_PATIENT:
        filtered_patients[patient_id] = filtered_patient


100%|██████████| 46520/46520 [01:09<00:00, 666.78it/s]


In [6]:
mimic3_ds.patients = filtered_patients
mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 7496
	- Number of visits: 19905
	- Number of visits per patient: 2.6554
	- Number of events per visit in DIAGNOSES_ICD: 12.9735
	- Number of events per visit in PROCEDURES_ICD: 4.0975



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 7496\n\t- Number of visits: 19905\n\t- Number of visits per patient: 2.6554\n\t- Number of events per visit in DIAGNOSES_ICD: 12.9735\n\t- Number of events per visit in PROCEDURES_ICD: 4.0975\n'

In [7]:
# Get the dataset span for encoding visit intervals
# The lookup table for interval encoding will be of dimensions mxd,
# where m is the length of the dataset span and d is the embedding
# dimension. The lookup table needs to be as large as the dataset
# time span because it could be possible for a patient to have his/her
# first visit on the earliest day in the dataset and the last visit on
# the latest day in the dataset. Thus, the index for this patient's visit
# interval would be (last_visit.time - first_visit.time).days = m = dataset span.

max_patient_span_days: int = 0

for patient_id, patient in tqdm(mimic3_ds.patients.items()):

    sorted_visits = sorted(patient, key=lambda v: v.encounter_time)
    patient_span_days = (sorted_visits[-1].encounter_time - sorted_visits[0].encounter_time).days
    if patient_span_days > max_patient_span_days:
        max_patient_span_days = patient_span_days

print(f"Max span (days) of a single patient's visits: {max_patient_span_days}")

100%|██████████| 7496/7496 [00:00<00:00, 277654.66it/s]

Max span (days) of a single patient's visits: 4221





In [8]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

icd9cm = InnerMap.load("ICD9CM")

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window: int = 30, max_length_visits: int = None):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    # # if the patient only has one visit, we drop it
    # if len(patient) <= 1:
    #     return []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # Clip the patient visits to the most recent max_length_visits + 1 if max_length_visits is not None
    if max_length_visits is not None:
        n_visits = len(sorted_visits)
        if n_visits > max_length_visits + 1:
            sorted_visits = sorted_visits[n_visits - (max_length_visits + 1):]

    feature_visits: List[Visit] = sorted_visits[:-1]
    last_visit: Visit = sorted_visits[-1]
    second_to_last_visit: Visit = feature_visits[-1]
    first_visit: Visit = feature_visits[0]

    # step 1 a: define readmission label
    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff <= time_window else 0

    # step 1 b: define diagnosis prediction label
    diagnosis_label = list(set([icd9cm.get_ancestors(code)[1] for code in last_visit.get_code_list("DIAGNOSES_ICD")]))

    # step 2: obtain features
    visits_diagnoses = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(feature_visits):
        diagnoses = visit.get_code_list(table="DIAGNOSES_ICD")
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        # Exclude visits that are missing either diagnoses or procedures.
        # BiteNet can handle missing procedures, but other PyHealth models like RNN
        # require all features have a length greater than 0.
        if len(diagnoses) * len(procedures) == 0:
            continue

        visits_diagnoses.append(diagnoses)
        visits_procedures.append(procedures)
        visits_intervals.append([str(time_diff_from_first_visit)])

    unique_diagnoses = list(set(flatten(visits_diagnoses)))

    # step 3: exclusion criteria
    if len(unique_diagnoses) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "diagnoses": visits_diagnoses,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "readmission_label": readmission_label,
            "diagnosis_label": diagnosis_label
        }
    )
    return samples

In [44]:
# Define the dataset and data loaders
# Create the task datasets
dataset = mimic3_ds.set_task(task_fn=patient_level_readmission_prediction)
print(len(dataset))

Generating samples for patient_level_readmission_prediction: 100%|██████████| 7496/7496 [00:04<00:00, 1855.50it/s]


6930


In [10]:
# Define the models
VERY_BIG_NUMBER = 1e30
VERY_SMALL_NUMBER = 1e-30
VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape: int):
        super().__init__()
        self.scale = nn.parameter.Parameter(torch.ones(normalized_shape, dtype=torch.float32, device=device))
        self.bias = nn.parameter.Parameter(torch.zeros(normalized_shape, dtype=torch.float32, device=device))
        self.normalized_shape = normalized_shape

    def forward(self, x: torch.Tensor, eps=1e-5):
        mean = torch.mean(x, dim=-1, keepdim=True)
        variance = torch.mean(torch.square(x - mean), dim=-1, keepdim=True)
        norm_x = (x - mean) * torch.rsqrt(variance + eps)
        return norm_x * self.scale + self.bias

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, keep: int):
        fixed_shape = list(x.size())
        start = len(fixed_shape) - keep
        left = reduce(mul, [fixed_shape[i] or x.shape[i] for i in range(start)])
        out_shape = [left] + [fixed_shape[i] or x.shape[i] for i in range(start, len(fixed_shape))]
        return torch.reshape(x, out_shape)


class Unflatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, v: torch.Tensor, ref: torch.Tensor, embedding_dim):
        batch_size = ref.shape[0]
        n_visits = ref.shape[1]
        out = torch.reshape(v, [batch_size, n_visits, embedding_dim])
        return out


class AttentionPooling(nn.Module):
    def __init__(self, embedding_size: int):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, inputs):
        x, mask = inputs
        x = self.fc(x)
        x[~mask] = VERY_NEGATIVE_NUMBER
        soft = F.softmax(x, dim=1)
        x[~mask] = 0
        attn_output = torch.sum(soft * x, 1)
        return attn_output


class MultiHeadAttention(nn.Module):
    def __init__(self, direction, dropout, n_units, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        self.direction = direction
        self.n_units = n_units
        self.q_linear = nn.Linear(n_units, n_units, bias=False)
        self.k_linear = nn.Linear(n_units, n_units, bias=False)
        self.v_linear = nn.Linear(n_units, n_units, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):

        # because of self-attention, queries and keys is equal to inputs
        input_tensor, input_mask = inputs
        queries = input_tensor
        keys = input_tensor

        # Linear projections
        Q = self.q_linear(queries)  # (N, L_q, d)
        K = self.k_linear(keys)  # (N, L_k, d)
        V = self.v_linear(keys)  # (N, L_k, d)

        # Split and concat
        assert self.n_units % self.n_heads == 0
        Q_ = torch.cat(torch.split(Q, self.n_units // self.n_heads, dim=2), dim=0)  # (h*N, L_q, d/h)
        K_ = torch.cat(torch.split(K, self.n_units // self.n_heads, dim=2), dim=0)  # (h*N, L_k, d/h)
        V_ = torch.cat(torch.split(V, self.n_units // self.n_heads, dim=2), dim=0)  # (h*N, L_k, d/h)

        # Multiplication
        outputs = torch.matmul(Q_, torch.permute(K_, [0, 2, 1]))  # (h*N, L_q, L_k)

        # Scale
        outputs = outputs / (list(K_.shape)[-1] ** 0.5)  # (h*N, L_q, L_k)

        # Key Masking
        key_masks = torch.sign(torch.sum(torch.abs(K_), dim=-1))  # (h*N, T_k)
        key_masks = torch.unsqueeze(key_masks, 1)  # (h*N, 1, T_k)
        key_masks = torch.tile(key_masks, [1, list(Q_.shape)[1], 1])  # (h*N, T_q, T_k)

        # Apply masks to outputs
        paddings = torch.ones_like(outputs, device=device) * (-2 ** 32 + 1)  # exp mask
        outputs = torch.where(key_masks == 0, paddings, outputs)  # (h*N, T_q, T_k)

        n_visits = list(input_tensor.shape)[1]
        sw_indices = torch.arange(0, n_visits, dtype=torch.int32, device=device)
        sw_col, sw_row = torch.meshgrid(sw_indices, sw_indices)
        if self.direction == MaskDirection.DIAGONAL:
            # shape of (n_visits, n_visits)
            attention_mask = (torch.diag(- torch.ones([n_visits], dtype=torch.int32, device=device)) + 1).bool()
        elif self.direction == MaskDirection.FORWARD:
            attention_mask = torch.greater(sw_row, sw_col)  # shape of (n_visits, n_visits)
        else: # MaskDirection.BACKWARD
            attention_mask = torch.greater(sw_col, sw_row)  # shape of (n_visits, n_visits)
        adder = (1.0 - attention_mask.type(outputs.dtype)) * -10000.0
        outputs += adder

        # softmax
        outputs = F.softmax(outputs, -1)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = torch.sign(torch.sum(torch.abs(Q_), dim=-1))  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, -1)  # (h*N, T_q, 1)
        query_masks = torch.tile(query_masks, [1, 1, list(K_.shape)[1]])  # (h*N, T_q, T_k)

        # Apply masks to outputs
        outputs = outputs * query_masks

        # Dropouts
        outputs = self.dropout(outputs)
        # Weighted sum
        outputs = torch.matmul(outputs, V_)  # ( h*N, T_q, C/h)

        # Restore shape
        outputs = torch.cat(torch.split(outputs, outputs.shape[0] // self.n_heads, dim=0), dim=2)  # (N, L_q, d)

        # input padding
        val_mask = torch.unsqueeze(input_mask, -1)
        outputs = torch.multiply(outputs, val_mask.float())

        return outputs


class PrePostProcessingWrapper(nn.Module):
  """Wrapper class that applies layer pre-processing and post-processing."""

  def __init__(self, module: nn.Module, normalized_shape: int):
    super().__init__()
    self.module = module
    self.layer_norm = LayerNorm(normalized_shape)


  def forward(self, inputs):
    """Calls wrapped layer with same parameters."""

    x, mask = inputs
    # Preprocessing: apply layer normalization
    y = self.layer_norm(x)
    # Get layer output
    try:
        y = self.module((y, mask))
    except:
        y = self.module(y)
    # Postprocessing: residual connection
    return x + y


class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            n_heads: int,
            dropout: float = 0.1,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = PrePostProcessingWrapper(
            module=MultiHeadAttention(
                direction=temporal_mask_direction,
                dropout=dropout,
                n_units=embedding_dim,
                n_heads=n_heads
            ),
            normalized_shape=embedding_dim
        )

        self.fc = PrePostProcessingWrapper(
            module=nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(embedding_dim, embedding_dim)
            ),
            normalized_shape=embedding_dim
        )

        self.output_normalization = LayerNorm(embedding_dim)

    def forward(self, inputs):
        x, mask = inputs

        out = self.attention((x, mask))
        out = self.fc((out, mask))
        out = self.output_normalization(out)
        return out, mask

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None
        if self.temporal_mask_direction == MaskDirection.FORWARD:
            return torch.tril(torch.full((n, n), -10000, device=device)).fill_diagonal_(0).float()
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            return torch.triu(torch.full((n, n), -10000, device=device)).fill_diagonal_(0).float()
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            return torch.zeros(n, n, device=device).fill_diagonal_(-10000).float()


class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 128,
            n_heads: int = 4,
            dropout: float = 0.1,
            n_mask_enc_layers: int = 2,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim

        self.flatten = Flatten()
        self.unflatten = Unflatten()

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                n_heads = n_heads,
                dropout = dropout,
                temporal_mask_direction = temporal_mask_direction,
            )

        self.code_attn = nn.Sequential()
        self.visit_attn_fw = nn.Sequential()
        self.visit_attn_bw = nn.Sequential()
        for _ in range(n_mask_enc_layers):
            self.code_attn.append(_make_mask_enc_block(MaskDirection.DIAGONAL))
            self.visit_attn_fw.append(_make_mask_enc_block(MaskDirection.FORWARD))
            self.visit_attn_bw.append(_make_mask_enc_block(MaskDirection.BACKWARD))

        # Attention pooling layers
        self.code_attn.append(AttentionPooling(embedding_dim))
        self.visit_attn_fw.append(AttentionPooling(embedding_dim))
        self.visit_attn_bw.append(AttentionPooling(embedding_dim))

        self.fc = nn.Sequential(
            nn.Linear(2*embedding_dim, embedding_dim),
            nn.ReLU()
        )

    def forward(
            self,
            embedded_codes: torch.Tensor,
            codes_mask: torch.Tensor,
            visits_mask: torch.Tensor,
            embedded_intervals: torch.Tensor = None,
    ) -> torch.Tensor:

        # input tensor, reshape 4 dimension to 3
        flattened_codes = self.flatten(embedded_codes, 2)

        # input mask, reshape 3 dimension to 2
        flattened_codes_mask = self.flatten(codes_mask, 1)

        code_attn = self.code_attn((flattened_codes, flattened_codes_mask))
        code_attn = self.unflatten(code_attn, embedded_codes, self.embedding_dim)

        if embedded_intervals is not None:
            code_attn += embedded_intervals

        u_fw = self.visit_attn_fw((code_attn, visits_mask))
        u_bw = self.visit_attn_bw((code_attn, visits_mask))
        u_bi = torch.cat([u_fw, u_bw], dim=-1)

        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: SampleDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            embedding_dim: int = 128,
            n_mask_enc_layers: int = 2,
            n_heads: int = 4,
            dropout: float = 0.1,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info, special_tokens=["<pad>", "<unk>"]
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            n_heads = n_heads,
            dropout = dropout,
            n_mask_enc_layers=n_mask_enc_layers,
        )

        self.fc = nn.Linear(self.embedding_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:

        embeddings = []
        masks = []
        intervals_embeddings = None
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]

            # each patient's feature is represented by [[code1, code2],[code3]]
            assert input_info["dim"] == 3 and input_info["type"] == str
            feature_vals = kwargs[feature_key]

            x = self.feat_tokenizers[feature_key].batch_encode_3d(feature_vals, truncation=(False, False))
            x = torch.tensor(x, dtype=torch.long, device=self.device)
            pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>")

            # Create the mask
            mask = (x != pad_idx).long()
            embeds = self.embeddings[feature_key](x)

            if feature_key == "intervals":
                intervals_embeddings = embeds
            else:
                embeddings.append(embeds)
                masks.append(mask)

        code_embeddings = torch.cat(embeddings, dim=2)
        codes_mask = torch.cat(masks, dim=2)
        visits_mask = torch.where(torch.sum(codes_mask, dim=-1) != 0, 1, 0)

        output = self.bite_net(code_embeddings, codes_mask, visits_mask, intervals_embeddings.squeeze(2))
        logits = self.fc(output)

        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)

        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

Device: cuda


In [86]:
BATCH_SIZE = 32
KS = list(range(5, 31, 5))
N_TRIALS = 10

In [87]:
def train_and_inference(model, train_loader, val_loader, test_loader, lr=0.001, monitor="pr_auc", optim = torch.optim.Adam):
    trainer = Trainer(model=model, device=device)
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=10,
        monitor=monitor,
        optimizer_class=optim,
        optimizer_params = {"lr" : lr},
        load_best_model_at_last=True
    )

    return trainer.inference(test_loader)

def precision_at_k(Y_true, Y_prob):

    Y_pred = (Y_prob > 0.5).astype(int)
    desc_idx = np.flip(np.argsort(Y_prob, axis=-1), axis=-1)

    Y_true = np.take(Y_true, desc_idx)
    Y_pred = np.take(Y_pred, desc_idx)

    precisions = [
        [
            precision_score(y_true_sample[:k], y_pred_sample[:k])
            for y_true_sample, y_pred_sample in zip(Y_true, Y_pred)
        ]
        for k in KS
    ]

    precisions = np.asarray(precisions)
    precisions = np.mean(precisions, axis=1)
    precisions = {
        str(k): p for k, p in zip(KS, precisions.tolist())
    }
    return precisions

In [88]:
metrics_df = pd.DataFrame(columns=['trial', 'model_name', 'feature_set', 'pr_auc', 'roc_auc', 'f1', '5', '10', '15', '20', '25', '30'])

In [None]:
# DxTx with interval embeddings

for trial in range(N_TRIALS):

    train, val, test = split_by_patient(dataset, [0.8, 0.1, 0.1])

    train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

    model = PyHealthBiteNet(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures", "intervals"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128,
        n_mask_enc_layers=2
    ).to(device)

    y_true, y_prob, _ = train_and_inference(
        model,
        train_loader,
        val_loader,
        test_loader,
        lr=0.001
    )
    binary_metrics = binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc", "f1"])


    model = PyHealthBiteNet(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures", "intervals"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        embedding_dim=128,
        n_mask_enc_layers=2
    ).to(device)

    y_true, y_prob, _ = train_and_inference(
        model,
        train_loader,
        val_loader,
        test_loader,
        lr=0.001,
        monitor="pr_auc_samples",
        optim=torch.optim.RMSprop
    )
    precisions = precision_at_k(y_true, y_prob)

    data = binary_metrics | precisions
    row = data | {
        "trial": trial,
        "model_name": "bitenet",
        "feature_set": "dxtx"
    }
    row = {
        k: [v] for k, v in row.items()
    }

    metrics_df = pd.concat([metrics_df, pd.DataFrame.from_dict(row)], ignore_index=True)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
    (intervals): Embedding(1756, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): PrePostProcessingWrapper(
          (module): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=False)
            (k_linear): Linear(in_features=128, out_features=128, bias=False)
            (v_linear): Linear(in_features=128, out_features=128, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer_norm): LayerNorm()
        )
        (fc): PrePostProcessingWrapper(
          (module): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.1, inp

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5226
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.44it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.1908
roc_auc: 0.4781
f1: 0.0000
loss: 0.4992
New best pr_auc score (0.1908) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.5172
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.11it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.1911
roc_auc: 0.4786
f1: 0.0000
loss: 0.4894
New best pr_auc score (0.1911) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.5147
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 56.70it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.2955
roc_auc: 0.5761
f1: 0.0000
loss: 0.4885
New best pr_auc score (0.2955) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.5071
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 55.70it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.3319
roc_auc: 0.5488
f1: 0.1918
loss: 0.4633
New best pr_auc score (0.3319) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.4969
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.59it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3521
roc_auc: 0.5442
f1: 0.2384
loss: 0.4608
New best pr_auc score (0.3521) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.4996
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.14it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3648
roc_auc: 0.6061
f1: 0.2041
loss: 0.4499
New best pr_auc score (0.3648) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.5032
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.82it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.2110
roc_auc: 0.5416
f1: 0.0000
loss: 0.4859



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.5100
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.14it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.2285
roc_auc: 0.5510
f1: 0.0000
loss: 0.4858



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.5098
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.35it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.1925
roc_auc: 0.4798
f1: 0.0000
loss: 0.4863



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.5097
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.98it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.2038
roc_auc: 0.4697
f1: 0.0000
loss: 0.4880
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.27it/s]
PyHealthBiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
    (intervals): Embedding(1756, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): PrePostProcessingWrapper(
          (module): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=False)
            (k_linear): Linear(in_features=128, out_features=128, bias=False)
            (v_linear): Linear(in_features=128, out_features=128, bias=False)
            (dropout): Dropout(p=0.1, inp

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1125
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 55.00it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.3146
loss: 0.0919
New best pr_auc_samples score (0.3146) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0889
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 54.59it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.3209
loss: 0.0861
New best pr_auc_samples score (0.3209) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0860
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 56.56it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.3275
loss: 0.0851
New best pr_auc_samples score (0.3275) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0848
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 56.41it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.3357
loss: 0.0845
New best pr_auc_samples score (0.3357) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0839
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.74it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.3477
loss: 0.0828
New best pr_auc_samples score (0.3477) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0830
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 56.85it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.3707
loss: 0.0831
New best pr_auc_samples score (0.3707) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0819
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.20it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.3864
loss: 0.0811
New best pr_auc_samples score (0.3864) at epoch-6, step-1218



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0812
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.90it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.3904
loss: 0.0811
New best pr_auc_samples score (0.3904) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0805
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 55.42it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.3948
loss: 0.0809
New best pr_auc_samples score (0.3948) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0800
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 55.00it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.3928
loss: 0.0815
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.35it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, ms

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5246
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.67it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.2277
roc_auc: 0.5261
f1: 0.0000
loss: 0.5235
New best pr_auc score (0.2277) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.5150
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 55.56it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.2262
roc_auc: 0.4935
f1: 0.0000
loss: 0.5183



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.5107
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.14it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.2288
roc_auc: 0.4913
f1: 0.0000
loss: 0.5178
New best pr_auc score (0.2288) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.5107
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.83it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.2381
roc_auc: 0.5174
f1: 0.0000
loss: 0.5185
New best pr_auc score (0.2381) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.5024
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.95it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.2274
roc_auc: 0.4998
f1: 0.0000
loss: 0.5261



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.4998
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.62it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.2416
roc_auc: 0.5224
f1: 0.0263
loss: 0.5243
New best pr_auc score (0.2416) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.4898
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.11it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.2608
roc_auc: 0.5376
f1: 0.0506
loss: 0.5357
New best pr_auc score (0.2608) at epoch-6, step-1218



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

In [79]:
print(metrics_df)

  trial model_name feature_set    pr_auc   roc_auc        f1    5   10   15   
0     0    bitenet        dxtx  0.395678  0.614033  0.294118  0.0  0.0  0.0  \

    20   25   30  
0  0.0  0.0  0.0  


In [19]:
model.eval()
data = next(iter(train_loader))
print(model(**data))

{'loss': tensor(0.5145, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.1884],
        [0.1629],
        [0.1886],
        [0.1886],
        [0.1884],
        [0.1884],
        [0.1885],
        [0.1884],
        [0.1648],
        [0.1885],
        [0.1756],
        [0.1885],
        [0.1688],
        [0.1632],
        [0.1885],
        [0.1885],
        [0.1884],
        [0.1885],
        [0.9666],
        [0.0794],
        [0.1711],
        [0.1885],
        [0.1885],
        [0.1885],
        [0.1075],
        [0.1883],
        [0.1886],
        [0.1884],
        [0.1689],
        [0.1884],
        [0.2081],
        [0.2081]], device='cuda:0', grad_fn=<SigmoidBackward0>), 'y_true': tensor([[1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
       

In [147]:
# DxTx without interval embeddings

model = PyHealthBiteNet(
    dataset = dataset,
    feature_keys = ["diagnoses", "procedures"],
    label_key = "readmission_label",
    mode = "binary",
    n_interval_embeddings=max_patient_span_days,
    embedding_dim=128,
    n_mask_enc_layers=2
).to(device)

train_and_evaluate_readm(
    model,
    train_loader,
    val_loader,
    test_loader,
    lr=0.0001
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): PrePostProcessingWrapper(
          (module): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=False)
            (k_linear): Linear(in_features=128, out_features=128, bias=False)
            (v_linear): Linear(in_features=128, out_features=128, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer_norm): LayerNorm()
        )
        (fc): PrePostProcessingWrapper(
          (module): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=128, 

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5248
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 62.32it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.2094
roc_auc: 0.4681
f1: 0.0000
loss: 0.5548
New best pr_auc score (0.2094) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.5188
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.97it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.2180
roc_auc: 0.4851
f1: 0.0000
loss: 0.5712
New best pr_auc score (0.2180) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.5201
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.90it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.2476
roc_auc: 0.5423
f1: 0.0000
loss: 0.5551
New best pr_auc score (0.2476) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.5181
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.94it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.2368
roc_auc: 0.4975
f1: 0.0000
loss: 0.5688



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.5117
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 60.11it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.2617
roc_auc: 0.5465
f1: 0.0000
loss: 0.5415
New best pr_auc score (0.2617) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.5022
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.97it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3836
roc_auc: 0.5864
f1: 0.2404
loss: 0.5308
New best pr_auc score (0.3836) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.4880
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.28it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3905
roc_auc: 0.5981
f1: 0.2444
loss: 0.5069
New best pr_auc score (0.3905) at epoch-6, step-1218



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.4819
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 62.86it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.4057
roc_auc: 0.6100
f1: 0.2637
loss: 0.5065
New best pr_auc score (0.4057) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.4715
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.80it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.4033
roc_auc: 0.6140
f1: 0.2500
loss: 0.5142



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.4670
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 61.28it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.4023
roc_auc: 0.6072
f1: 0.2618
loss: 0.5057
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 57.89it/s]

{'pr_auc': 0.3521874464960856, 'roc_auc': 0.5945835982199618, 'f1': 0.19653179190751446, 'loss': 0.49657205966385926}





In [145]:
model.train()
data = next(iter(train_loader))
print(model(**data))

{'loss': tensor(0.6983, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.5025],
        [0.5035],
        [0.5022],
        [0.5026],
        [0.5046],
        [0.5077],
        [0.5033],
        [0.5042],
        [0.5040],
        [0.5024],
        [0.5042],
        [0.5033],
        [0.5024],
        [0.5034],
        [0.5037],
        [0.5047],
        [0.5036],
        [0.5035],
        [0.5042],
        [0.5053],
        [0.5039],
        [0.5037],
        [0.5060],
        [0.5030],
        [0.5043],
        [0.5043],
        [0.5046],
        [0.5029],
        [0.5031],
        [0.5029],
        [0.5070],
        [0.5070]], device='cuda:0', grad_fn=<SigmoidBackward0>), 'y_true': tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
       

In [129]:
train_and_evaluate_readm(
    PyHealthBiteNet(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures", "intervals"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        n_interval_embeddings=max_patient_span_days,
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.0001,
    monitor="pr_auc_samples"
)

PyHealthBiteNet(
  (interval_embedding): Embedding(4221, 128)
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict(
    (intervals): Linear(in_features=1, out_features=128, bias=True)
  )
  (bite_net): BiteNet(
    (flatten): Flatten()
    (unflatten): Unflatten()
    (code_attn): Sequential(
      (0): MaskEnc(
        (attention): MultiHeadAttention(
          (q_linear): Linear(in_features=128, out_features=128, bias=False)
          (k_linear): Linear(in_features=128, out_features=128, bias=False)
          (v_linear): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=128, out_features=128, bias=True)
        )
 

Using visit interval embeddings


Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1735
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.82it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.3058
loss: 0.1235
New best pr_auc_samples score (0.3058) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.1225
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.67it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.3082
loss: 0.1216
New best pr_auc_samples score (0.3082) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.1202
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.36it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.3119
loss: 0.1198
New best pr_auc_samples score (0.3119) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.1188
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.82it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.3177
loss: 0.1179
New best pr_auc_samples score (0.3177) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.1165
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.51it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.3201
loss: 0.1154
New best pr_auc_samples score (0.3201) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.1143
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.95it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.3244
loss: 0.1127
New best pr_auc_samples score (0.3244) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.1115
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.46it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.3263
loss: 0.1102
New best pr_auc_samples score (0.3263) at epoch-6, step-1218



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.1088
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 59.95it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.3295
loss: 0.1075
New best pr_auc_samples score (0.3295) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.1060
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.67it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.3554
loss: 0.1046
New best pr_auc_samples score (0.3554) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.1030
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 58.05it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.3580
loss: 0.1018
New best pr_auc_samples score (0.3580) at epoch-9, step-1740
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 54.32it/s]


{'pr_auc_samples': 0.3462034725410505, 'loss': 0.10405040837147018}


In [130]:
train_and_evaluate_readm(
    RNN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001
)

RNN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (diagnoses): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x00000205CDD7ECD0>
Monitor: pr_auc
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.4974
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 149.66it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.4121
roc_auc: 0.6199
f1: 0.2500
loss: 0.5142
New best pr_auc score (0.4121) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.4109
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 158.27it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.4283
roc_auc: 0.6278
f1: 0.2500
loss: 0.5245
New best pr_auc score (0.4283) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.3374
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 158.27it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.4144
roc_auc: 0.6105
f1: 0.2609
loss: 0.5536



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.2528
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 154.93it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.4217
roc_auc: 0.6113
f1: 0.2586
loss: 0.6107



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.1814
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 151.73it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3882
roc_auc: 0.5719
f1: 0.2661
loss: 0.7234



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.1266
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 153.85it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3974
roc_auc: 0.5835
f1: 0.3125
loss: 0.7652



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0886
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 151.73it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3971
roc_auc: 0.5769
f1: 0.2881
loss: 0.8607



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0617
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 150.69it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.4137
roc_auc: 0.5971
f1: 0.3083
loss: 0.8973



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0426
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 150.69it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.4145
roc_auc: 0.5898
f1: 0.3154
loss: 1.0103



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0324
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 150.68it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.4106
roc_auc: 0.5906
f1: 0.3033
loss: 1.1282
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 169.23it/s]

{'pr_auc': 0.4088044614253502, 'roc_auc': 0.6401017164653529, 'f1': 0.2573099415204678, 'loss': 0.4801513755863363}





In [131]:
train_and_evaluate_readm(
    RNN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001,
    monitor="pr_auc_samples"
)

RNN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (diagnoses): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=465, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x00000205CDD7ECD0>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1068
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 165.42it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.3980
loss: 0.0834
New best pr_auc_samples score (0.3980) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0816
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 162.97it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.4224
loss: 0.0797
New best pr_auc_samples score (0.4224) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0780
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 134.15it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.4421
loss: 0.0778
New best pr_auc_samples score (0.4421) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0752
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 130.18it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.4508
loss: 0.0767
New best pr_auc_samples score (0.4508) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0730
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 148.65it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.4543
loss: 0.0768
New best pr_auc_samples score (0.4543) at epoch-4, step-870



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0708
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 134.97it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.4601
loss: 0.0762
New best pr_auc_samples score (0.4601) at epoch-5, step-1044



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0690
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 135.80it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.4616
loss: 0.0762
New best pr_auc_samples score (0.4616) at epoch-6, step-1218



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0673
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 141.03it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.4633
loss: 0.0764
New best pr_auc_samples score (0.4633) at epoch-7, step-1392



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0657
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 136.65it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.4635
loss: 0.0766
New best pr_auc_samples score (0.4635) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0642
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 131.74it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.4627
loss: 0.0771
Loaded best model
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 144.74it/s]


{'pr_auc_samples': 0.4410093371163336, 'loss': 0.08005004884167151}


In [27]:
train_and_evaluate_readm(
    RETAIN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader
)

RETAIN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (diagnoses): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 't

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5355
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 84.29it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.3553
roc_auc: 0.5890
f1: 0.1775
loss: 0.5163
New best pr_auc score (0.3553) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.3683
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.88it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.3871
roc_auc: 0.6180
f1: 0.2737
loss: 0.5194
New best pr_auc score (0.3871) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.2285
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 77.19it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.3715
roc_auc: 0.6024
f1: 0.2526
loss: 0.6153



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.1225
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.01it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.3835
roc_auc: 0.5939
f1: 0.3097
loss: 0.6973



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0688
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 79.42it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3848
roc_auc: 0.6013
f1: 0.3117
loss: 0.8054



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0419
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.39it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3802
roc_auc: 0.5960
f1: 0.3084
loss: 0.9247



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0313
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.29it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3778
roc_auc: 0.6086
f1: 0.3070
loss: 0.9865



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0186
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 79.71it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.3776
roc_auc: 0.5983
f1: 0.2857
loss: 1.1027



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0181
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 73.83it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.3936
roc_auc: 0.6030
f1: 0.3004
loss: 1.1505
New best pr_auc score (0.3936) at epoch-8, step-1566



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0158
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.29it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.3926
roc_auc: 0.6023
f1: 0.3064
loss: 1.2158
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 82.09it/s]

{'pr_auc': 0.3606033694496517, 'roc_auc': 0.5797932330827067, 'f1': 0.29729729729729726, 'loss': 1.2300627353516491}





In [46]:
train_and_evaluate_readm(
    RETAIN(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "diagnosis_label",
        mode = "multilabel",
        embedding_dim=128,
        dropout=0.1
    ).to(device),
    train_loader,
    val_loader,
    test_loader,
    lr=0.001,
    monitor="pr_auc_samples"
)

RETAIN(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3374, 128, padding_idx=0)
    (procedures): Embedding(1362, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (diagnoses): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=465, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 

Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.1180
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 81.18it/s]
--- Eval epoch-0, step-174 ---
pr_auc_samples: 0.4010
loss: 0.0930
New best pr_auc_samples score (0.4010) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.0811
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 75.34it/s]
--- Eval epoch-1, step-348 ---
pr_auc_samples: 0.4227
loss: 0.0878
New best pr_auc_samples score (0.4227) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.0747
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 70.29it/s]
--- Eval epoch-2, step-522 ---
pr_auc_samples: 0.4304
loss: 0.0862
New best pr_auc_samples score (0.4304) at epoch-2, step-522



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.0707
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.57it/s]
--- Eval epoch-3, step-696 ---
pr_auc_samples: 0.4357
loss: 0.0863
New best pr_auc_samples score (0.4357) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0674
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 78.02it/s]
--- Eval epoch-4, step-870 ---
pr_auc_samples: 0.4313
loss: 0.0869



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0648
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 77.74it/s]
--- Eval epoch-5, step-1044 ---
pr_auc_samples: 0.4356
loss: 0.0875



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0626
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 67.69it/s]
--- Eval epoch-6, step-1218 ---
pr_auc_samples: 0.4325
loss: 0.0880



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0604
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.39it/s]
--- Eval epoch-7, step-1392 ---
pr_auc_samples: 0.4332
loss: 0.0890



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0586
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 76.66it/s]
--- Eval epoch-8, step-1566 ---
pr_auc_samples: 0.4300
loss: 0.0904



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0569
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 80.88it/s]
--- Eval epoch-9, step-1740 ---
pr_auc_samples: 0.4272
loss: 0.0914
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 68.54it/s]


{'pr_auc_samples': 0.4249831201075734, 'loss': 0.09086060625585643}


In [29]:
train_and_evaluate_readm(
    Deepr(
        dataset = dataset,
        feature_keys = ["diagnoses", "procedures"],
        label_key = "readmission_label",
        mode = "binary",
        embedding_dim=128
    ).to(device),
    train_loader,
    val_loader,
    test_loader
)

Deepr(
  (embeddings): ModuleDict(
    (diagnoses): Embedding(3375, 128, padding_idx=0)
    (procedures): Embedding(1363, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (cnn): ModuleDict(
    (diagnoses): DeeprLayer(
      (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
    )
    (procedures): DeeprLayer(
      (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
    )
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.rmsprop.RMSprop'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x00000211A794E700>
Monitor: pr_auc
Monitor criterion: max
Epochs: 10



Epoch 0 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-0, step-174 ---
loss: 0.5101
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 183.33it/s]
--- Eval epoch-0, step-174 ---
pr_auc: 0.4078
roc_auc: 0.6345
f1: 0.1863
loss: 0.4984
New best pr_auc score (0.4078) at epoch-0, step-174



Epoch 1 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-1, step-348 ---
loss: 0.3764
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 268.29it/s]
--- Eval epoch-1, step-348 ---
pr_auc: 0.4154
roc_auc: 0.6332
f1: 0.2755
loss: 0.4897
New best pr_auc score (0.4154) at epoch-1, step-348



Epoch 2 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-2, step-522 ---
loss: 0.2655
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 258.82it/s]
--- Eval epoch-2, step-522 ---
pr_auc: 0.3905
roc_auc: 0.6049
f1: 0.3258
loss: 0.5340



Epoch 3 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-3, step-696 ---
loss: 0.1544
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 249.96it/s]
--- Eval epoch-3, step-696 ---
pr_auc: 0.4250
roc_auc: 0.6310
f1: 0.3577
loss: 0.5696
New best pr_auc score (0.4250) at epoch-3, step-696



Epoch 4 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-4, step-870 ---
loss: 0.0791
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 244.44it/s]
--- Eval epoch-4, step-870 ---
pr_auc: 0.3979
roc_auc: 0.6118
f1: 0.2922
loss: 0.6220



Epoch 5 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-5, step-1044 ---
loss: 0.0389
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 261.87it/s]
--- Eval epoch-5, step-1044 ---
pr_auc: 0.3990
roc_auc: 0.6011
f1: 0.2842
loss: 0.7871



Epoch 6 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-6, step-1218 ---
loss: 0.0175
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.79it/s]
--- Eval epoch-6, step-1218 ---
pr_auc: 0.3948
roc_auc: 0.6088
f1: 0.2737
loss: 0.8633



Epoch 7 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-7, step-1392 ---
loss: 0.0168
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.82it/s]
--- Eval epoch-7, step-1392 ---
pr_auc: 0.4105
roc_auc: 0.6164
f1: 0.2932
loss: 0.9251



Epoch 8 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-8, step-1566 ---
loss: 0.0055
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 255.81it/s]
--- Eval epoch-8, step-1566 ---
pr_auc: 0.3957
roc_auc: 0.6097
f1: 0.3136
loss: 0.8593



Epoch 9 / 10:   0%|          | 0/174 [00:00<?, ?it/s]

--- Train epoch-9, step-1740 ---
loss: 0.0027
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 252.87it/s]
--- Eval epoch-9, step-1740 ---
pr_auc: 0.4183
roc_auc: 0.6189
f1: 0.3349
loss: 0.9303
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 224.49it/s]

{'pr_auc': 0.39479580130583247, 'roc_auc': 0.6421455424274973, 'f1': 0.303030303030303, 'loss': 0.8376922133294019}



