# DuaLK: Tutorial for MIMIC-IV Dataset

This notebook provides a comprehensive pipeline for reproducing the DuaLK model results on the MIMIC-IV dataset. The workflow encompasses four critical stages: data preprocessing, laboratory-based pretraining, auxiliary laboratory layer training, and final model fine-tuning with dual knowledge integration.

## Environment Setup

First, we import the necessary dependencies and configure the computational environment. The framework leverages PyTorch for deep learning operations and PyTorch Geometric for graph neural network components.

In [1]:
import os
import pickle
import numpy as np
import random
import warnings
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.utils import shuffle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch_geometric
from torch_geometric.data import Data

# Suppress sklearn warnings for cleaner output
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# Set random seeds for reproducibility
random_seed = 66
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)

print("Environment initialized successfully.")

Environment initialized successfully.


In [2]:
print(torch.__version__)
print(torch_geometric.__version__)

2.1.2
2.5.3


### Data Preprocessing

The preprocessing stage implements a sophisticated pipeline that transforms raw MIMIC-IV clinical records into structured tensorial representations. This involves hierarchical patient-admission mapping, ICD code standardization, laboratory event aggregation, and graph-based disease relationship encoding.

**Key Configuration Parameters:**
- `dataset`: Target dataset identifier (mimic4)
- `train_num`: Cardinality of the training cohort (8000)
- `test_num`: Cardinality of the test cohort (1000)
- `threshold`: Minimum code frequency threshold for inclusion (0.01)
- `kg_embed_dims`: Dimensionality of knowledge graph embeddings ([2000])

The preprocessing leverages the `Mimic4Parser` class from `parse_csv_4.py`, which handles ICD-10 to ICD-9 code mapping and temporal filtering based on anchor year constraints.

In [3]:
from preprocess.parse_csv_4 import Mimic4Parser, process_lab_events, process_lab_items
from preprocess.encode import encode_code, encode_lab
from preprocess.build_dataset import split_patients, build_code_xy, build_code_xy_pretrain, build_single_lab_xy
from preprocess.build_dataset import build_heart_failure_y, get_rare_codes
from preprocess.auxiliary import load_icd2name
from preprocess.generate_graph import generate_disease_complication_edge_index, generate_disease_complication_x

In [4]:
# Configuration dictionary with dataset-specific hyperparameters
conf = {
    'mimic4': {
        'parser': Mimic4Parser,
        'train_num': 8000,
        'test_num': 1000,
        'threshold': 0.01,
    },
}

data_path = 'data'
dataset = 'mimic4'  # Critical: Must be mimic4 for this tutorial
dataset_path = os.path.join(data_path, dataset)
raw_path = os.path.join(dataset_path, 'raw')

if not os.path.exists(raw_path):
    os.makedirs(raw_path)
    print('Please place the MIMIC-IV CSV files in `data/%s/raw`' % dataset)
    raise FileNotFoundError("Raw data files not found.")

print(f"Initialized preprocessing for {dataset} dataset.")

Initialized preprocessing for mimic4 dataset.


#### Phase 1: Patient-Admission Parsing

The parser extracts patient-admission temporal sequences and diagnoses codes. The `Mimic4Parser` employs a bi-directional mapping strategy to convert ICD-10 codes to ICD-9 format, ensuring compatibility with established clinical ontologies.

In [5]:
parser = conf[dataset]['parser'](raw_path)
patient_admission, admission_codes = parser.parse()

print('Valid patients extracted: %d' % len(patient_admission))
max_admission_num = max([len(admissions) for admissions in patient_admission.values()])
max_code_num_in_a_visit = max([len(codes) for codes in admission_codes.values()])
print('Maximum admission sequence length: %d' % max_admission_num)
print('Maximum diagnostic codes per visit: %d' % max_code_num_in_a_visit)

loading ICD-10 to ICD-9 map ...
loading patients anchor year ...
parsing the csv file of admission ...
	selecting valid admission ...
		546028 in 546028 rows
		remaining 325593 rows
	325593 in 325593 rows
parsing csv file of diagnosis ...
	mapping ICD-10 to ICD-9 ...
		6364520 in 6364520 rows
	6364520 in 6364520 rows
calibrating patients by admission ...
calibrating admission by patients ...
Valid patients extracted: 140935
Maximum admission sequence length: 90
Maximum diagnostic codes per visit: 40


#### Phase 2: Laboratory Event Processing

Laboratory measurements are aggregated per admission, filtering for abnormal flags. The `process_lab_events` function constructs a mapping from admission IDs to abnormal lab item IDs, while `process_lab_items` categorizes laboratory tests into semantic groups (hematology, chemistry, blood gas).

In [6]:
admission_labs, lab_set = process_lab_events(raw_path, admission_codes, dataset)
lab_category = process_lab_items(raw_path, lab_set, dataset)
max_lab_num_in_a_visit = max([len(labs) for labs in admission_labs.values()])
print('Maximum laboratory tests per admission: %d' % max_lab_num_in_a_visit)

Maximum laboratory tests per admission: 150


#### Phase 3: Code Encoding and Mapping

The encoding phase transforms diagnostic codes and laboratory items into integer indices. Two code maps are generated: `code_map` for multi-visit patients and `code_map_pretrain` encompassing single-visit patients, facilitating differential pretraining strategies.

In [7]:
admission_codes_encoded, code_map, code_map_pretrain = encode_code(patient_admission, admission_codes)
admission_labs_encoded, lab_map = encode_lab(admission_labs, lab_category)

lab_num = len(lab_map['all'])
code_num = len(code_map)
code_num_pretrain = len(code_map_pretrain)

print('Unique laboratory items: %d' % lab_num)
print('Diagnostic codes (multi-visit): %d' % code_num)
print('Diagnostic codes (pretrain): %d' % code_num_pretrain)

encoding code ...
Lab counts per category: {'hematology': 2736676, 'chemistry': 1699193, 'blood gas': 295003}
Unique laboratory items: 445
Diagnostic codes (multi-visit): 7337
Diagnostic codes (pretrain): 7945


#### Phase 4: Knowledge Graph Initialization

We leverage pre-trained HAKE (Hierarchy-Aware Knowledge Embedding) representations to initialize disease node embeddings. The ICD9CM ontology provides semantic anchoring for diagnostic codes.

In [8]:
resource_path = 'resources'
icd2name, _ = load_icd2name(resource_path, code_map_pretrain)
print('ICD9CM codes matched: %d' % len(icd2name))

There is 66 unmatched ICD9CM codes
ICD9CM codes matched: 7945


#### Phase 5: Patient Cohort Stratification

The `split_patients` function partitions patients into single-visit, pretrain, train, validation, and test cohorts based on visit frequency and random sampling with a fixed seed for reproducibility.

In [9]:
single_pids, pretrain_pids, train_pids, valid_pids, test_pids = split_patients(
    patient_admission=patient_admission,
    admission_codes=admission_codes,
    code_map=code_map,
    train_num=conf[dataset]['train_num'],
    test_num=conf[dataset]['test_num'],
    seed=42
)

print('Single-visit samples: %d' % len(single_pids))
print('Pretrain samples: %d, Train: %d, Valid: %d, Test: %d' %
      (len(pretrain_pids), len(train_pids), len(valid_pids), len(test_pids)))

splitting pretrain, train, valid, and test pids ...
There are 82377 single admission patients, 58558 multiple admission patients
	100%00%
Single-visit samples: 82377
Pretrain samples: 90377, Train: 8000, Valid: 49558, Test: 1000


#### Phase 6: Disease Complication Graph Construction

The disease-disease interaction graph is constructed by analyzing sequential diagnosis patterns within patient trajectories. Edge weights are derived from co-occurrence statistics, capturing complication relationships.

In [10]:
tuples_disease2disease = generate_disease_complication_edge_index(
    pretrain_pids, patient_admission,
    admission_codes_encoded, code_num_pretrain,
    self_edge=True, edge_score=1
)
edge_index_disease2disease, edge_weight_disease2disease = tuples_disease2disease
print('Disease-disease edges: %d' % len(edge_index_disease2disease[0]))

Generating disease-disease complication edge index...
Vocabulary Size of codes: 7945
Disease-disease edges: 1725287


#### Phase 7: HAKE Embedding Initialization

HAKE embeddings encode hierarchical disease relationships within a hyperbolic space, capturing both semantic similarity and taxonomic structure. The 2000-dimensional embedding space provides rich representational capacity.

In [11]:
emb_path = os.path.join(data_path, 'emb')
kg_embed_dims = [2000]
x_disease2disease_HAKE = {}

for kg_embed_dim in kg_embed_dims:
    print('Initializing HAKE %d-dimensional embeddings ...' % kg_embed_dim)
    icd2hake = pickle.load(open(os.path.join(emb_path, f'ICD2HAKE_{kg_embed_dim}.pkl'), 'rb'))
    x_disease2disease_HAKE[kg_embed_dim] = generate_disease_complication_x(
        icd2hake, code_map_pretrain, emb_dim=kg_embed_dim
    )

Initializing HAKE 2000-dimensional embeddings ...
Cannot find: 213


#### Phase 8: Dataset Construction

This phase generates tensorial datasets for different training stages:
- **Single-visit lab data**: For auxiliary laboratory prediction task
- **Pretrain data**: Multi-category lab prediction (all, hematology, chemistry, blood gas)
- **Train/Valid/Test data**: Diagnosis prediction with lab features
- **Rare disease subset**: Focuses on low-frequency diagnostic codes

In [12]:
# Single-visit laboratory data
lab_x, lab_y = build_single_lab_xy(
    single_pids, patient_admission,
    admission_codes_encoded, admission_labs_encoded['all']
)
lab_x, lab_y = shuffle(lab_x, lab_y, random_state=66)
single_train_labs_x = lab_x[:int(0.85 * lab_x.shape[0])]
single_train_labs_y = lab_y[:int(0.85 * lab_y.shape[0])]
single_test_labs_x = lab_x[int(0.85 * lab_x.shape[0]):]
single_test_labs_y = lab_y[int(0.85 * lab_y.shape[0]):]

print('Single-visit lab data shape:', lab_x.shape, lab_y.shape)

building single lab features and labels ...
Single-visit lab data shape: (82377, 445) (82377, 7945)


In [13]:
# Pretrain dataset with multi-category lab predictions
pretrain_codes = build_code_xy_pretrain(
    pretrain_pids, patient_admission, admission_codes_encoded,
    admission_labs_encoded, max_admission_num, lab_map,
    max_code_num_in_a_visit, lab_category='all'
)
pretrain_codes_x, pretrain_codes_y_all, pretrain_visit_lens = pretrain_codes

_, pretrain_codes_y_hema, _ = build_code_xy_pretrain(
    pretrain_pids, patient_admission, admission_codes_encoded,
    admission_labs_encoded, max_admission_num, lab_map,
    max_code_num_in_a_visit, lab_category='hematology'
)
_, pretrain_codes_y_chem, _ = build_code_xy_pretrain(
    pretrain_pids, patient_admission, admission_codes_encoded,
    admission_labs_encoded, max_admission_num, lab_map,
    max_code_num_in_a_visit, lab_category='chemistry'
)
_, pretrain_codes_y_blood, _ = build_code_xy_pretrain(
    pretrain_pids, patient_admission, admission_codes_encoded,
    admission_labs_encoded, max_admission_num, lab_map,
    max_code_num_in_a_visit, lab_category='blood gas'
)

pretrain_codes_y = pretrain_codes_y_all, pretrain_codes_y_hema, pretrain_codes_y_chem, pretrain_codes_y_blood
print('Pretrain data shape:', pretrain_codes_x.shape)

building pretrain codes features and labels ...
	90377 / 90377
building pretrain codes features and labels ...
	90377 / 90377
building pretrain codes features and labels ...
	90377 / 90377
building pretrain codes features and labels ...
	90377 / 90377
Pretrain data shape: (90377, 90, 40)


In [14]:
# Train/Valid/Test datasets
train_codes_x, train_codes_y, train_visit_lens, train_labs_x = build_code_xy(
    train_pids, patient_admission,
    admission_labs_encoded['all'], lab_num,
    admission_codes_encoded,
    max_admission_num,
    code_num, max_code_num_in_a_visit
)
valid_codes_x, valid_codes_y, valid_visit_lens, valid_labs_x = build_code_xy(
    valid_pids, patient_admission,
    admission_labs_encoded['all'], lab_num,
    admission_codes_encoded,
    max_admission_num,
    code_num, max_code_num_in_a_visit
)
test_codes_x, test_codes_y, test_visit_lens, test_labs_x = build_code_xy(
    test_pids, patient_admission,
    admission_labs_encoded['all'], lab_num,
    admission_codes_encoded,
    max_admission_num,
    code_num, max_code_num_in_a_visit
)

print('Train/Valid/Test shapes:', train_codes_y.shape, valid_codes_y.shape, test_codes_y.shape)

building train/valid/test codes features and labels ...
	8000 / 8000
building train/valid/test codes features and labels ...
	49558 / 49558
building train/valid/test codes features and labels ...
	1000 / 1000
Train/Valid/Test shapes: (8000, 7337) (49558, 7337) (1000, 7337)


In [15]:
# Rare disease subset extraction
print('Extracting rare disease codes (threshold=3)...')
codes_y_rare = get_rare_codes(train_codes_y, valid_codes_y, test_codes_y, threshold=3)
train_codes_y_r, valid_codes_y_r, test_codes_y_r = codes_y_rare
print('Rare disease shapes:', train_codes_y_r.shape, valid_codes_y_r.shape, test_codes_y_r.shape)

Extracting rare disease codes (threshold=3)...
Rare disease shapes: (8000, 4424) (49558, 4424) (1000, 4424)


In [16]:
# Heart failure outcome labels (ICD-9 code 428)
train_hf_y = build_heart_failure_y('428', train_codes_y, code_map)
valid_hf_y = build_heart_failure_y('428', valid_codes_y, code_map)
test_hf_y = build_heart_failure_y('428', test_codes_y, code_map)

building train/valid/test heart failure labels ...
building train/valid/test heart failure labels ...
building train/valid/test heart failure labels ...


#### Phase 9: Data Persistence

Serialized datasets are stored in structured directories for subsequent training phases. This modular approach decouples preprocessing from model training.

In [17]:
train_labs_data = (single_train_labs_x, single_train_labs_y)
test_labs_data = (single_test_labs_x, single_test_labs_y)
pretrain_codes_data = (pretrain_codes_x, pretrain_codes_y, None, pretrain_visit_lens)
train_codes_data = (train_codes_x, train_codes_y, train_labs_x, train_codes_y_r)
valid_codes_data = (valid_codes_x, valid_codes_y, valid_labs_x, valid_codes_y_r)
test_codes_data = (test_codes_x, test_codes_y, test_labs_x, test_codes_y_r)

encoded_path = os.path.join(dataset_path, 'encoded')
os.makedirs(encoded_path, exist_ok=True)

print('Persisting encoded data...')
pickle.dump(patient_admission, open(os.path.join(encoded_path, 'patient_admission.pkl'), 'wb'))
pickle.dump(admission_codes_encoded, open(os.path.join(encoded_path, 'codes_encoded.pkl'), 'wb'))
pickle.dump(admission_labs_encoded, open(os.path.join(encoded_path, 'labs_encoded.pkl'), 'wb'))
pickle.dump({
    'lab_map': lab_map,
    'code_map': code_map,
    'code_map_pretrain': code_map_pretrain
}, open(os.path.join(encoded_path, 'code_maps.pkl'), 'wb'))
pickle.dump({
    'pretrain_pids': pretrain_pids,
    'train_pids': train_pids,
    'valid_pids': valid_pids,
    'test_pids': test_pids
}, open(os.path.join(encoded_path, 'pids.pkl'), 'wb'))

standard_path = os.path.join(dataset_path, 'standard')
os.makedirs(standard_path, exist_ok=True)

print('Persisting standard datasets...')
pickle.dump(pretrain_codes_data, open(os.path.join(standard_path, 'pretrain_codes_dataset.pkl'), 'wb'))
pickle.dump({
    'train_labs_data': train_labs_data,
    'test_labs_data': test_labs_data
}, open(os.path.join(standard_path, 'labs_dataset.pkl'), 'wb'))
pickle.dump({
    'train_codes_data': train_codes_data,
    'valid_codes_data': valid_codes_data,
    'test_codes_data': test_codes_data
}, open(os.path.join(standard_path, 'codes_dataset.pkl'), 'wb'))
pickle.dump({
    'train_hf_y': train_hf_y,
    'valid_hf_y': valid_hf_y,
    'test_hf_y': test_hf_y
}, open(os.path.join(standard_path, 'heart_failure.pkl'), 'wb'))

graph_path = os.path.join(dataset_path, 'graph')
os.makedirs(graph_path, exist_ok=True)

pickle.dump({
    'edge_index': edge_index_disease2disease,
    'edge_weight': edge_weight_disease2disease,
    'x_hake_2000': x_disease2disease_HAKE[2000],
}, open(os.path.join(graph_path, 'disease2disease.pkl'), 'wb'))

print('Preprocessing completed. All datasets saved.')

Persisting encoded data...
Persisting standard datasets...
Preprocessing completed. All datasets saved.


### Laboratory Pretraining

The laboratory pretraining phase employs a multi-decoder architecture to learn category-specific laboratory prediction tasks. This stage consists of two sub-phases:

1. **Joint Pretraining**: All decoders (hematology, chemistry, blood gas) are trained simultaneously with a unified loss function, establishing shared representational foundations.
2. **Individual Decoder Refinement**: Each decoder is fine-tuned independently to capture category-specific patterns.

**Architectural Specifications:**
- `init_dim`: 2000 (HAKE embedding dimensionality)
- `GNN.type`: 'gat' (Graph Attention Network for disease graph encoding)
- `GNN.dims`: (256, 256) (two-layer GAT with 256 hidden units)
- `Decoder.dims`: (256, 128) (hierarchical decoder architecture)
- `joint_epochs`: 10 (iterations for joint training)
- `individual_epochs`: 10 (iterations per decoder refinement)

In [3]:
from models.model import DuaLK
from utils import PatientLabDataset, load_data
from utils import pretrain_model_jointly, pretrain_individual_decoder

model_config = {
    'init_dim': 2000,
    'GNN': {
        'type': 'gat',
        'dims': (256, 256),
        'dropout': 0.,
    },
    'Attention': {
        'dropout': 0.2,
    },
    'Decoder': {
        'dims': (256, 128),
        'dropout': 0.4,
    },
    'Classifier': {
        'dims': [256],
        'dropout': 0.4,
    }
}

data_path, dataset = 'data', 'mimic4'
pretrain = True
train_type = 'pretrain'
use_lab = False
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
code_fuse, visit_fuse = 'simple', 'simple'
gnn_type = model_config['GNN']['type']

print(f'Pretraining on device: {device}')

Pretraining on device: cuda:0


#### Loading Preprocessed Data

The `load_data` utility function retrieves graph structures, node features, and pretrain labels. Note the separation of laboratory categories for specialized decoder training.

In [4]:
train_codes_x, pretrain_codes_y, edge_index, x, edge_weight = load_data(
    pretrain, data_path, dataset, model_config['init_dim']
)
_, pretrain_codes_y_hema, pretrain_codes_y_chem, pretrain_codes_y_blood = pretrain_codes_y

x = x.float()
data = Data(x=x, edge_index=edge_index)
print('Graph node features shape:', data.x.shape)
print('Training data shape:', train_codes_x.shape)

Graph node features shape: torch.Size([7946, 2000])
Training data shape: (90377, 90, 40)


In [5]:
# Dataset construction with category-specific labels
train_dataset = PatientLabDataset(
    train_codes_x, pretrain_codes_y_hema,
    pretrain_codes_y_chem, pretrain_codes_y_blood
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

# Validation subset for monitoring
test_indices = np.random.choice(len(train_dataset), 10000, replace=False)
test_subset = torch.utils.data.Subset(train_dataset, test_indices)
test_loader = DataLoader(test_subset, batch_size=256, shuffle=False)

#### Model Instantiation

The DuaLK model integrates graph neural encoding with attention-based visit aggregation. The `num_classes` parameter specifies output dimensions for each laboratory category decoder.

In [6]:
learning_rate = 0.001
joint_epochs = 1 # Here we set 1 epoch for quick start
individual_epochs = 1 # Here we set 1 epoch for quick start
num_classes = [
    pretrain_codes_y_hema.shape[1],
    pretrain_codes_y_chem.shape[1],
    pretrain_codes_y_blood.shape[1]
]

data = data.to(str(device))
model = DuaLK(
    model_config=model_config,
    emb_init=data.x,
    num_classes=num_classes,
    use_lab=use_lab,
    code_fuse=code_fuse,
    visit_fuse=visit_fuse,
    train_type=train_type,
    lab_weight=None,
    lab_bias=None,
    gnn_type=gnn_type
).to(device)

print(model)

criterion = torch.nn.BCEWithLogitsLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

DuaLK(
  (gnn_layer): GATLayer(
    (conv1): GATConv(2000, 256, heads=1)
    (conv2): GATConv(256, 256, heads=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (visit_attention): AttentionLayer(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (patient_attention): AttentionLayer(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (decoder1): Decoder(
    (fc1): Linear(in_features=256, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=239, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder2): Decoder(
    (fc1): Linear(in_features=256, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=191, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder3): Decoder(
    (fc1): Linear(in_features=256, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=1

#### Joint Pretraining Phase

Multi-task learning across laboratory categories enforces shared feature extraction while preserving category-specific nuances through dedicated output heads.

In [6]:
step = 'together'  # Options: 'joint', 'individual', 'together'

if step == "joint" or step == 'together':
    print('Initiating joint pretraining...')
    pretrain_model_jointly(
        model, data, train_loader, criterion,
        optimizer, joint_epochs, device, test_loader
    )
    torch.save(model.state_dict(), f'./saved/joint_pretrained_model_{joint_epochs}.pth')
    print(f'Joint pretrained model saved: joint_pretrained_model_{joint_epochs}.pth')

Initiating joint pretraining...
Epoch 1/1, Train Loss: 0.4419, Test Loss: 0.3148, F1-weighted1: 0.6341, F1-weighted2: 0.5906, F1-weighted3: 0.7932, Learning Rate: 0.00100, Train Time: 807.07s, Eval Time: 90.37s
Joint pretrained model saved: joint_pretrained_model_1.pth


#### Individual Decoder Refinement

Post joint training, each decoder undergoes isolated optimization to maximize category-specific predictive performance. This two-stage approach balances generalization and specialization.

In [7]:
if step == "individual" or step == 'together':
    model.load_state_dict(torch.load(f'./saved/joint_pretrained_model_{joint_epochs}.pth'))
    print('Initiating individual decoder refinement...')
    decoders = ['hema', 'chem', 'blood']
    for decoder_type in decoders:
        print(f'\nRefining {decoder_type} decoder...')
        pretrain_individual_decoder(
            model, data, train_loader, criterion,
            optimizer, individual_epochs, device,
            decoder_type, test_loader
        )
    print('\nIndividual decoder weights saved.')

Initiating individual decoder refinement...

Refining hema decoder...
Epoch 1/1, Train Loss (hema): 0.0780, Test Loss: 0.0752, F1-weighted: 0.6369, Learning Rate: 0.00100, Train Time: 777.26s, Eval Time: 74.29s

Refining chem decoder...
Epoch 1/1, Train Loss (chem): 0.0668, Test Loss: 0.0637, F1-weighted: 0.5928, Learning Rate: 0.00100, Train Time: 704.37s, Eval Time: 87.76s

Refining blood decoder...
Epoch 1/1, Train Loss (blood): 0.1817, Test Loss: 0.1700, F1-weighted: 0.7946, Learning Rate: 0.00100, Train Time: 797.91s, Eval Time: 90.06s

Individual decoder weights saved.


### Auxiliary Laboratory Layer Training

This stage trains a standalone feed-forward network to map laboratory values to diagnostic codes using single-visit data. The learned weights initialize the laboratory integration module in the final model, providing a warm start for lab-diagnosis associations.

**Network Architecture:**
- Input dimension: Variable (number of lab items, ~239 for MIMIC-IV)
- Hidden layer: 256 units with ReLU activation and 0.4 dropout
- Output dimension: Number of diagnostic codes

**Training Configuration:**
- `batch_size`: 128
- `learning_rate`: 0.001 with StepLR decay (γ=0.5 every 20 epochs)
- `num_epochs`: 100
- Evaluation metrics: Recall@20, Recall@40

In [7]:
from metrics import top_k_prec_recall

def train_and_evaluate_lab(model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device):
    """Training loop for auxiliary laboratory prediction layer."""
    ks = [20, 40]
    best_test_loss = float('inf')
    best_model_path = './saved/train_lab/lab_layer_checkpoint.pth'

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        test_loss = 0
        y_true_test = []
        y_pred_test = []

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                y_true_test.append(labels.cpu().numpy())
                y_pred_test.append(outputs.cpu().numpy())

        y_true_test = np.vstack(y_true_test)
        y_pred_test = np.vstack(y_pred_test)
        y_pred_sorted_test = np.argsort(y_pred_test, axis=1)[:, ::-1]

        _, recall_at_k_test = top_k_prec_recall(y_true_test, y_pred_sorted_test, ks)
        test_loss /= len(test_loader)
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, '
              f'Test Loss: {test_loss:.4f}, '
              f'Test Recall@20: {recall_at_k_test[0]:.4f}, '
              f'Test Recall@40: {recall_at_k_test[1]:.4f}, '
              f'LR: {current_lr:.6f}')

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save({
                'linear1_weight': model[0].weight.data,
                'linear1_bias': model[0].bias.data,
                'linear2_weight': model[3].weight.data,
                'linear2_bias': model[3].bias.data,
            }, best_model_path)
            print(f'Checkpoint saved at epoch {epoch + 1}')

In [8]:
data_path = 'data'
dataset = 'mimic4'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print('Loading single-visit laboratory data...')
dataset_path = os.path.join(data_path, dataset)
standard_path = os.path.join(dataset_path, 'standard')
labs_dataset = pickle.load(open(os.path.join(standard_path, 'labs_dataset.pkl'), 'rb'))
train_labs_x, train_labs_y = labs_dataset['train_labs_data']
test_labs_x, test_labs_y = labs_dataset['test_labs_data']

train_labs_x, train_labs_y = torch.from_numpy(train_labs_x), torch.from_numpy(train_labs_y)
test_labs_x, test_labs_y = torch.from_numpy(test_labs_x), torch.from_numpy(test_labs_y)
print('Train lab shape:', train_labs_x.shape, train_labs_y.shape)
print('Test lab shape:', test_labs_x.shape, test_labs_y.shape)

item_num = train_labs_x.shape[1]
code_num = train_labs_y.shape[1]

lab_model = nn.Sequential(
    nn.Linear(item_num, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, code_num),
).to(device)

train_dataset = TensorDataset(train_labs_x.float(), train_labs_y.float())
test_dataset = TensorDataset(test_labs_x.float(), test_labs_y.float())
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

criterion = torch.nn.BCEWithLogitsLoss().to(device)
optimizer = optim.Adam(lab_model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

num_epochs = 100
print('\nTraining auxiliary laboratory layer...\n')
train_and_evaluate_lab(lab_model, train_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device)

Loading single-visit laboratory data...
Train lab shape: torch.Size([70020, 445]) torch.Size([70020, 7945])
Test lab shape: torch.Size([12357, 445]) torch.Size([12357, 7945])

Training auxiliary laboratory layer...

Epoch 1/100, Train Loss: 0.0510, Test Loss: 0.0066, Test Recall@20: 0.3652, Test Recall@40: 0.4549, LR: 0.001000
Checkpoint saved at epoch 1
Epoch 2/100, Train Loss: 0.0052, Test Loss: 0.0046, Test Recall@20: 0.3840, Test Recall@40: 0.4876, LR: 0.001000
Checkpoint saved at epoch 2
Epoch 3/100, Train Loss: 0.0044, Test Loss: 0.0042, Test Recall@20: 0.3973, Test Recall@40: 0.4997, LR: 0.001000
Checkpoint saved at epoch 3
Epoch 4/100, Train Loss: 0.0042, Test Loss: 0.0041, Test Recall@20: 0.4101, Test Recall@40: 0.5078, LR: 0.001000
Checkpoint saved at epoch 4
Epoch 5/100, Train Loss: 0.0041, Test Loss: 0.0040, Test Recall@20: 0.4206, Test Recall@40: 0.5201, LR: 0.001000
Checkpoint saved at epoch 5
Epoch 6/100, Train Loss: 0.0040, Test Loss: 0.0040, Test Recall@20: 0.4251, Tes

In [9]:
os.makedirs('./saved/train_lab', exist_ok=True)
end_model_path = './saved/train_lab/lab_layer_checkpoint_end.pth'
torch.save({
    'linear1_weight': lab_model[0].weight.data,
    'linear1_bias': lab_model[0].bias.data,
    'linear2_weight': lab_model[3].weight.data,
    'linear2_bias': lab_model[3].bias.data,
}, end_model_path)
print('Final laboratory layer checkpoint saved.')

Final laboratory layer checkpoint saved.


### Final Model Training (Fine-tuning)

The culminating training phase integrates all pretrained components into the full DuaLK architecture. This involves:

1. **Initialization**: Loading pretrained graph encoder and decoder weights
2. **Laboratory Integration**: Incorporating learned lab-diagnosis mappings
3. **End-to-End Optimization**: Fine-tuning on multi-visit diagnosis prediction

**Key Hyperparameters:**
- `train_type`: 'finetune' (enables partial parameter freezing)
- `use_lab`: True (activates laboratory feature fusion)
- `loss`: 'pos_weight' (balanced binary cross-entropy)
- `code_fuse/visit_fuse`: 'simple' (concatenation-based fusion)
- `epochs`: 500 with adaptive learning rate decay
- `batch_size`: 32

**Evaluation Metrics:**
- F1-weighted: Macro-averaged F1 score across all diagnostic codes
- Recall@k (k=10,20,30,40): Top-k retrieval performance

In [10]:
from utils import PatientDataset, load_data
from metrics import f1
from models.loss import WeightedBCEWithLogitsLoss

def adjust_learning_rate(optimizer, epoch, base_lr):
    """Stepwise learning rate decay schedule."""
    if epoch > 40:
        lr = 0.0001
    elif epoch > 30:
        lr = 0.0005
    else:
        lr = base_lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [11]:
model_config = {
    'init_dim': 2000,
    'GNN': {
        'type': 'gat',
        'dims': (256, 256),
        'dropout': 0.,
    },
    'Attention': {
        'dropout': 0.2,
    },
    'Decoder': {
        'dims': (256, 128),
        'dropout': 0.4,
    },
    'Classifier': {
        'dims': [256],
        'dropout': 0.4,
    }
}

data_path, dataset = 'data', 'mimic4'
pretrain = False
train_type = 'finetune'
use_lab = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
code_fuse, visit_fuse = 'simple', 'simple'
loss = 'pos_weight'
gnn_type = model_config['GNN']['type']
code_range = 'all'

pos_weight, neg_weight = 1, 1
print(f'Loss configuration: pos_weight={pos_weight}, neg_weight={neg_weight}')

Loss configuration: pos_weight=1, neg_weight=1


#### Data Loading and Graph Construction

The `load_data` function retrieves preprocessed tensors and graph structures. The graph remains static during training, serving as a structural prior for disease relationships.

In [12]:
(
    train_codes_x, train_codes_y, test_codes_x, test_codes_y,
    train_labs_x, test_labs_x, edge_index, x, edge_weight,
    train_codes_y_r, test_codes_y_r
) = load_data(pretrain, data_path, dataset, model_config['init_dim'])

x = x.float()
data = Data(x=x, edge_index=edge_index)

print('Disease graph nodes:', data.x.shape[0])
print('Train samples:', train_codes_x.shape[0])
print('Test samples:', test_codes_x.shape[0])

Disease graph nodes: 7946
Train samples: 8000
Test samples: 1000


In [13]:
train_dataset = PatientDataset(train_codes_x, train_labs_x.float(), train_codes_y)
test_dataset = PatientDataset(test_codes_x, test_labs_x.float(), test_codes_y)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

#### Model Initialization with Pretrained Weights

The model is initialized with:
1. Pretrained graph encoder parameters (from laboratory pretraining)
2. Category-specific decoder weights (hematology, chemistry, blood gas)
3. Auxiliary laboratory layer weights (from train_lab phase)

For MIMIC-IV, `num_classes` follows: [239 hematology, 191 chemistry, 15 blood gas, total diagnostic codes].

In [14]:
learning_rate = 0.001
epochs = 50  # Here we set 50 epochs for quick start

if dataset == 'mimic3':
    num_classes = [159, 115, 16, train_codes_y.shape[1]]
elif dataset == 'mimic4':
    num_classes = [239, 191, 15, train_codes_y.shape[1]]
else:
    raise ValueError('Invalid dataset')

if use_lab:
    print('Loading auxiliary laboratory layer weights...')
    model_path = './saved/train_lab/lab_layer_checkpoint_end.pth'
    checkpoint = torch.load(model_path)
    lab_weight, lab_bias = checkpoint['linear1_weight'], checkpoint['linear1_bias']
else:
    lab_weight, lab_bias = None, None

data = data.to(str(device))
model = DuaLK(
    model_config=model_config,
    emb_init=data.x,
    num_classes=num_classes,
    use_lab=use_lab,
    code_fuse=code_fuse,
    visit_fuse=visit_fuse,
    train_type=train_type,
    lab_weight=lab_weight,
    lab_bias=lab_bias,
    gnn_type=gnn_type
).to(device)

print(model)

Loading auxiliary laboratory layer weights...
DuaLK(
  (gnn_layer): GATLayer(
    (conv1): GATConv(2000, 256, heads=1)
    (conv2): GATConv(256, 256, heads=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (visit_attention): AttentionLayer(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (patient_attention): AttentionLayer(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (lab_layer): Linear(in_features=290, out_features=256, bias=True)
  (decoder1): Decoder(
    (fc1): Linear(in_features=256, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=239, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder2): Decoder(
    (fc1): Linear(in_features=256, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=191, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder3): Decoder(

In [15]:
if train_type in ['pretrain', 'finetune']:
    print('Loading pretrained parameters...')
    model.load_partial_state_dict(
        torch.load(f'./saved/joint_pretrained_model_{joint_epochs}.pth', map_location=device)
    )
    model.load_decoder_weights(
        './saved/decoder_hema_weights.pth',
        './saved/decoder_chem_weights.pth',
        './saved/decoder_blood_weights.pth',
        device
    )
    model = model.to(device)
    print('Pretrained weights successfully loaded.')

Loading pretrained parameters...
Pretrained weights successfully loaded.


In [16]:
criterion = WeightedBCEWithLogitsLoss(pos_weight=pos_weight, neg_weight=neg_weight).to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

os.makedirs('./saved/train', exist_ok=True)
torch.save(data.x.cpu(), './saved/train/initial_embeddings.pth')

#### Training Loop

The training employs a curriculum learning strategy via adaptive learning rate decay. Early epochs focus on coarse-grained pattern learning, while later epochs refine predictions with granular adjustments.

In [17]:
print('\nCommencing end-to-end training...\n')
for epoch in range(epochs):
    adjust_learning_rate(optimizer, epoch, learning_rate)
    current_lr = optimizer.param_groups[0]['lr']
    epoch_loss = 0

    model.train()
    for batch in train_loader:
        patient_data, labels = batch
        patient_data = {k: v.to(device) for k, v in patient_data.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        output = model(data, patient_data)
        loss = criterion(output, labels.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader)

    # Evaluation on test set
    model.eval()
    test_loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in test_loader:
            patient_data, labels = batch
            patient_data = {k: v.to(device) for k, v in patient_data.items()}
            labels = labels.to(device)

            output = model(data, patient_data)
            loss = criterion(output, labels.float())
            test_loss += loss.item()

            y_true.append(labels.cpu().numpy())
            y_pred.append(output.cpu().numpy())

    test_loss /= len(test_loader)
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    y_pred_sorted = np.argsort(y_pred, axis=1)[:, ::-1]

    f1_weighted = f1(y_true, y_pred_sorted, metrics='weighted')
    ks = [10, 20, 30, 40] if code_range == 'all' else [5, 8, 15]
    _, recall_at_k = top_k_prec_recall(y_true, y_pred_sorted, ks)

    recall_str = ", ".join([f"Recall@{k}: {recall:.4f}" for k, recall in zip(ks, recall_at_k)])
    print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.5f}, Test Loss: {test_loss:.5f}, '
          f'F1-weighted: {f1_weighted:.4f}, {recall_str}, LR: {current_lr:.6f}')

    torch.save(model.state_dict(), './saved/train/checkpoint.pth')

print('\nTraining completed. Final checkpoint saved.')


Commencing end-to-end training...

Epoch 1/50, Train Loss: 0.02562, Test Loss: 0.00493, F1-weighted: 0.1370, Recall@10: 0.3303, Recall@20: 0.4005, Recall@30: 0.4536, Recall@40: 0.4874, LR: 0.001000
Epoch 2/50, Train Loss: 0.00630, Test Loss: 0.00468, F1-weighted: 0.1576, Recall@10: 0.3480, Recall@20: 0.4209, Recall@30: 0.4794, Recall@40: 0.5226, LR: 0.001000
Epoch 3/50, Train Loss: 0.00598, Test Loss: 0.00455, F1-weighted: 0.1756, Recall@10: 0.3652, Recall@20: 0.4382, Recall@30: 0.4962, Recall@40: 0.5388, LR: 0.001000
Epoch 4/50, Train Loss: 0.00577, Test Loss: 0.00452, F1-weighted: 0.1727, Recall@10: 0.3638, Recall@20: 0.4337, Recall@30: 0.4998, Recall@40: 0.5491, LR: 0.001000
Epoch 5/50, Train Loss: 0.00564, Test Loss: 0.00445, F1-weighted: 0.1844, Recall@10: 0.3737, Recall@20: 0.4454, Recall@30: 0.5133, Recall@40: 0.5587, LR: 0.001000
Epoch 6/50, Train Loss: 0.00553, Test Loss: 0.00441, F1-weighted: 0.1960, Recall@10: 0.3832, Recall@20: 0.4551, Recall@30: 0.5164, Recall@40: 0.5625,

## Conclusion

This notebook demonstrates the complete DuaLK pipeline on MIMIC-IV:
1. **Preprocessing**: Raw EHR → structured tensors + disease graph
2. **Laboratory Pretraining**: Multi-category lab prediction with graph-aware encoders
3. **Auxiliary Lab Training**: Direct lab-diagnosis mapping
4. **Fine-tuning**: Integrated model with dual knowledge (graph + lab)

The modular design enables flexible experimentation with different graph encoders, fusion strategies, and pretraining objectives.