Google Cloud authentication and data download

In [1]:
# ! gcloud auth login
# ! gcloud auth application-default login
# ! gcloud config set project dl4h-final-project-383605

In [3]:
# ! pip install --upgrade google-api-python-client google-cloud-storage
from google.cloud import storage

# Replace these values with your project and bucket as needed
project_id = "dl4h-final-project-383605"
mimic3_bucket = "mimiciii-1.4.physionet.org"

storage_client = storage.Client(project=project_id)
bucket = storage_client.bucket(mimic3_bucket)
data_folder = "./data"
for blob in bucket.list_blobs():
  if "CHARTEVENTS" in blob.name:
    continue
  blob.download_to_filename(f"{data_folder}/{blob.name}")

# Extract all the files
! gunzip {data_folder}/*.gz

gzip: *.gz: No such file or directory


In [5]:
! gunzip {data_folder}/*.gz

Library imports and data preparation

In [59]:
# ! pip install pyhealth
from pyhealth.datasets import MIMIC3Dataset, SampleDataset
from pyhealth.data import Patient, Visit
import pandas as pd
from pyhealth.datasets import split_by_patient, get_dataloader, BaseDataset
from pyhealth.models import BaseModel
from pyhealth.trainer import Trainer
from pyhealth.metrics.binary import binary_metrics_fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import List, Dict, Optional
from enum import Enum
from functools import reduce
from operator import mul

# Set this to the directory with all MIMIC-3 dataset files
data_root = "./data"

In [109]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

Parsing PATIENTS and ADMISSIONS: 100%|██████████| 46520/46520 [00:19<00:00, 2423.10it/s]
Parsing DIAGNOSES_ICD: 100%|██████████| 58929/58929 [00:03<00:00, 17375.21it/s]
Parsing PROCEDURES_ICD: 100%|██████████| 52243/52243 [00:01<00:00, 27342.92it/s]
Mapping codes: 100%|██████████| 46520/46520 [00:00<00:00, 167922.90it/s]


In [61]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: MIMIC3Dataset
	- Number of patients: 1000
	- Number of visits: 1295
	- Number of visits per patient: 1.2950
	- Number of events per visit in DIAGNOSES_ICD: 9.3544
	- Number of events per visit in PROCEDURES_ICD: 4.3351



'\nStatistics of base dataset (dev=True):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 1000\n\t- Number of visits: 1295\n\t- Number of visits per patient: 1.2950\n\t- Number of events per visit in DIAGNOSES_ICD: 9.3544\n\t- Number of events per visit in PROCEDURES_ICD: 4.3351\n'

In [110]:
# Find all diagnoses codes
# Find all procedure codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diag_codes = []
all_proc_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    procedures = visit.get_code_list(table="PROCEDURES_ICD")
    all_diag_codes.extend(conditions)
    all_proc_codes.extend(procedures)

codes = pd.Series(all_diag_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
num_unique_diag_codes = len(filtered_diag_codes)

unique_proc_codes = list(set(all_proc_codes))
num_unique_proc_codes = len(unique_proc_codes)

In [111]:
proc_code_to_index_map = {}
diag_code_to_index_map = {}

index = 0
for proc_code in unique_proc_codes:
    proc_code_to_index_map[proc_code] = index
    index += 1

index = 0
for diag_code in filtered_diag_codes:
    diag_code_to_index_map[diag_code] = index
    index += 1

In [112]:
# Define the tasks

DIAGNOSES_KEY = "conditions"
PROCEDURES_KEY = "procedures"
INTERVAL_DAYS_KEY = "days_since_first_visit"

def flatten(l: List):
    return [item for sublist in l for item in sublist]

def patient_level_readmission_prediction(patient, time_window=30):
    """
    patient is a <pyhealth.data.Patient> object
    """
    samples = []

    # if the patient only has one visit, we drop it
    if len(patient) == 1:
        return []

    sorted_visits = sorted(patient, key=lambda visit: visit.encounter_time)

    # step 1: define label
    idx_last_visit = len(sorted_visits)-1
    last_visit: Visit = sorted_visits[idx_last_visit]
    second_to_last_visit: Visit = sorted_visits[idx_last_visit - 1]
    first_visit: Visit = sorted_visits[0]

    time_diff = (last_visit.encounter_time - second_to_last_visit.encounter_time).days
    readmission_label = 1 if time_diff < time_window else 0

    # step 2: obtain features
    visits_conditions = []
    visits_procedures = []
    visits_intervals = []
    for idx, visit in enumerate(sorted_visits):
        if idx == len(sorted_visits) - 1: break
        conditions = [c for c in visit.get_code_list(table="DIAGNOSES_ICD") if c in filtered_diag_codes]
        procedures = visit.get_code_list(table="PROCEDURES_ICD")
        time_diff_from_first_visit = (visit.encounter_time - first_visit.encounter_time).days

        if len(conditions) * len(procedures) == 0:
            continue

        visits_conditions.append(conditions)
        visits_procedures.append(procedures)
        visits_intervals.append([time_diff_from_first_visit])

    unique_conditions = list(set(flatten(visits_conditions)))
    unique_procedures = list(set(flatten(visits_procedures)))

    # step 3: exclusion criteria
    if len(unique_conditions) * len(unique_procedures) == 0:
        return []

    # step 4: assemble the sample
    samples.append(
        {
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "conditions": visits_conditions,
            "procedures": visits_procedures,
            "intervals": visits_intervals,
            "label": readmission_label,
        }
    )
    return samples

In [113]:
# Create the task datasets
mimic3_dxtx = mimic3_ds.set_task(task_fn=patient_level_readmission_prediction)

Generating samples for patient_level_readmission_prediction: 100%|██████████| 46520/46520 [00:08<00:00, 5227.98it/s]


In [114]:
# Get the unique visit intervals
all_intervals = []
for sample in mimic3_dxtx:
    intervals = flatten(sample['intervals'])
    all_intervals.extend(intervals)

unique_intervals = list(set(all_intervals))
num_unique_intervals = len(unique_intervals)

visit_interval_to_index_map = {}

index = 0
for visit_interval in unique_intervals:
    visit_interval_to_index_map[visit_interval] = index
    index += 1

In [67]:
# def collate_codes(batch_codes, i_patient, patient, feature_key, code_to_index_map, mask_tensor):
#     for i_visit, visit_codes in enumerate(patient[feature_key]):
#         for i_code, code in enumerate(visit_codes):
#             batch_codes[i_patient][i_visit][i_code] = code_to_index_map[code]
#
#         # Set the mask for the visit codes
#         num_codes = len(visit_codes)
#         mask_tensor[i_patient][i_visit][:num_codes] = 1
#
# def collate_fn(batch):
#     batch_size = len(batch)
#
#     max_num_visits = 0
#     max_num_diagnosis_codes_per_visit = 0
#     max_num_procedure_codes_per_visit = 0
#
#     for patient in batch:
#         patient_num_visits = len(patient['conditions'])
#         if patient_num_visits > max_num_visits:
#             max_num_visits = patient_num_visits
#         for visit_conditions in patient['conditions']:
#             num_visit_diagnosis_codes = len(visit_conditions)
#             if num_visit_diagnosis_codes > max_num_diagnosis_codes_per_visit:
#                 max_num_diagnosis_codes_per_visit = num_visit_diagnosis_codes
#         for visit_procedures in patient["procedures"]:
#             num_visit_procedure_codes = len(visit_procedures)
#             if num_visit_procedure_codes > max_num_procedure_codes_per_visit:
#                 max_num_procedure_codes_per_visit = num_visit_procedure_codes
#
#     batch_procedures = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
#     batch_conditions = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()
#     batch_intervals = torch.zeros(batch_size, max_num_visits).long()
#     batch_labels = torch.zeros(batch_size).long()
#
#     batch_visits_masks = torch.zeros(batch_size, max_num_visits).long()
#     batch_procedures_masks = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
#     batch_conditions_masks = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()
#
#     for i_patient, patient in enumerate(batch):
#         # Collate diagnosis and procedure codes
#         collate_codes(batch_procedures, i_patient, patient, "procedures", proc_code_to_index_map, batch_procedures_masks)
#         collate_codes(batch_conditions, i_patient, patient, "conditions", diag_code_to_index_map, batch_conditions_masks)
#
#         # Get the number of visits this patient has
#         visit_intervals = flatten(patient["intervals"])
#         num_visits = len(visit_intervals)
#
#         # Set the visits mask for the patient
#         batch_visits_masks[i_patient][:num_visits] = 1
#
#         # Collate the visit intervals by setting the onehot encoding
#         batch_intervals[i_patient][:num_visits] = 1
#
#         # Set the label
#         batch_labels[i_patient] = patient["label"]
#
#     return (
#         batch_procedures,
#         batch_conditions,
#         batch_intervals,
#         batch_visits_masks,
#         batch_procedures_masks,
#         batch_conditions_masks,
#         batch_labels
#     )
#
# BATCH_SIZE = 2
# train, val, test = split_by_patient(mimic3_dxtx, [0.8, 0.1, 0.1])
#
# # obtain train/val/test dataloader, they are <torch.data.DataLoader> object
# train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
# val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
# test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
#
# _ = next(iter(train_loader))
# print(_[0])

tensor([[[427,  16, 431, 217, 258]],

        [[274,   0,   0,   0,   0]]])


In [172]:
def collate_codes(batch_codes, i_patient, visits, code_to_index_map, mask_tensor: Optional[torch.Tensor] = None):
    for i_visit, visit_codes in enumerate(visits):
        for i_code, code in enumerate(visit_codes):
            batch_codes[i_patient][i_visit][i_code] = code_to_index_map[code]

        if mask_tensor is not None:
            # Set the mask for the visit codes
            num_codes = len(visit_codes)
            mask_tensor[i_patient][i_visit][:num_codes] = 1

def collate_fn(conditions, procedures, intervals):
    batch_size = len(conditions)

    max_num_visits = 0
    max_num_diagnosis_codes_per_visit = 0
    max_num_procedure_codes_per_visit = 0

    for i_patient in range(batch_size):
        patient_conditions = conditions[i_patient]
        patient_procedures = procedures[i_patient]
        patient_num_visits = len(patient_conditions)
        if patient_num_visits > max_num_visits:
            max_num_visits = patient_num_visits
        for visit_conditions in patient_conditions:
            num_visit_diagnosis_codes = len(visit_conditions)
            if num_visit_diagnosis_codes > max_num_diagnosis_codes_per_visit:
                max_num_diagnosis_codes_per_visit = num_visit_diagnosis_codes
        for visit_procedures in patient_procedures:
            num_visit_procedure_codes = len(visit_procedures)
            if num_visit_procedure_codes > max_num_procedure_codes_per_visit:
                max_num_procedure_codes_per_visit = num_visit_procedure_codes

    batch_procedures = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
    batch_conditions = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()
    batch_intervals = torch.zeros(batch_size, max_num_visits).long()

    batch_visits_masks = torch.zeros(batch_size, max_num_visits).long()
    batch_procedures_masks = torch.zeros(batch_size, max_num_visits, max_num_procedure_codes_per_visit).long()
    batch_conditions_masks = torch.zeros(batch_size, max_num_visits, max_num_diagnosis_codes_per_visit).long()

    for i_patient in range(batch_size):
        # Collate diagnosis and procedure codes
        collate_codes(batch_procedures, i_patient, procedures[i_patient], proc_code_to_index_map, batch_procedures_masks)
        collate_codes(batch_conditions, i_patient, conditions[i_patient], diag_code_to_index_map, batch_conditions_masks)

        for i_interval, interval in enumerate(intervals[i_patient]):
            interval = interval[0]
            batch_intervals[i_patient][i_interval] = visit_interval_to_index_map[interval]
            batch_visits_masks[i_patient][i_interval] = 1

    return (
        batch_procedures,
        batch_conditions,
        batch_intervals,
        batch_visits_masks,
        batch_procedures_masks,
        batch_conditions_masks
    )

BATCH_SIZE = 32
train, val, test = split_by_patient(mimic3_dxtx, [0.8, 0.1, 0.1])

train_loader = get_dataloader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val, batch_size=BATCH_SIZE, shuffle=False)
test_loader = get_dataloader(test, batch_size=BATCH_SIZE, shuffle=False)

In [187]:
# Define the models
VERY_BIG_NUMBER = 1e30
VERY_SMALL_NUMBER = 1e-30
VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER

class MaskDirection(Enum):
    FORWARD = 'forward'
    BACKWARD = 'backward'
    DIAGONAL = 'diagonal'
    NONE = 'none'

# class MaskedLayerNorm(nn.Module):
#     def __init__(self, normalized_shape: int):
#         super().__init__()
#         self.gamma = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.beta = nn.parameter.Parameter(torch.randn(normalized_shape))
#         self.eps = 1e-5
#         self.normalized_shape = normalized_shape
#
#     def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
#         print(x.shape)
#         print(key_padding_mask.shape)
#         n = torch.sum(torch.ones_like(x) * (~key_padding_mask).int().unsqueeze(-1).expand(-1, -1, self.normalized_shape), dim=-1)
#         print(n.shape)
#         print(n)
#
#         # expected_val = torch.nanmean(x, dim=1)
#         # variance = torch.nansum(())

def flatten_keep(x: torch.Tensor, keep: int):
    fixed_shape = list(x.size())
    start = len(fixed_shape) - keep
    left = reduce(mul, [fixed_shape[i] or x.shape[i] for i in range(start)])
    out_shape = [left] + [fixed_shape[i] or x.shape[i] for i in range(start, len(fixed_shape))]
    return torch.reshape(x, out_shape)


def reshape(v, ref, embedding_dim):
    batch_size = ref.shape[0]
    n_visits = ref.shape[1]
    out = torch.reshape(v, [batch_size, n_visits, embedding_dim])
    return out

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        means = torch.full(
            size=(vocab_size, embedding_dim),
            fill_value=0.0
        )
        stds = torch.full(
            size=(vocab_size, embedding_dim),
            fill_value=embedding_dim ** -0.5
        )
        self.weights = nn.Parameter(torch.normal(mean=means, std=stds))

    def forward(self, x: torch.Tensor, mask):
        embeddings = self.weights[x.reshape(-1)]
        embeddings = embeddings.reshape(list(x.shape) + [self.embedding_dim])
        embeddings *= torch.unsqueeze(mask, -1)

        # Scale embedding by the sqrt of the hidden size
        embeddings *= self.embedding_dim ** 0.5
        return embeddings


class AttentionPooling(nn.Module):
    def __init__(self, embedding_size: int):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.ReLU(),
            nn.Linear(embedding_size, embedding_size)
        )

    def forward(self, input):
        x = self.fc(input)
        # x = torch.nan_to_num(x, nan=VERY_NEGATIVE_NUMBER)
        soft = F.softmax(x, dim=1) # Todo: fixme
        attn_output = torch.nansum(soft * input, 1)
        return attn_output


# todo: add temporal encoding
# todo: implement layer normalization
class MaskEnc(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            num_heads: int,
            dropout: float = 0.1,
            batch_first: bool = True,
            temporal_mask_direction: MaskDirection = MaskDirection.NONE,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.temporal_mask_direction = temporal_mask_direction

        self.attention = nn.MultiheadAttention(
                embed_dim=embedding_dim,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=batch_first
            )
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )

        # self.masked_layer_norm1 = MaskedLayerNorm(embedding_dim)
        # self.masked_layer_norm2 = MaskedLayerNorm(embedding_dim)

    def forward(
            self,
            x: torch.Tensor,
            key_padding_mask: Optional[torch.Tensor] = None
    ):
        attn_mask = self._make_temporal_mask(x.shape[1])

        attn_output, attn_output_weights = self.attention(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        attn_output = self.dropout(attn_output)
        out = self.fc(attn_output)

        # x = self.masked_layer_norm1(x + attn_output, key_padding_mask)
        # x = self.masked_layer_norm2(x + self.fc(x), key_padding_mask)
        return out

    def _make_temporal_mask(self, n: int) -> Optional[torch.Tensor]:
        if self.temporal_mask_direction == MaskDirection.NONE:
            return None

        if self.temporal_mask_direction == MaskDirection.FORWARD:
            return torch.tril(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.BACKWARD:
            return torch.triu(torch.ones(n,n))
        if self.temporal_mask_direction == MaskDirection.DIAGONAL:
            return torch.zeros(n,n).fill_diagonal_(1)


class BiteNet(nn.Module):
    def __init__(
            self,
            embedding_dim: int = 128,
            output_dim: int = 1,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            n_mask_enc_layers: int = 2,
            use_procedures: bool = True,
            use_intervals: bool = True,
            num_diag_codes: int = num_unique_diag_codes,
            num_proc_codes: int = num_unique_proc_codes,
            num_intervals: int = num_unique_intervals,
    ):
        super().__init__()

        self.use_intervals = use_intervals
        self.use_procedures = use_procedures
        self.embedding_dim = embedding_dim

        self.diag_emb = Embedding(num_diag_codes, embedding_dim)
        self.proc_emb = Embedding(num_proc_codes, embedding_dim)
        self.interval_emb = Embedding(num_intervals, embedding_dim)

        def _make_mask_enc_block(temporal_mask_direction: MaskDirection = MaskDirection.NONE):
            return MaskEnc(
                embedding_dim = embedding_dim,
                num_heads = num_heads,
                dropout = dropout,
                batch_first = batch_first,
                temporal_mask_direction = temporal_mask_direction,
            )

        self.code_level_attn = nn.ModuleList()
        self.visit_level_attn_forward = nn.ModuleList()
        self.visit_level_attn_backward = nn.ModuleList()
        for _ in range(n_mask_enc_layers):
            self.code_level_attn.append(_make_mask_enc_block(MaskDirection.DIAGONAL))
            self.visit_level_attn_forward.append(_make_mask_enc_block(MaskDirection.FORWARD))
            self.visit_level_attn_backward.append(_make_mask_enc_block(MaskDirection.BACKWARD))
        self.code_level_attn.append(AttentionPooling(embedding_dim))
        self.visit_level_attn_forward.append(AttentionPooling(embedding_dim))
        self.visit_level_attn_backward.append(AttentionPooling(embedding_dim))

        self.fc = nn.Sequential(
            nn.Linear(2*embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, output_dim)
        )

    def forward(
            self,
            procedures: torch.Tensor,
            conditions: torch.Tensor,
            intervals: torch.Tensor,
            visits_mask: torch.Tensor,
            procedures_mask: torch.Tensor,
            conditions_mask: torch.Tensor
    ) -> torch.Tensor:

        embedded_conditions = self.diag_emb(conditions, conditions_mask)

        if self.use_procedures:
            procedures_emb = self.proc_emb(procedures, procedures_mask)
            embedded_codes = torch.cat([embedded_conditions, procedures_emb], dim=2)
            codes_mask = torch.cat([conditions_mask, procedures_mask], dim=-1)
        else:
            embedded_codes = embedded_conditions
            codes_mask = conditions_mask

        codes_mask = ~(codes_mask.bool())

        # input tensor, reshape 4 dimension to 3
        flattened_codes = flatten_keep(embedded_codes, 2)

        # input mask, reshape 3 dimension to 2
        flattened_codes_mask = flatten_keep(codes_mask, 1)

        code_attn = flattened_codes
        for i, l in enumerate(self.code_level_attn):
            if i == len(self.code_level_attn) - 1:
                code_attn = l(code_attn)
            else:
                code_attn = l(code_attn, flattened_codes_mask)
        code_attn = reshape(code_attn, embedded_codes, self.embedding_dim)

        if self.use_intervals:
            embedded_intervals = self.interval_emb(intervals, visits_mask)
            code_attn += embedded_intervals

        visits_mask = ~(visits_mask.bool())

        u_fw = code_attn
        for i, l in enumerate(self.visit_level_attn_forward):
            if i == len(self.visit_level_attn_forward) - 1:
                u_fw = l(u_fw)
            else:
                u_fw = l(u_fw, visits_mask)

        u_bw = code_attn
        for i, l in enumerate(self.visit_level_attn_backward):
            if i == len(self.visit_level_attn_backward) - 1:
                u_bw = l(u_bw)
            else:
                u_bw = l(u_bw, visits_mask)

        u_bi = torch.cat([u_fw, u_bw], dim=-1)
        s = self.fc(u_bi)
        return s

class PyHealthBiteNet(BaseModel):
    def __init__(
            self,
            dataset: SampleDataset,
            feature_keys: List[str],
            label_key: str,
            mode: str,
            embedding_dim: int = 128,
            n_mask_enc_layers: int = 2,
            use_intervals: bool = True,
            use_procedures: bool = True,
            num_heads: int = 4,
            dropout: float = 0.1,
            batch_first: bool = True,
            **kwargs
    ):
        super().__init__(dataset, feature_keys, label_key, mode)

        # Any BaseModel should have these attributes, as functions like add_feature_transform_layer uses them
        self.feat_tokenizers = {}
        self.embeddings = nn.ModuleDict()
        self.linear_layers = nn.ModuleDict()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embedding_dim = embedding_dim

        # self.add_feature_transform_layer will create a transformation layer for each feature
        for feature_key in self.feature_keys:
            input_info = self.dataset.input_info[feature_key]
            self.add_feature_transform_layer(
                feature_key, input_info
            )

        # final output layer
        output_size = self.get_output_size(self.label_tokenizer)
        self.bite_net = BiteNet(
            embedding_dim = embedding_dim,
            output_dim = output_size,
            num_heads = num_heads,
            dropout = dropout,
            batch_first = batch_first,
            use_intervals=use_intervals,
            use_procedures=use_procedures
        )

        # self.fc = nn.Linear(len(self.feature_keys) * hidden_dim, output_size)

    def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
        conditions = kwargs['conditions']
        procedures = kwargs['procedures']
        intervals = kwargs['intervals']

        input_procedures, input_conditions, input_intervals, visits_mask, procedures_mask, conditions_mask = \
            collate_fn(conditions, procedures, intervals)

        logits = self.bite_net(
            input_procedures,
            input_conditions,
            input_intervals,
            visits_mask,
            procedures_mask,
            conditions_mask
        )

        # obtain y_true, loss, y_prob
        y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
        loss = self.get_loss_function()(logits, y_true)
        y_prob = self.prepare_y_prob(logits)
        return {"loss": loss, "y_prob": y_prob, "y_true": y_true}

model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ['procedures', 'conditions', 'intervals'],
    label_key = "label",
    mode = "binary",
    embedding_dim=128
)

data = next(iter(train_loader))
model_dxtx(**data)

# model = BiteNet(
#     embedding_dim = 4,
#     output_dim = 1,
#     num_heads = 4,
#     dropout = 0
# )
# procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks, labels = next(iter(train_loader))
# model(procedures, conditions, intervals, visits_masks, procedures_masks, conditions_masks)

{'loss': tensor(0.7162, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'y_prob': tensor([[0.5215],
         [0.5220],
         [0.5217],
         [0.5220],
         [0.5223],
         [0.5224],
         [0.5217],
         [0.5221],
         [0.5216],
         [0.5225],
         [0.5225],
         [0.5217],
         [0.5221],
         [0.5217],
         [0.5214],
         [0.5221],
         [0.5213],
         [0.5218],
         [0.5215],
         [0.5217],
         [0.5221],
         [0.5216],
         [0.5218],
         [0.5222],
         [0.5219],
         [0.5219],
         [0.5217],
         [0.5222],
         [0.5227],
         [0.5223],
         [0.5220],
         [0.5222]], grad_fn=<SigmoidBackward0>),
 'y_true': tensor([[0.],
         [0.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.]

In [188]:
model_dxtx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions", "procedures"],
    label_key = "label",
    mode = "binary",
)

model_dx = PyHealthBiteNet(
    dataset = mimic3_dxtx,
    feature_keys = ["conditions"],
    label_key = "label",
    mode = "binary",
)

In [189]:
trainer_dxtx = Trainer(model=model_dxtx)
trainer_dxtx.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="pr_auc",
)

PyHealthBiteNet(
  (embeddings): ModuleDict(
    (conditions): Embedding(3383, 128, padding_idx=0)
    (procedures): Embedding(1366, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (bite_net): BiteNet(
    (diag_emb): Embedding()
    (proc_emb): Embedding()
    (interval_emb): Embedding()
    (code_level_attn): ModuleList(
      (0-1): 2 x MaskEnc(
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (2): AttentionPooling(
        (fc): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_feat

Epoch 0 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-0, step-175 ---
loss: 0.5194
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 30.41it/s]
--- Eval epoch-0, step-175 ---
pr_auc: 0.1894
roc_auc: 0.5000
f1: 0.0000
loss: 0.6501
New best pr_auc score (0.1894) at epoch-0, step-175



Epoch 1 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-1, step-350 ---
loss: 0.5139
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 30.53it/s]
--- Eval epoch-1, step-350 ---
pr_auc: 0.1894
roc_auc: 0.5000
f1: 0.0000
loss: 0.6390



Epoch 2 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-2, step-525 ---
loss: 0.4871
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 29.92it/s]
--- Eval epoch-2, step-525 ---
pr_auc: 0.1894
roc_auc: 0.5000
f1: 0.0000
loss: 0.6277



Epoch 3 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-3, step-700 ---
loss: 0.4586
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 29.97it/s]
--- Eval epoch-3, step-700 ---
pr_auc: 0.1894
roc_auc: 0.5000
f1: 0.0000
loss: 0.6217



Epoch 4 / 5:   0%|          | 0/175 [00:00<?, ?it/s]

--- Train epoch-4, step-875 ---
loss: 0.4457
  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 29.85it/s]
--- Eval epoch-4, step-875 ---
pr_auc: 0.1894
roc_auc: 0.5000
f1: 0.0000
loss: 0.6183
Loaded best model


In [190]:
# option 1: use our built-in evaluation metric
score_dxtx = trainer_dxtx.evaluate(test_loader)
print (score_dxtx)

# option 2: use our pyhealth.metrics to evaluate
y_true_dxtx, y_prob_dxtx, loss_dxtx = trainer_dxtx.inference(test_loader)
binary_metrics_fn(y_true_dxtx, y_prob_dxtx, metrics=["pr_auc"])

  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 27.97it/s]


{'pr_auc': 0.20373027259684362, 'roc_auc': 0.5, 'f1': 0.0, 'loss': 0.6519408307292245}


  return torch._native_multi_head_attention(
Evaluation: 100%|██████████| 22/22 [00:00<00:00, 28.00it/s]


{'pr_auc': 0.20373027259684362}