# ExMed-BERT Data Preparation for Pretraining

This notebook provides a detailed walkthrough of how patient data is prepared for pretraining a BERT-style model using the ExMed-BERT framework. It is designed for users who are new to the codebase and want to understand the data pipeline, encoding strategies, and the rationale behind each step.

We will cover:
- The purpose and structure of the main classes involved (e.g., `CodeDict`, `AgeDict`, `Patient`, `PatientDataset`)
- How medical codes and patient attributes are encoded
- How patient data is processed and masked for model input
- How a dataset is constructed and saved for pretraining

## 1. Imports and Setup

We begin by importing the necessary modules and ensuring the ExMed-BERT package is available for import. This includes adding the package directory to the Python path and importing PyTorch for tensor operations.

In [1]:
import sys
import os
from datetime import date

# Add the ExMed-BERT package to Python path for imports
sys.path.append('./ExMed-BERT-main')

import torch  # PyTorch is used for tensor operations and model input formatting

## 2. Encoding Classes: Dictionaries for Medical Data

ExMed-BERT uses specialized dictionary classes to encode various types of patient data into integer IDs suitable for model input. These include:
- `CodeDict`: Handles medical codes (ICD, PheWAS, RxNorm, ATC) and their mappings.
- `AgeDict`: Bins and encodes patient ages.
- `SexDict`: Encodes sex/gender.
- `StateDict`: Encodes US state information.
- `EndpointDict`: Encodes endpoint labels for classification tasks.

Let's import these classes and briefly describe their roles.

In [2]:
from exmed_bert.data.encoding import (
    AgeDict,     # Handles age binning and encoding
    CodeDict,    # Handles medical codes (ICD, PheWAS, RxNorm, ATC)
    SexDict,     # Handles sex/gender encoding
    StateDict,   # Handles US state encoding
    DICT_DEFAULTS,  # Default tokens (e.g., PAD, UNK, MASK)
    EndpointDict # Handles endpoint label encoding (e.g., for classification tasks)
)
from exmed_bert.data.patient import Patient  # Main class for patient sequence processing
from exmed_bert.data.dataset import PatientDataset  # Dataset class for batching patients

### What do these classes do?
- **CodeDict**: Maps raw medical codes (diagnoses, drugs) to integer IDs, handles code normalization (e.g., mapping RxNorm to ATC, ICD to PheWAS), and provides decoding for interpretability.
- **AgeDict**: Bins ages (e.g., by year) and encodes them as IDs.
- **SexDict**: Encodes sex/gender as IDs.
- **StateDict**: Encodes US state abbreviations as IDs.
- **EndpointDict**: Encodes outcome labels for supervised tasks.

These dictionaries ensure that all categorical data is consistently mapped to integer IDs for model input.

## 3. Example Code Dictionaries and Mappings

For demonstration, we define small example sets of codes and mappings. In real applications, these would be much larger and loaded from files.

In [3]:
# ATC (Anatomical Therapeutic Chemical) codes for medications
atc_codes = [
    'A01AA01', 'B01AC06', 'C09AA05', 'D05AX02', 'E03AA01', 'F01BA01', 'G04BE03'
]

# PheWAS (Phenome Wide Association Study) codes for diagnoses
phewas_codes = [
    '008', '250', '401.1', '530.11', '715.2', '272.1', '585.3'
]

# Mapping from RxNorm (prescription) to ATC codes
rx_to_atc_map = {
    '860975': 'A01AA01',
    '197361': 'B01AC06',
    '123456': 'C09AA05',
    '654321': 'D05AX02',
    '789012': 'E03AA01',
    '345678': 'F01BA01',
    '987654': 'G04BE03'
}

# Mapping from ICD-10 to PheWAS codes
icd_to_phewas_map = {
    'I10': '401.1',
    'E11.9': '250',
    'Z51.11': '008',
    'K21.0': '530.11',
    'M17.9': '715.2',
    'E78.5': '272.1',
    'N18.3': '585.3'
}

# List of US states for state encoding
state_list = ['CA', 'NY', 'TX']

## 4. Initializing Encoding Dictionaries

We now create the encoding dictionary objects. These will be used to map all patient data to integer IDs.

In [4]:
# Code dictionary for medical codes
code_dict = CodeDict(
    atc_codes=atc_codes,
    phewas_codes=phewas_codes,
    rx_to_atc_map=rx_to_atc_map,
    icd_to_phewas_map=icd_to_phewas_map
)

# Age dictionary for binning and encoding ages (by year)
age_dict = AgeDict(max_age=90, min_age=0, binsize=1)
# Convert the age_dict's vocabulary to integers for consistency
age_dict.vocab = [str(int(float(age))) if age not in DICT_DEFAULTS else age for age in age_dict.vocab]
age_dict.labels_to_id = {(str(int(float(label))) if label not in DICT_DEFAULTS else label): idx 
                        for label, idx in age_dict.labels_to_id.items()}
age_dict.ids_to_label = {idx: (str(int(float(label))) if label not in DICT_DEFAULTS else label)
                        for idx, label in age_dict.ids_to_label.items()}

# Sex dictionary
sex_dict = SexDict(sex=['MALE', 'FEMALE'])

# State dictionary
state_dict = StateDict(states=state_list)

## 5. Example Patient Data

We create example patients with expanded medical histories. Each patient is represented as a dictionary with diagnoses, drugs, dates, and demographic information.

In [None]:
patients = [
    {
        'patient_id': 12345,
        'diagnoses': ['I10', 'E11.9', 'Z51.11', 'K21.0', 'M17.9', 'E78.5', 'N18.3'],
        'diagnosis_dates': [
            date(2021, 3, 15), date(2022, 3, 15), date(2023, 4, 20),
            date(2024, 5, 1), date(2025, 5, 10), date(2025, 6, 5), date(2025, 7, 12)
        ],
        'drugs': ['860975', '197361', '197361', '654321', '789012', '345678', '987654'],
        'prescription_dates': [
            date(2021, 3, 16), date(2022, 3, 16), date(2023, 4, 21),
            date(2024, 5, 2), date(2025, 5, 11), date(2025, 6, 6), date(2025, 7, 13)
        ],
        'birth_year': 2004,
        'sex': 'MALE',
        'patient_state': 'CA',
        'plos': 1
    },
    {
        'patient_id': 67890,
        'diagnoses': ['E11.9', 'I10', 'K21.0', 'M17.9', 'E78.5', 'N18.3', 'Z51.11'],
        'diagnosis_dates': [
            date(2021, 5, 1), date(2022, 5, 15), date(2023, 5, 20),
            date(2024, 6, 1), date(2025, 6, 10), date(2025, 7, 5), date(2025, 8, 12)
        ],
        'drugs': ['123456', '860975', '654321', '789012', '345678', '987654', '197361'],
        'prescription_dates': [
            date(2021, 5, 2), date(2022, 5, 16), date(2023, 5, 21),
            date(2024, 6, 2), date(2025, 6, 11), date(2025, 7, 6), date(2025, 8, 13)
        ],
        'birth_year': 1960,
        'sex': 'FEMALE',
        'patient_state': 'NY',
        'plos': 0
    }
]

## 6. Patient Object: Encoding and Processing

The `Patient` class takes a patient's raw data and encodes it using the dictionaries above. It also handles masking (for masked language modeling), sequence splitting, and other preprocessing steps.

Let's process each patient and inspect the encoded and decoded outputs.

In [25]:
for idx, patient_data in enumerate(patients, 1):
    try:
        patient = Patient(
            patient_id=patient_data['patient_id'],            # Unique identifier for tracking and referencing individual patients in the dataset
            diagnoses=patient_data['diagnoses'],            # List of ICD diagnosis codes that will be converted to PheWAS for standardization
            drugs=patient_data['drugs'],                    # List of RxNorm drug codes that will be converted to ATC for standardization
            diagnosis_dates=patient_data['diagnosis_dates'], # Timestamps for diagnoses to maintain temporal order in sequence modeling
            prescription_dates=patient_data['prescription_dates'], # Timestamps for prescriptions to maintain temporal order
            birth_year=patient_data['birth_year'],           # Used to calculate patient age for age-based feature encoding
            sex=patient_data['sex'],                         # Demographic feature encoded as MALE/FEMALE for patient characterization
            patient_state=patient_data['patient_state'],     # Geographic feature for potential regional health pattern analysis
            max_length=50,                                   # Maximum sequence length - truncates or pads sequences for consistent model input
            code_embed=code_dict,                            # Handles conversion and encoding of medical codes (ICD→PheWAS, RxNorm→ATC)
            sex_embed=sex_dict,                              # Converts sex categories to numerical embeddings for model input
            age_embed=age_dict,                              # Bins and encodes patient ages into discrete categories
            state_embed=state_dict,                          # Converts state information into numerical embeddings
            mask_drugs=True,                                 # Enables drug code masking for MLM pretraining task
            delete_temporary_variables=True,                 # Cleans up memory by removing intermediate processing variables
            split_sequence=True,                             # Splits long patient sequences into manageable chunks if needed
            drop_duplicates=True,                            # Removes redundant codes to prevent sequence bias
            converted_codes=False,                           # Indicates if codes are in raw form (False) or already converted to standard format
            convert_icd_to_phewas=True,                      # Enables automatic conversion of ICD codes to PheWAS for standardization
            convert_rxcui_to_atc=True,                       # Enables automatic conversion of RxNorm to ATC for drug standardization
            keep_min_unmasked=1,                             # Ensures at least one token remains unmasked for context in MLM
            max_masked_tokens=20,                            # Limits masked tokens to prevent too much information loss
            masked_lm_prob=0.15,                             # Probability of masking each token, following BERT's approach
            truncate='right',                                # Specifies to remove older events when truncating long sequences
            index_date=None,                                 # Optional reference date for temporal alignment of patient histories
            had_plos=True,                                   # Prolonged Length of Stay label for supervised learning tasks
            endpoint_labels=patient_data.get('endpoint_labels', None), # Additional outcome labels for multi-task learning
            dynamic_masking=False,                           # When False, uses static masks; True generates new masks each epoch
            min_observations=5,                              # Minimum required events for valid patient sequence
            age_usage='year',                                # Specifies granularity of age binning (year vs month)
            use_cls=True,                                    # Adds classification token at sequence start like BERT
            use_sep=False,                                    # Adds separator tokens between visits for temporal segmentation
            valid_patient=True,                              # Internal flag for tracking patient data validity
            num_visits=None,                                 # Tracks number of unique clinical visits (set internally)
            combined_length=None,                            # Total length of patient sequence before processing
            unpadded_length=None                             # Original sequence length before padding to max_length
        )

        # Get encoded data for model input
        model_input = patient.get_patient_data(
            evaluate=True,
            mask_dynamically=False,
            min_unmasked=1,
            max_masked=20,
            masked_lm_prob=0.15,
            mask_drugs=True
        )

        for k, v in model_input.items():
            print(f'  {k}: {v}')
        if 'input_ids' in model_input:
            print('  input_ids (codes):', code_dict.decode(model_input['input_ids']))
        if 'sex_ids' in model_input:
            print('  sex_ids:', sex_dict.decode(model_input['sex_ids']))
        if 'state_ids' in model_input:
            print('  state_ids:', state_dict.decode(model_input['state_ids']))
        if 'age_ids' in model_input:
            print('  age_ids:', age_dict.decode(model_input['age_ids']))
        if 'entity_ids' in model_input:
            print('  entity_ids:', [code_dict.ids_to_entity.get(x.item(), 'UNK') for x in model_input['entity_ids']])
        if 'code_labels' in model_input:
            print('  code_labels:', code_dict.decode(model_input['code_labels']))
        df = patient.to_df(
            code_embed=code_dict,
            age_embed=age_dict,
            sex_embed=sex_dict,
            state_embed=state_dict,
            dynamic_masking=True,
            mask_drugs=True,
            min_unmasked=1,
            max_masked=20,
            masked_lm_prob=0.15
        )
        print(df)
    except Exception as e:
        print(f'Error processing patient {idx}: {e}')

  input_ids: tensor([ 2, 15,  6, 14,  7, 13,  4,  4,  9, 17, 10, 18, 11, 19, 12,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
  entity_ids: tensor([0, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  sex_ids: tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  attention_mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  position_ids: tensor([ 1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0

### What happens in this step?
- Each patient's data is encoded into integer IDs using the dictionaries.
- Masking is applied to some codes for masked language modeling (MLM) pretraining.
- The encoded data is ready for input to a BERT-style model.
- The decoded output and DataFrame view help with interpretability and debugging.

## 7. Adding Endpoint Labels for Pretraining

For supervised pretraining, we may want to predict clinical endpoints (e.g., Prolonged Length of Stay, 'plos'). We add endpoint labels to each patient and encode them.

In [7]:
endpoint_labels = ['plos']
endpoint_dict = EndpointDict(endpoint_labels)

# Add a 'plos' label to each patient (for demo, alternate 0/1)
for i, patient in enumerate(patients):
    patient['endpoint_labels'] = torch.LongTensor([i % 2])

## 8. Creating Patient Objects with Endpoint Labels

We now create `Patient` objects for each patient, including the endpoint labels. These objects are ready for batching into a dataset.

In [None]:
patient_objs = []
for patient_data in patients:
    patient_obj = Patient(
        patient_id=patient_data['patient_id'],            # Unique identifier for tracking and referencing individual patients in the dataset
        diagnoses=patient_data['diagnoses'],            # List of ICD diagnosis codes that will be converted to PheWAS for standardization
        drugs=patient_data['drugs'],                    # List of RxNorm drug codes that will be converted to ATC for standardization
        diagnosis_dates=patient_data['diagnosis_dates'], # Timestamps for diagnoses to maintain temporal order in sequence modeling
        prescription_dates=patient_data['prescription_dates'], # Timestamps for prescriptions to maintain temporal order
        birth_year=patient_data['birth_year'],           # Used to calculate patient age for age-based feature encoding
        sex=patient_data['sex'],                         # Demographic feature encoded as MALE/FEMALE for patient characterization
        patient_state=patient_data['patient_state'],     # Geographic feature for potential regional health pattern analysis
        max_length=50,                                   # Maximum sequence length - truncates or pads sequences for consistent model input
        code_embed=code_dict,                            # Handles conversion and encoding of medical codes (ICD→PheWAS, RxNorm→ATC)
        sex_embed=sex_dict,                              # Converts sex categories to numerical embeddings for model input
        age_embed=age_dict,                              # Bins and encodes patient ages into discrete categories
        state_embed=state_dict,                          # Converts state information into numerical embeddings
        mask_drugs=True,                                 # Enables drug code masking for MLM pretraining task
        delete_temporary_variables=True,                 # Cleans up memory by removing intermediate processing variables
        split_sequence=True,                             # Splits long patient sequences into manageable chunks if needed
        drop_duplicates=True,                            # Removes redundant codes to prevent sequence bias
        converted_codes=False,                           # Indicates if codes are in raw form (False) or already converted to standard format
        convert_icd_to_phewas=True,                      # Enables automatic conversion of ICD codes to PheWAS for standardization
        convert_rxcui_to_atc=True,                       # Enables automatic conversion of RxNorm to ATC for drug standardization
        keep_min_unmasked=1,                             # Ensures at least one token remains unmasked for context in MLM
        max_masked_tokens=20,                            # Limits masked tokens to prevent too much information loss
        masked_lm_prob=0.15,                             # Probability of masking each token, following BERT's approach
        truncate='right',                                # Specifies to remove older events when truncating long sequences
        index_date=None,                                 # Optional reference date for temporal alignment of patient histories
        had_plos=None,                                   # Prolonged Length of Stay label for supervised learning tasks
        endpoint_labels=patient_data.get('endpoint_labels', None), # Additional outcome labels for multi-task learning
        dynamic_masking=False,                           # When False, uses static masks; True generates new masks each epoch
        min_observations=5,                              # Minimum required events for valid patient sequence
        age_usage='year',                                # Specifies granularity of age binning (year vs month)
        use_cls=True,                                    # Adds classification token at sequence start like BERT
        use_sep=True,                                    # Adds separator tokens between visits for temporal segmentation
        valid_patient=True,                              # Internal flag for tracking patient data validity
        num_visits=None,                                 # Tracks number of unique clinical visits (set internally)
        combined_length=None,                            # Total length of patient sequence before processing
        unpadded_length=None                             # Original sequence length before padding to max_length
    )
    patient_objs.append(patient_obj)

      codes     sex position entity_ids code_label state  age
0       CLS  FEMALE        1    default       None    NY   61
1       250  FEMALE        1     phewas       None    NY   61
2       SEP  FEMALE        1    default       None    NY   61
3   C09AA05  FEMALE        2        atc       None    NY   61
4       SEP  FEMALE        2    default       None    NY   61
5     401.1  FEMALE        3     phewas      401.1    NY   62
6       SEP  FEMALE        3    default       None    NY   62
7   A01AA01  FEMALE        4        atc       None    NY   62
8       SEP  FEMALE        4    default       None    NY   62
9    530.11  FEMALE        5     phewas       None    NY   63
10      SEP  FEMALE        5    default       None    NY   63
11  D05AX02  FEMALE        6        atc       None    NY   63
12      SEP  FEMALE        6    default       None    NY   63
13    715.2  FEMALE        7     phewas       None    NY   64
14      SEP  FEMALE        7    default       None    NY   64
15     M

## 9. Creating and Saving the PatientDataset

The `PatientDataset` class batches multiple `Patient` objects and prepares them for model training. It handles masking, batching, and can save the dataset to disk for later use.

In [None]:
patient_dataset = PatientDataset(
    code_embed=code_dict,
    age_embed=age_dict,
    sex_embed=sex_dict,
    state_embed=state_dict,
    endpoint_dict=endpoint_dict,
    patient_paths=None,
    max_length=50,
    do_eval=True,
    mask_substances=True,
    dataset_path=None,
    patients=patient_objs,
    dynamic_masking=False,
    min_unmasked=1,
    max_masked=20,
    masked_lm_prob=0.15
)

print(patient_dataset)

(Patient: 12345; Valid: True; Visits: 14, 'na')


In [None]:
output_path = 'demo_patient_dataset.pt'
patient_dataset.save_dataset(path=output_path, with_patients=False, do_copy=True)
print(f'PatientDataset saved to {output_path}. You can now use this file for pretraining.')

## 10. Summary

- We defined encoding dictionaries for all categorical data.
- We created example patients and encoded their data.
- We processed and masked the data for BERT-style pretraining.
- We batched patients into a dataset and saved it for model training.

This workflow ensures that all patient data is consistently and efficiently prepared for use in ExMed-BERT or similar models.