# ExMed-BERT Data Preparation for Pretraining

This notebook provides a detailed walkthrough of how patient data is prepared for pretraining a BERT-style model using the ExMed-BERT framework. It is designed for users who are new to the codebase and want to understand the data pipeline, encoding strategies, and the rationale behind each step.

We will cover:
- The purpose and structure of the main classes involved (e.g., `CodeDict`, `AgeDict`, `Patient`, `PatientDataset`)
- How medical codes and patient attributes are encoded
- How patient data is processed and masked for model input
- How a dataset is constructed and saved for pretraining


### Troubleshooting Environment Setup: Missing Packages

If you encounter `ImportError` messages or other package-related issues after creating the Conda environment (e.g., `ModuleNotFoundError`), it's likely that one or more packages failed to install correctly during the initial environment creation process. This can sometimes happen due to network issues, conflicts, or specific system configurations.

**To resolve this, manually install the problematic package(s) using `pip`, ensuring you specify the exact version listed in the `environment.yaml` file.**

**Steps:**

1.  **Activate your Conda environment:**
    ```bash
    conda activate your_env_name
    ```
    (Replace `your_env_name` with the actual name of your Conda environment).

2.  **Identify the missing package and its version:**
    Refer to your `environment.yaml` file. Locate the `pip` section or the main `dependencies` list. For example, if you see an error related to `torch`, find its entry:
    ```yaml
    # ... other dependencies
    dependencies:
      - python=3.9
      - pip
      - pip:
        - torch==1.10.0+cu113 # Example entry
        - transformers==4.12.0 # Another example
        # ...
    ```

3.  **Manually install the package with pip:**
    Use the `pip install` command, appending `==` and the exact version number you found in the `environment.yaml`.
    ```bash
    pip install package_name==x.y.z
    ```
    **Example:**
    ```bash
    pip install torch==1.10.0+cu113
    pip install transformers==4.12.0
    ```
    Repeat this for any other packages causing import errors.

This manual installation typically resolves issues where Conda's solver or initial pip installations within the Conda creation process might have failed for specific packages.

In [19]:
import sys
import os
from datetime import date

# Add the ExMed-BERT package to Python path for imports
sys.path.append('./ExMed-BERT-main')

import torch  # PyTorch is used for tensor operations and model input formatting

## 2. Encoding Classes: Dictionaries for Medical Data

ExMed-BERT uses specialized dictionary classes to encode various types of patient data into integer IDs suitable for model input. These include:
- `CodeDict`: Handles medical codes (ICD, PheWAS, RxNorm, ATC) and their mappings.
- `AgeDict`: Bins and encodes patient ages.
- `SexDict`: Encodes sex/gender.
- `StateDict`: Encodes US state information.
- `EndpointDict`: Encodes endpoint labels for classification tasks.

__Note:__ We may enhance the `CodeDict` class to broaden its capabilities beyond ICD and RxNorm to include CPT codes and allow for the integration of other relevant coding systems as needed.

Let's import these classes and briefly describe their roles.


In [20]:
from exmed_bert.data.encoding import (
    AgeDict,     # Handles age binning and encoding
    CodeDict,    # Handles medical codes (ICD, PheWAS, RxNorm, ATC)
    SexDict,     # Handles sex/gender encoding
    StateDict,   # Handles US state encoding
    DICT_DEFAULTS,  # Default tokens (e.g., PAD, UNK, MASK)
    EndpointDict # Handles endpoint label encoding (e.g., for classification tasks)
)
from exmed_bert.data.patient import Patient  # Main class for patient sequence processing
from exmed_bert.data.dataset import PatientDataset  # Dataset class for batching patients

### What do these classes do?
- **CodeDict**: Maps raw medical codes (diagnoses, drugs) to integer IDs, handles code normalization (e.g., mapping RxNorm to ATC, ICD to PheWAS), and provides decoding for interpretability.
- **AgeDict**: Bins ages (e.g., by year) and encodes them as IDs.
- **SexDict**: Encodes sex/gender as IDs.
- **StateDict**: Encodes US state abbreviations as IDs.
- **EndpointDict**: Encodes outcome labels for supervised tasks.

These dictionaries ensure that all categorical data is consistently mapped to integer IDs for model input.

## 3. Example Code Dictionaries and Mappings

For demonstration purposes, we define small example sets of codes and mappings. In our real model, these would be substantially larger and loaded from files within the `data/` directory. Given the current model state, the following files would be essential:

* **`ATC_codes`**: We would likely utilize ATC5 codes to ensure the desired level of granularity.

* **`phecodes`**: A pragmatic approach would be to replace the traditional Phecode list with a comprehensive list of ICD codes.

* **`RxNorm_to_ATC_mapping`**: Rather than a direct RxNorm to ATC mapping, we may opt to use generic molecule names as tokens. This can be achieved with a ATC5-to-ATC5 map (e.g., `{'FINTEPLA': 'FINTEPLA', 'BIMZELX': 'BIMZELX'}`). We could also try and `NDC11` to `ATC5` map. We just want some way to get our ATC codes in here

* **`ICD_to_Phecode_mapping`**: Similar to the RxNorm mapping, we can re-engineer this to function as an ICD-10 to ICD-10 map (e.g., `{'G40.81': 'G40.81', 'G40.812': 'G40.812'}`).

In [21]:
# Define new ATC codes (medications)
atc_codes = [
    'A01AA01', 'B01AC06', 'C09AA05', 'D05AX02', 'E03AA01', 'F01BA01', 'G04BE03',
    'H02AB02', 'J01CA04', 'K01AA02', 'L01XE01', 'M01AE01', 'N02BA01', 'O01AA01',
    'P01AB01', 'Q01AA01', 'R03BA02', 'S01AA01', 'T01AA01', 'U01AA01', 'V01AA01'
    ]

# Define new PheWAS codes (diagnoses)
phewas_codes = [
    '008', '250', '401.1', '530.11', '715.2', '272.1', '585.3',
    '800', '900', '1000', '1100', '1200', '1300', '1400', '1500',
    '1600', '1700', '1800', '1900', '2000'
    ]

# New RxNorm to ATC mapping. 
# ALL ATC CODES CAN BE FOUND IN THE ATC CODE LIST. THIS IS IMPORTANT
rx_to_atc_map = {
    '860975': 'A01AA01', '197361': 'B01AC06', '123456': 'C09AA05', '654321': 'D05AX02',
    '789012': 'E03AA01', '345678': 'F01BA01', '987654': 'G04BE03',
    '111111': 'H02AB02', '222222': 'J01CA04', '333333': 'K01AA02', '444444': 'L01XE01',
    '555555': 'M01AE01', '666666': 'N02BA01', '777777': 'O01AA01', '888888': 'P01AB01',
    '999999': 'Q01AA01', '121212': 'R03BA02', '131313': 'S01AA01', '141414': 'T01AA01',
    '151515': 'U01AA01', '161616': 'V01AA01'
    }

# New ICD-10 to PheWAS mapping
# ALL PHEWA CODES CAN BE FOUND IN THE PHEWAS CODE LIST. THIS IS IMPORTANT
icd_to_phewas_map = {
    'I10': '401.1', 'E11.9': '250', 'Z51.11': '008', 'K21.0': '530.11', 'M17.9': '715.2',
    'E78.5': '272.1', 'N18.3': '585.3',
    'A00': '800', 'B00': '900', 'C00': '1000', 'D00': '1100', 'E00': '1200',
    'F00': '1300', 'G00': '1400', 'H00': '1500', 'J00': '1600', 'K00': '1700',
    'L00': '1800', 'M00': '1900', 'N00': '2000'
    }

# States for state encoding
state_list = ['CA', 'NY', 'TX', 'FL', 'IL', 'PA', 'OH', 'GA', 'NC', 'MI', 'WA', 'OR', 'CO', 'AZ', 'MA']

## 4. Initializing Encoding Dictionaries

This section outlines the process of building our model's vocabulary, analogous to the tokenization phase in traditional NLP. ExMed-BERT handles patient data across five distinct "modes" or sequences:

1.  **Code Sequence:** Medical codes (e.g., ICD, CPT, RxNorm, ATC).
2.  **Sex Sequence:** Patient sex/gender information (though currently omitted from our model).
3.  **State Sequence:** US state information (though currently omitted from our model).
4.  **Age Sequence:** Patient age information.
5.  **Visit Sequence:** This sequence is implicitly generated by ordering patient events based on their respective dates. (NO PYTHON CLASS FOR THIS MODE)

For each of these modalities, we employ dedicated dictionary classes, each responsible for managing its own vocabulary. This ensures that vocabularies are kept separate and distinct across modalities.

**Important Considerations:**

* **Integer Mapping:** During this initialization, each unique medical code (or other data point) is mapped to a distinct integer ID. This integer serves as a unique identifier, or "token," for that specific data element within the model.
* **Learned Embeddings:** Later in the model's architecture, each of these unique tokens will be assigned its own learnable embedding, allowing the model to capture semantic relationships and context.
* **Reserved Tokens:** The first six integer IDs are reserved for BERT's special tokens, defined as `DICT_DEFAULTS = ["PAD", "UNK", "CLS", "SEP", "MASK", "NA"]`. These special tokens serve various purposes, such as padding sequences to a uniform length (`PAD`), handling unknown tokens (`UNK`), marking the beginning of a sequence (`CLS`), separating segments (`SEP`), masking tokens for pre-training (`MASK`), and indicating not applicable data (`NA`).
* **Modality-Specific Vocabularies:** To reiterate, the vocabularies are __*not*__ mixed together. Each dictionary class maintains a separate set of vocabulary and mappings for its respective modality, allowing ExMed-BERT to process and learn from these distinct data types independently yet cohesively within the model's architecture.


In [22]:
# Code dictionary for medical codes
code_dict = CodeDict(
    atc_codes=atc_codes,
    phewas_codes=phewas_codes,
    rx_to_atc_map=rx_to_atc_map,
    icd_to_phewas_map=icd_to_phewas_map
)

# Age dictionary for binning and encoding ages (by year)
# You can customize the binsize. I chose yearly age bins from 0 to 90: [0, 1, 2, ..., 90]
age_dict = AgeDict(max_age=90, min_age=0, binsize=1)
# Convert the age_dict's vocabulary to integers for consistency
age_dict.vocab = [str(int(float(age))) if age not in DICT_DEFAULTS else age for age in age_dict.vocab]
age_dict.labels_to_id = {(str(int(float(label))) if label not in DICT_DEFAULTS else label): idx 
                        for label, idx in age_dict.labels_to_id.items()}
age_dict.ids_to_label = {idx: (str(int(float(label))) if label not in DICT_DEFAULTS else label)
                        for idx, label in age_dict.ids_to_label.items()}

# Sex dictionary
sex_dict = SexDict(sex=['MALE', 'FEMALE'])

# State dictionary
state_dict = StateDict(states=state_list)

## 5. Example Patient Data Structure

To illustrate the input format, let's examine the structure of patient data that ExMed-BERT will process. Each patient's record will primarily consist of sequences of events, where diagnoses and drug administrations are time-stamped.

It's crucial to ensure that the following core demographic and clinical data points are included for each patient:

* **Diagnosis Codes:** A chronological list of all recorded diagnosis codes (e.g., ICD-10 codes).
* **Diagnosis Dates:** The corresponding dates for each diagnosis, maintaining a direct index-level alignment with the `Diagnosis Codes` list.
* **Drug Codes:** A chronological list of administered drug codes (e.g., RxNorm CUI or generic molecule names).
* **Drug Dates:** The corresponding dates for each drug administration, also maintaining index-level alignment with the `Drug Codes` list.
* **Birth Year:** The patient's year of birth.
* **Sex:** The patient's sex (e.g., "Male," "Female").
* **Patient State:** The patient's US state of residence.

**Example Illustration:**

Consider patient `10000`. Their data might appear as follows:

* **Diagnosis Codes:** `['A00', 'G00']`
* **Diagnosis Dates:** `['2021-01-01', '2021-01-07']`
* **Drug Codes:** `['111111', '222222']`
* **Drug Dates:** `['2021-02-01', '2021-02-02']`

This structure ensures that for a specific patient, the diagnosis at index `i` in `diagnosis_codes` occurred on the date at index `i` in `diagnosis_dates`. Similarly, for drug administrations.

**Note on PLOS (Prolonged Length of Stay):**

While "Prolonged Length of Stay" (PLOS) is a significant clinical outcome, it is generally **not used during the pre-training phase** of ExMed-BERT. Pre-training typically focuses on learning general representations from large volumes of unlabeled medical text and codes. However, PLOS could serve as an excellent **downstream fine-tuning task**, where the pre-trained ExMed-BERT model is further trained on a labeled dataset to predict whether a patient will experience a prolonged hospital stay. This leverages the learned medical knowledge for a specific clinical prediction problem.

In [23]:
patients = [
    {
        'patient_id': 10000,
        'diagnoses': ['A00', 'B00', 'C00', 'D00', 'E00', 'F00', 'G00'],
        'diagnosis_dates': [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3), date(2021, 1, 4), date(2021, 1, 5), date(2021, 1, 6), date(2021, 1, 7)],
        'drugs': ['111111', '222222', '333333', '444444', '555555', '666666', '777777'],
        'prescription_dates': [date(2021, 2, 1), date(2021, 2, 2), date(2021, 2, 3), date(2021, 2, 4), date(2021, 2, 5), date(2021, 2, 6), date(2021, 2, 7)],
        'birth_year': 1980,
        'sex': 'MALE',
        'patient_state': 'CA',
        'plos': 1
    },
    {
        'patient_id': 10001,
        'diagnoses': ['H00', 'J00', 'K00', 'L00', 'M00', 'N00', 'I10'],
        'diagnosis_dates': [date(2021, 1, 8), date(2021, 1, 9), date(2021, 1, 10), date(2021, 1, 11), date(2021, 1, 12), date(2021, 1, 13), date(2021, 1, 14)],
        'drugs': ['888888', '999999', '121212', '131313', '141414', '151515', '161616'],
        'prescription_dates': [date(2021, 2, 8), date(2021, 2, 9), date(2021, 2, 10), date(2021, 2, 11), date(2021, 2, 12), date(2021, 2, 13), date(2021, 2, 14)],
        'birth_year': 1975,
        'sex': 'FEMALE',
        'patient_state': 'NY',
        'plos': 0
    },
    {
        'patient_id': 10002,
        'diagnoses': ['E11.9', 'Z51.11', 'K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00'],
        'diagnosis_dates': [date(2021, 3, 1), date(2021, 3, 2), date(2021, 3, 3), date(2021, 3, 4), date(2021, 3, 5), date(2021, 3, 6), date(2021, 3, 7)],
        'drugs': ['860975', '197361', '123456', '654321', '789012', '345678', '987654'],
        'prescription_dates': [date(2021, 4, 1), date(2021, 4, 2), date(2021, 4, 3), date(2021, 4, 4), date(2021, 4, 5), date(2021, 4, 6), date(2021, 4, 7)],
        'birth_year': 1990,
        'sex': 'MALE',
        'patient_state': 'TX',
        'plos': 1
    },
    {
        'patient_id': 10003,
        'diagnoses': ['B00', 'C00', 'D00', 'E00', 'F00', 'G00', 'H00'],
        'diagnosis_dates': [date(2021, 5, 1), date(2021, 5, 2), date(2021, 5, 3), date(2021, 5, 4), date(2021, 5, 5), date(2021, 5, 6), date(2021, 5, 7)],
        'drugs': ['222222', '333333', '444444', '555555', '666666', '777777', '888888'],
        'prescription_dates': [date(2021, 6, 1), date(2021, 6, 2), date(2021, 6, 3), date(2021, 6, 4), date(2021, 6, 5), date(2021, 6, 6), date(2021, 6, 7)],
        'birth_year': 1985,
        'sex': 'FEMALE',
        'patient_state': 'FL',
        'plos': 0
    },
    {
        'patient_id': 10004,
        'diagnoses': ['J00', 'K00', 'L00', 'M00', 'N00', 'I10', 'E11.9'],
        'diagnosis_dates': [date(2021, 7, 1), date(2021, 7, 2), date(2021, 7, 3), date(2021, 7, 4), date(2021, 7, 5), date(2021, 7, 6), date(2021, 7, 7)],
        'drugs': ['999999', '121212', '131313', '141414', '151515', '161616', '860975'],
        'prescription_dates': [date(2021, 8, 1), date(2021, 8, 2), date(2021, 8, 3), date(2021, 8, 4), date(2021, 8, 5), date(2021, 8, 6), date(2021, 8, 7)],
        'birth_year': 1970,
        'sex': 'MALE',
        'patient_state': 'IL',
        'plos': 1
    },
    {
        'patient_id': 10005,
        'diagnoses': ['Z51.11', 'K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00', 'B00'],
        'diagnosis_dates': [date(2021, 9, 1), date(2021, 9, 2), date(2021, 9, 3), date(2021, 9, 4), date(2021, 9, 5), date(2021, 9, 6), date(2021, 9, 7)],
        'drugs': ['197361', '123456', '654321', '789012', '345678', '987654', '222222'],
        'prescription_dates': [date(2021, 10, 1), date(2021, 10, 2), date(2021, 10, 3), date(2021, 10, 4), date(2021, 10, 5), date(2021, 10, 6), date(2021, 10, 7)],
        'birth_year': 2000,
        'sex': 'FEMALE',
        'patient_state': 'PA',
        'plos': 0
    },
    {
        'patient_id': 10006,
        'diagnoses': ['C00', 'D00', 'E00', 'F00', 'G00', 'H00', 'J00'],
        'diagnosis_dates': [date(2021, 11, 1), date(2021, 11, 2), date(2021, 11, 3), date(2021, 11, 4), date(2021, 11, 5), date(2021, 11, 6), date(2021, 11, 7)],
        'drugs': ['333333', '444444', '555555', '666666', '777777', '888888', '999999'],
        'prescription_dates': [date(2021, 12, 1), date(2021, 12, 2), date(2021, 12, 3), date(2021, 12, 4), date(2021, 12, 5), date(2021, 12, 6), date(2021, 12, 7)],
        'birth_year': 1995,
        'sex': 'MALE',
        'patient_state': 'OH',
        'plos': 1
    },
    {
        'patient_id': 10007,
        'diagnoses': ['K00', 'L00', 'M00', 'N00', 'I10', 'E11.9', 'Z51.11'],
        'diagnosis_dates': [date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), date(2022, 1, 4), date(2022, 1, 5), date(2022, 1, 6), date(2022, 1, 7)],
        'drugs': ['121212', '131313', '141414', '151515', '161616', '860975', '197361'],
        'prescription_dates': [date(2022, 2, 1), date(2022, 2, 2), date(2022, 2, 3), date(2022, 2, 4), date(2022, 2, 5), date(2022, 2, 6), date(2022, 2, 7)],
        'birth_year': 1988,
        'sex': 'FEMALE',
        'patient_state': 'GA',
        'plos': 0
    },
    {
        'patient_id': 10008,
        'diagnoses': ['K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00', 'B00', 'C00'],
        'diagnosis_dates': [date(2022, 3, 1), date(2022, 3, 2), date(2022, 3, 3), date(2022, 3, 4), date(2022, 3, 5), date(2022, 3, 6), date(2022, 3, 7)],
        'drugs': ['123456', '654321', '789012', '345678', '987654', '222222', '333333'],
        'prescription_dates': [date(2022, 4, 1), date(2022, 4, 2), date(2022, 4, 3), date(2022, 4, 4), date(2022, 4, 5), date(2022, 4, 6), date(2022, 4, 7)],
        'birth_year': 1978,
        'sex': 'MALE',
        'patient_state': 'NC',
        'plos': 1
    },
    {
        'patient_id': 10009,
        'diagnoses': ['D00', 'E00', 'F00', 'G00', 'H00', 'J00', 'K00'],
        'diagnosis_dates': [date(2022, 5, 1), date(2022, 5, 2), date(2022, 5, 3), date(2022, 5, 4), date(2022, 5, 5), date(2022, 5, 6), date(2022, 5, 7)],
        'drugs': ['444444', '555555', '666666', '777777', '888888', '999999', '121212'],
        'prescription_dates': [date(2022, 6, 1), date(2022, 6, 2), date(2022, 6, 3), date(2022, 6, 4), date(2022, 6, 5), date(2022, 6, 6), date(2022, 6, 7)],
        'birth_year': 1982,
        'sex': 'FEMALE',
        'patient_state': 'MI',
        'plos': 0
    },
    {
        'patient_id': 10010,
        'diagnoses': ['L00', 'M00', 'N00', 'I10', 'E11.9', 'Z51.11', 'K21.0'],
        'diagnosis_dates': [date(2022, 7, 1), date(2022, 7, 2), date(2022, 7, 3), date(2022, 7, 4), date(2022, 7, 5), date(2022, 7, 6), date(2022, 7, 7)],
        'drugs': ['131313', '141414', '151515', '161616', '860975', '197361', '123456'],
        'prescription_dates': [date(2022, 8, 1), date(2022, 8, 2), date(2022, 8, 3), date(2022, 8, 4), date(2022, 8, 5), date(2022, 8, 6), date(2022, 8, 7)],
        'birth_year': 1992,
        'sex': 'MALE',
        'patient_state': 'WA',
        'plos': 1
    },
    {
        'patient_id': 10011,
        'diagnoses': ['M17.9', 'E78.5', 'N18.3', 'A00', 'B00', 'C00', 'D00'],
        'diagnosis_dates': [date(2022, 9, 1), date(2022, 9, 2), date(2022, 9, 3), date(2022, 9, 4), date(2022, 9, 5), date(2022, 9, 6), date(2022, 9, 7)],
        'drugs': ['654321', '789012', '345678', '987654', '222222', '333333', '444444'],
        'prescription_dates': [date(2022, 10, 1), date(2022, 10, 2), date(2022, 10, 3), date(2022, 10, 4), date(2022, 10, 5), date(2022, 10, 6), date(2022, 10, 7)],
        'birth_year': 1987,
        'sex': 'FEMALE',
        'patient_state': 'OR',
        'plos': 0
    },
    {
        'patient_id': 10012,
        'diagnoses': ['E78.5', 'N18.3', 'A00', 'B00', 'C00', 'D00', 'E00'],
        'diagnosis_dates': [date(2022, 11, 1), date(2022, 11, 2), date(2022, 11, 3), date(2022, 11, 4), date(2022, 11, 5), date(2022, 11, 6), date(2022, 11, 7)],
        'drugs': ['789012', '345678', '987654', '222222', '333333', '444444', '555555'],
        'prescription_dates': [date(2022, 12, 1), date(2022, 12, 2), date(2022, 12, 3), date(2022, 12, 4), date(2022, 12, 5), date(2022, 12, 6), date(2022, 12, 7)],
        'birth_year': 1998,
        'sex': 'MALE',
        'patient_state': 'CO',
        'plos': 1
    },
    {
        'patient_id': 10013,
        'diagnoses': ['N18.3', 'A00', 'B00', 'C00', 'D00', 'E00', 'F00'],
        'diagnosis_dates': [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3), date(2023, 1, 4), date(2023, 1, 5), date(2023, 1, 6), date(2023, 1, 7)],
        'drugs': ['345678', '987654', '222222', '333333', '444444', '555555', '666666'],
        'prescription_dates': [date(2023, 2, 1), date(2023, 2, 2), date(2023, 2, 3), date(2023, 2, 4), date(2023, 2, 5), date(2023, 2, 6), date(2023, 2, 7)],
        'birth_year': 1983,
        'sex': 'FEMALE',
        'patient_state': 'AZ',
        'plos': 0
    },
    {
        'patient_id': 10014,
        'diagnoses': ['A00', 'B00', 'C00', 'D00', 'E00', 'F00', 'G00'],
        'diagnosis_dates': [date(2023, 3, 1), date(2023, 3, 2), date(2023, 3, 3), date(2023, 3, 4), date(2023, 3, 5), date(2023, 3, 6), date(2023, 3, 7)],
        'drugs': ['987654', '222222', '333333', '444444', '555555', '666666', '777777'],
        'prescription_dates': [date(2023, 4, 1), date(2023, 4, 2), date(2023, 4, 3), date(2023, 4, 4), date(2023, 4, 5), date(2023, 4, 6), date(2023, 4, 7)],
        'birth_year': 1993,
        'sex': 'MALE',
        'patient_state': 'MA',
        'plos': 1
    },
    {
        'patient_id': 10015,
        'diagnoses': ['B00', 'C00', 'D00', 'E00', 'F00', 'G00', 'H00'],
        'diagnosis_dates': [date(2023, 5, 1), date(2023, 5, 2), date(2023, 5, 3), date(2023, 5, 4), date(2023, 5, 5), date(2023, 5, 6), date(2023, 5, 7)],
        'drugs': ['222222', '333333', '444444', '555555', '666666', '777777', '888888'],
        'prescription_dates': [date(2023, 6, 1), date(2023, 6, 2), date(2023, 6, 3), date(2023, 6, 4), date(2023, 6, 5), date(2023, 6, 6), date(2023, 6, 7)],
        'birth_year': 1986,
        'sex': 'FEMALE',
        'patient_state': 'CA',
        'plos': 0
    },
    {
        'patient_id': 10016,
        'diagnoses': ['J00', 'K00', 'L00', 'M00', 'N00', 'I10', 'E11.9'],
        'diagnosis_dates': [date(2023, 7, 1), date(2023, 7, 2), date(2023, 7, 3), date(2023, 7, 4), date(2023, 7, 5), date(2023, 7, 6), date(2023, 7, 7)],
        'drugs': ['999999', '121212', '131313', '141414', '151515', '161616', '860975'],
        'prescription_dates': [date(2023, 8, 1), date(2023, 8, 2), date(2023, 8, 3), date(2023, 8, 4), date(2023, 8, 5), date(2023, 8, 6), date(2023, 8, 7)],
        'birth_year': 1971,
        'sex': 'MALE',
        'patient_state': 'NY',
        'plos': 1
    },
    {
        'patient_id': 10017,
        'diagnoses': ['E11.9', 'Z51.11', 'K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00'],
        'diagnosis_dates': [date(2023, 9, 1), date(2023, 9, 2), date(2023, 9, 3), date(2023, 9, 4), date(2023, 9, 5), date(2023, 9, 6), date(2023, 9, 7)],
        'drugs': ['860975', '197361', '123456', '654321', '789012', '345678', '987654'],
        'prescription_dates': [date(2023, 10, 1), date(2023, 10, 2), date(2023, 10, 3), date(2023, 10, 4), date(2023, 10, 5), date(2023, 10, 6), date(2023, 10, 7)],
        'birth_year': 1991,
        'sex': 'FEMALE',
        'patient_state': 'TX',
        'plos': 0
    },
    {
        'patient_id': 10018,
        'diagnoses': ['B00', 'C00', 'D00', 'E00', 'F00', 'G00', 'H00'],
        'diagnosis_dates': [date(2023, 11, 1), date(2023, 11, 2), date(2023, 11, 3), date(2023, 11, 4), date(2023, 11, 5), date(2023, 11, 6), date(2023, 11, 7)],
        'drugs': ['222222', '333333', '444444', '555555', '666666', '777777', '888888'],
        'prescription_dates': [date(2023, 12, 1), date(2023, 12, 2), date(2023, 12, 3), date(2023, 12, 4), date(2023, 12, 5), date(2023, 12, 6), date(2023, 12, 7)],
        'birth_year': 1984,
        'sex': 'MALE',
        'patient_state': 'FL',
        'plos': 1
    },
    {
        'patient_id': 10019,
        'diagnoses': ['J00', 'K00', 'L00', 'M00', 'N00', 'I10', 'E11.9'],
        'diagnosis_dates': [date(2024, 1, 1), date(2024, 1, 2), date(2024, 1, 3), date(2024, 1, 4), date(2024, 1, 5), date(2024, 1, 6), date(2024, 1, 7)],
        'drugs': ['999999', '121212', '131313', '141414', '151515', '161616', '860975'],
        'prescription_dates': [date(2024, 2, 1), date(2024, 2, 2), date(2024, 2, 3), date(2024, 2, 4), date(2024, 2, 5), date(2024, 2, 6), date(2024, 2, 7)],
        'birth_year': 1972,
        'sex': 'FEMALE',
        'patient_state': 'IL',
        'plos': 0
    },
    {
        'patient_id': 10020,
        'diagnoses': ['Z51.11', 'K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00', 'B00'],
        'diagnosis_dates': [date(2024, 3, 1), date(2024, 3, 2), date(2024, 3, 3), date(2024, 3, 4), date(2024, 3, 5), date(2024, 3, 6), date(2024, 3, 7)],
        'drugs': ['197361', '123456', '654321', '789012', '345678', '987654', '222222'],
        'prescription_dates': [date(2024, 4, 1), date(2024, 4, 2), date(2024, 4, 3), date(2024, 4, 4), date(2024, 4, 5), date(2024, 4, 6), date(2024, 4, 7)],
        'birth_year': 2001,
        'sex': 'MALE',
        'patient_state': 'PA',
        'plos': 1
    },
    {
        'patient_id': 10021,
        'diagnoses': ['C00', 'D00', 'E00', 'F00', 'G00', 'H00', 'J00'],
        'diagnosis_dates': [date(2024, 5, 1), date(2024, 5, 2), date(2024, 5, 3), date(2024, 5, 4), date(2024, 5, 5), date(2024, 5, 6), date(2024, 5, 7)],
        'drugs': ['333333', '444444', '555555', '666666', '777777', '888888', '999999'],
        'prescription_dates': [date(2024, 6, 1), date(2024, 6, 2), date(2024, 6, 3), date(2024, 6, 4), date(2024, 6, 5), date(2024, 6, 6), date(2024, 6, 7)],
        'birth_year': 1996,
        'sex': 'FEMALE',
        'patient_state': 'OH',
        'plos': 0
    },
    {
        'patient_id': 10022,
        'diagnoses': ['K00', 'L00', 'M00', 'N00', 'I10', 'E11.9', 'Z51.11'],
        'diagnosis_dates': [date(2024, 7, 1), date(2024, 7, 2), date(2024, 7, 3), date(2024, 7, 4), date(2024, 7, 5), date(2024, 7, 6), date(2024, 7, 7)],
        'drugs': ['121212', '131313', '141414', '151515', '161616', '860975', '197361'],
        'prescription_dates': [date(2024, 8, 1), date(2024, 8, 2), date(2024, 8, 3), date(2024, 8, 4), date(2024, 8, 5), date(2024, 8, 6), date(2024, 8, 7)],
        'birth_year': 1989,
        'sex': 'MALE',
        'patient_state': 'GA',
        'plos': 1
    },
    {
        'patient_id': 10023,
        'diagnoses': ['K21.0', 'M17.9', 'E78.5', 'N18.3', 'A00', 'B00', 'C00'],
        'diagnosis_dates': [date(2024, 9, 1), date(2024, 9, 2), date(2024, 9, 3), date(2024, 9, 4), date(2024, 9, 5), date(2024, 9, 6), date(2024, 9, 7)],
        'drugs': ['123456', '654321', '789012', '345678', '987654', '222222', '333333'],
        'prescription_dates': [date(2024, 10, 1), date(2024, 10, 2), date(2024, 10, 3), date(2024, 10, 4), date(2024, 10, 5), date(2024, 10, 6), date(2024, 10, 7)],
        'birth_year': 1979,
        'sex': 'FEMALE',
        'patient_state': 'NC',
        'plos': 0
    },
    {
        'patient_id': 10024,
        'diagnoses': ['D00', 'E00', 'F00', 'G00', 'H00', 'J00', 'K00'],
        'diagnosis_dates': [date(2024, 11, 1), date(2024, 11, 2), date(2024, 11, 3), date(2024, 11, 4), date(2024, 11, 5), date(2024, 11, 6), date(2024, 11, 7)],
        'drugs': ['444444', '555555', '666666', '777777', '888888', '999999', '121212'],
        'prescription_dates': [date(2024, 12, 1), date(2024, 12, 2), date(2024, 12, 3), date(2024, 12, 4), date(2024, 12, 5), date(2024, 12, 6), date(2024, 12, 7)],
        'birth_year': 1981,
        'sex': 'MALE',
        'patient_state': 'MI',
        'plos': 1
    },
    {
        'patient_id': 10025,
        'diagnoses': ['L00', 'M00', 'N00', 'I10', 'E11.9', 'Z51.11', 'K21.0'],
        'diagnosis_dates': [date(2025, 1, 1), date(2025, 1, 2), date(2025, 1, 3), date(2025, 1, 4), date(2025, 1, 5), date(2025, 1, 6), date(2025, 1, 7)],
        'drugs': ['131313', '141414', '151515', '161616', '860975', '197361', '123456'],
        'prescription_dates': [date(2025, 2, 1), date(2025, 2, 2), date(2025, 2, 3), date(2025, 2, 4), date(2025, 2, 5), date(2025, 2, 6), date(2025, 2, 7)],
        'birth_year': 1994,
        'sex': 'FEMALE',
        'patient_state': 'WA',
        'plos': 0
    },
    {
        'patient_id': 10026,
        'diagnoses': ['M17.9', 'E78.5', 'N18.3', 'A00', 'B00', 'C00', 'D00'],
        'diagnosis_dates': [date(2025, 3, 1), date(2025, 3, 2), date(2025, 3, 3), date(2025, 3, 4), date(2025, 3, 5), date(2025, 3, 6), date(2025, 3, 7)],
        'drugs': ['654321', '789012', '345678', '987654', '222222', '333333', '444444'],
        'prescription_dates': [date(2025, 4, 1), date(2025, 4, 2), date(2025, 4, 3), date(2025, 4, 4), date(2025, 4, 5), date(2025, 4, 6), date(2025, 4, 7)],
        'birth_year': 1997,
        'sex': 'MALE',
        'patient_state': 'OR',
        'plos': 1
    },
    {
        'patient_id': 10027,
        'diagnoses': ['E78.5', 'N18.3', 'A00', 'B00', 'C00', 'D00', 'E00'],
        'diagnosis_dates': [date(2025, 5, 1), date(2025, 5, 2), date(2025, 5, 3), date(2025, 5, 4), date(2025, 5, 5), date(2025, 5, 6), date(2025, 5, 7)],
        'drugs': ['789012', '345678', '987654', '222222', '333333', '444444', '555555'],
        'prescription_dates': [date(2025, 6, 1), date(2025, 6, 2), date(2025, 6, 3), date(2025, 6, 4), date(2025, 6, 5), date(2025, 6, 6), date(2025, 6, 7)],
        'birth_year': 1999,
        'sex': 'FEMALE',
        'patient_state': 'CO',
        'plos': 0
    }
]

## 6. Patient Object: Encoding and Processing

The `Patient` class is central to preparing raw patient data for the model. It handles the encoding of various data points using predefined dictionaries (like `CodeDict`, `SexDict`, `AgeDict`, `StateDict`), and orchestrates critical preprocessing steps such as masking for Masked Language Modeling (MLM) and sequence truncation/padding. Essentially, it transforms a patient's raw medical record into a structured, numerical format that can be directly fed into a deep learning model.

This class manages patient data, incorporates vocabulary definitions, and applies various configurable model parameters to construct the input sequences. The examples below show the patient sequence *after* masking has been performed.

Let's look at some important parameters and outputs:

### Key Patient Class Parameters:

* **`max_length`**: Defines the maximum number of codes (or patient events) allowed in the sequence. Sequences longer than this will be truncated, and shorter ones will be padded.
* **`mask_drugs`**: A boolean flag that, when `True`, enables the masking of drug codes during the MLM pre-training task. This allows the model to learn to predict drug codes from their context. This can be extended to other code types if desired.
* **`drop_duplicates`**: When set to `True` (recommended), this parameter ensures that only unique codes occurring on the exact same `time_point` (date) are retained, preventing redundancy in concurrent events.
* **`converted_codes`**: This boolean flag indicates whether the input diagnosis and drug codes provided to the `Patient` object have already been converted to their target format (e.g., PheWAS, ATC).
    * If `True`, the class skips its internal conversion functions, assuming codes are pre-processed.
    * If `False`, it signals that input codes may require transformation based on `convert_icd_to_phewas` and `convert_rxcui_to_atc` flags.
* **`convert_icd_to_phewas`**: When `True` (and `converted_codes` is `False`), ICD diagnosis codes will be mapped to PheWAS codes during processing. If `False`, ICD codes remain as-is.
* **`convert_rxcui_to_atc`**: When `True` (and `converted_codes` is `False`), RxCUI drug codes will be transformed into ATC codes. If `False`, RxCUI codes remain in their original format.
* **`dynamic_masking`**: Setting this to `True` is recommended as it ensures that different codes are masked on each iteration (or epoch). This prevents the model from memorizing a fixed set of masked tokens and improves generalization.
* **`age_usage`**: Specifies the unit for age representation (e.g., "year", "months", "decimal"). This aligns with how age values are defined in the `AgeDict`.

### Model Input Modalities (from `patient.get_patient_data()`):

After initializing and processing a `Patient` object, calling `model_input = patient.get_patient_data()` retrieves a dictionary containing various sequences (modalities) ready for model consumption.

* **`input_ids`**: A tensor or list containing the numerical IDs of the PheWAS and ATC codes in the patient's sequence. Special reserved BERT tokens (`0-5`) are used for specific purposes (e.g., `[CLS]`, `[MASK]`, `[PAD]`). You will see many `0`s at the end of sequences, which are `PAD` tokens used to ensure all patient sequences conform to the `max_length`.
    * **Example from data:** `tensor([ 2, 31, 32, 33, 34, 4, 36, 37, 9, 10, 11, 12, 32, 15, 16, 0, ...])`
    * **Decoded Example:** `['CLS', '715.2', '272.1', '585.3', '800', 'MASK', '1000', '1100', 'D05AX02', 'E03AA01', 'F01BA01', 'G04BE03', '272.1', 'K01AA02', 'L01XE01', 'PAD', ...]`

* **`entity_ids`**: A tensor indicating the type of `input_id` code at each index. This helps the model differentiate between code categories.
    * `default`: Used for special BERT tokens (e.g., `[CLS]`, `[PAD]`).
    * `atc`: Denotes an ATC (drug) code.
    * `phewas`: Denotes a PheWAS (diagnosis) code.
    * **Example from data:** `tensor([0, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 0, ...])`
    * **Decoded Example:** `['default', 'phewas', 'phewas', 'phewas', 'phewas', 'phewas', 'phewas', 'phewas', 'atc', 'atc', 'atc', 'atc', 'atc', 'atc', 'atc', 'default', ...]`

* **`sex_ids`**: A tensor representing the patient's biological sex, constant across all active tokens in the sequence.
    * **Example from data:** `tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, ...])`
    * **Decoded Example:** `['MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'MALE', 'PAD', ...]`

* **`state_ids`**: A tensor holding the patient's state of residence, constant across all active tokens.
    * **Example from data:** `tensor([13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, ...])`
    * **Decoded Example:** `['OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'OR', 'PAD', ...]`

* **`age_ids`**: A tensor representing the patient's age (often binned or quantized), constant across all active tokens.
    * **Example from data:** `tensor([30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 0, ...])`
    * **Decoded Example:** `['28', '28', '28', '28', '28', '28', '28', '28', '28', '28', '28', '28', '28', '28', '28', 'PAD', ...]`

* **`attention_mask`**: A critical tensor for transformer models, indicating which parts of the input sequence are "real" data and which are padding.
    * `1`: Indicates an active, non-padded token that the model should attend to.
    * `0`: Indicates a padded token that the model should ignore during attention calculations.
    * **Example from data:** `tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ...])`

* **`position_ids`**: This tensor provides positional information to the model, enabling it to understand the chronological order of events within the sequence.
    * If two or more medical events share the same `position_id`, it indicates that they occurred on the same day or within the same visit.
    * **Example from data:** `tensor([ 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, ...])`

* **`code_labels`**: This tensor is specifically used for the Masked Language Model (MLM) pre-training task. It holds the *original IDs* of the tokens that were masked in `input_ids`.
    * `-100`: This is a common convention in PyTorch to signify that a particular position should be ignored when calculating the loss (i.e., this token was not masked, or it's a special token that shouldn't be predicted).
    * For example, if `input_ids` has `4` (the `[MASK]` token) at index 5, then `code_labels` at index 5 will contain the original numerical ID of the code that was masked (e.g., `35`, which decodes to `'900'`). If a token was replaced by a *random* token from the vocabulary (another part of BERT's masking strategy), `code_labels` will hold the ID of the *original* token before the replacement.
    * **Example from data:** `tensor([-100, -100, -100, -100, -100, 35, -100, -100, -100, -100, -100, -100, 14, -100, -100, -100, ...])`
    * **Decoded Example:** `[None, None, None, None, None, '900', None, None, None, None, None, None, 'J01CA04', None, None, None, ...]` (where `J01CA04` was the original value at that position before replacement/masking).

In [None]:
for idx, patient_data in enumerate(patients, 1):
    try:
        # Remove endpoint_labels since they're not needed for pretraining
        patient = Patient(
            patient_id=patient_data['patient_id'],            # Unique identifier for tracking and referencing individual patients in the dataset
            diagnoses=patient_data['diagnoses'],            # List of ICD diagnosis codes that will be converted to PheWAS for standardization
            drugs=patient_data['drugs'],                    # List of RxNorm drug codes that will be converted to ATC for standardization
            diagnosis_dates=patient_data['diagnosis_dates'], # Timestamps for diagnoses to maintain temporal order in sequence modeling
            prescription_dates=patient_data['prescription_dates'], # Timestamps for prescriptions to maintain temporal order
            birth_year=patient_data['birth_year'],           # Used to calculate patient age for age-based feature encoding
            sex=patient_data['sex'],                         # Demographic feature encoded as MALE/FEMALE for patient characterization
            patient_state=patient_data['patient_state'],     # Geographic feature for potential regional health pattern analysis
            max_length=50,                                   # Maximum sequence length - truncates or pads sequences for consistent model input
            code_embed=code_dict,                            # Handles conversion and encoding of medical codes (ICD→PheWAS, RxNorm→ATC)
            sex_embed=sex_dict,                              # Converts sex categories to numerical embeddings for model input
            age_embed=age_dict,                              # Bins and encodes patient ages into discrete categories
            state_embed=state_dict,                          # Converts state information into numerical embeddings
            mask_drugs=True,                                 # Enables drug code masking for MLM pretraining task
            delete_temporary_variables=True,                 # Cleans up memory by removing intermediate processing variables
            split_sequence=True,                             # Splits long patient sequences into manageable chunks if needed
            drop_duplicates=True,                            # Removes redundant codes to prevent sequence bias
            converted_codes=False,                           # Indicates if codes are in raw form (False) or already converted to standard format
            convert_icd_to_phewas=True,                      # Enables automatic conversion of ICD codes to PheWAS for standardization
            convert_rxcui_to_atc=True,                       # Enables automatic conversion of RxNorm to ATC for drug standardization
            keep_min_unmasked=1,                             # Ensures at least one token remains unmasked for context in MLM
            max_masked_tokens=20,                            # Limits masked tokens to prevent too much information loss
            masked_lm_prob=0.15,                             # Probability of masking each token, following BERT's approach
            truncate='right',                                # Specifies to remove older events when truncating long sequences
            index_date=None,                                 # Optional reference date for temporal alignment of patient histories
                                                            # Remove had_plos and endpoint_labels for pretraining - only use MLM
            dynamic_masking=False,                           # When False, uses static masks; True generates new masks each epoch
            min_observations=5,                              # Minimum required events for valid patient sequence
            age_usage='year',                                # Specifies granularity of age binning (year vs month)
            use_cls=True,                                    # Adds classification token at sequence start like BERT
            use_sep=False,                                    # Adds separator tokens between visits for temporal segmentation
            valid_patient=True,                              # Internal flag for tracking patient data validity
            num_visits=None,                                 # Tracks number of unique clinical visits (set internally)
            combined_length=None,                            # Total length of patient sequence before processing
            unpadded_length=None                             # Original sequence length before padding to max_length
        )

        # Get encoded data for model input with evaluate=True for pretraining (enables code_labels)
        model_input = patient.get_patient_data(
            evaluate=True,      # Set to True for pretraining to get code_labels for MLM
            mask_dynamically=False,
            code_embed=code_dict,
            min_unmasked=1,
            max_masked=20,
            masked_lm_prob=0.15,
            mask_drugs=True
        )

        for k, v in model_input.items():
            print(f'  {k}: {v}')
        if 'input_ids' in model_input:
            print('  input_ids (codes):', code_dict.decode(model_input['input_ids']))
        if 'sex_ids' in model_input:
            print('  sex_ids:', sex_dict.decode(model_input['sex_ids']))
        if 'state_ids' in model_input:
            print('  state_ids:', state_dict.decode(model_input['state_ids']))
        if 'age_ids' in model_input:
            print('  age_ids:', age_dict.decode(model_input['age_ids']))
        if 'entity_ids' in model_input:
            print('  entity_ids:', [code_dict.ids_to_entity.get(x.item(), 'UNK') for x in model_input['entity_ids']])
        if 'code_labels' in model_input:
            print('  code_labels:', code_dict.decode(model_input['code_labels']))
        df = patient.to_df(
            code_embed=code_dict,
            age_embed=age_dict,
            sex_embed=sex_dict,
            state_embed=state_dict,
            dynamic_masking=False,
            mask_drugs=True,
            min_unmasked=1,
            max_masked=20,
            masked_lm_prob=0.15
        )
        print(df)
    except Exception as e:
        print(f'Error processing patient {idx}: {e}')

  input_ids: tensor([ 2, 34, 35, 36, 37,  4, 39,  4, 13, 14, 15, 16, 17, 18, 19,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
  entity_ids: tensor([0, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  sex_ids: tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  attention_mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])
  position_ids: tensor([ 1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0

### What happens in this step?
- Each patient's data is encoded into integer IDs using the dictionaries.
- Masking is applied to some codes for masked language modeling (MLM) pretraining.
- The encoded data is ready for input to a BERT-style model.
- The decoded output and DataFrame view help with interpretability and debugging.

## 7. Pretraining Focus: Masked Language Modeling Only

For pretraining, we focus exclusively on masked language modeling (MLM). We do NOT include endpoint labels like PLOS since those are designed for supervised fine-tuning tasks, not self-supervised pretraining.

The key differences for pretraining setup:
- **No endpoint labels**: We set `endpoint_labels=None` and `had_plos=None`
- **Enable code labels**: With `evaluate=True` and no endpoint labels, `code_labels` are automatically generated for MLM
- **Static or dynamic masking**: We can use either static masks (applied once) or dynamic masks (applied each epoch)

## 8. Creating Patient Objects for Pretraining

We now create `Patient` objects for each patient, configured specifically for pretraining with masked language modeling. Note that we exclude endpoint labels and focus purely on MLM objectives.

In [29]:
patient_objs = []
for patient_data in patients: # This would come from a huge json file up patient data in the needed format
    patient_obj = Patient(
        patient_id=patient_data['patient_id'],            # Unique identifier for tracking and referencing individual patients in the dataset
        diagnoses=patient_data['diagnoses'],            # List of ICD diagnosis codes that will be converted to PheWAS for standardization
        drugs=patient_data['drugs'],                    # List of RxNorm drug codes that will be converted to ATC for standardization
        diagnosis_dates=patient_data['diagnosis_dates'], # Timestamps for diagnoses to maintain temporal order in sequence modeling
        prescription_dates=patient_data['prescription_dates'], # Timestamps for prescriptions to maintain temporal order
        birth_year=patient_data['birth_year'],           # Used to calculate patient age for age-based feature encoding
        sex=patient_data['sex'],                         # Demographic feature encoded as MALE/FEMALE for patient characterization
        patient_state=patient_data['patient_state'],     # Geographic feature for potential regional health pattern analysis
        max_length=50,                                   # Maximum sequence length - truncates or pads sequences for consistent model input
        code_embed=code_dict,                            # Handles conversion and encoding of medical codes (ICD→PheWAS, RxNorm→ATC)
        sex_embed=sex_dict,                              # Converts sex categories to numerical embeddings for model input
        age_embed=age_dict,                              # Bins and encodes patient ages into discrete categories
        state_embed=state_dict,                          # Converts state information into numerical embeddings
        mask_drugs=True,                                 # Enables drug code masking for MLM pretraining task
        delete_temporary_variables=True,                 # Cleans up memory by removing intermediate processing variables
        split_sequence=True,                             # Splits long patient sequences into manageable chunks if needed
        drop_duplicates=True,                            # Removes redundant codes to prevent sequence bias
        converted_codes=False,                           # Indicates if codes are in raw form (False) or already converted to standard format
        convert_icd_to_phewas=True,                      # Enables automatic conversion of ICD codes to PheWAS for standardization
        convert_rxcui_to_atc=True,                       # Enables automatic conversion of RxNorm to ATC for drug standardization
        keep_min_unmasked=1,                             # Ensures at least one token remains unmasked for context in MLM
        max_masked_tokens=20,                            # Limits masked tokens to prevent too much information loss
        masked_lm_prob=0.15,                             # Probability of masking each token, following BERT's approach
        truncate='right',                                # Specifies to remove older events when truncating long sequences
        index_date=None,                                 # Optional reference date for temporal alignment of patient histories
                                                        # Remove endpoint-related parameters for pure pretraining setup
        dynamic_masking=False,                           # When False, uses static masks; True generates new masks each epoch
        min_observations=5,                              # Minimum required events for valid patient sequence
        age_usage='year',                                # Specifies granularity of age binning (year vs month)
        use_cls=True,                                    # Adds classification token at sequence start like BERT
        use_sep=True,                                    # Adds separator tokens between visits for temporal segmentation
        valid_patient=True,                              # Internal flag for tracking patient data validity
        num_visits=None,                                 # Tracks number of unique clinical visits (set internally)
        combined_length=None,                            # Total length of patient sequence before processing
        unpadded_length=None                             # Original sequence length before padding to max_length
    )
    patient_objs.append(patient_obj)
    
    # Get model input for pretraining - code_labels should be available
    model_input = patient_obj.get_patient_data(
        evaluate=True,        # True for pretraining to get MLM labels
        mask_dynamically=False,
        code_embed=code_dict,
        min_unmasked=1,
        max_masked=20,
        masked_lm_prob=0.15,
        mask_drugs=True
    )
    print(f"Patient {patient_data['patient_id']} model input keys: {list(model_input.keys())}")
    # Verify code_labels are present for MLM pretraining
    if 'code_labels' in model_input:
        print(f"✓ Code labels available for MLM pretraining")
    else:
        print("✗ No code labels - check masking configuration")

Patient 10000 model input keys: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels available for MLM pretraining
Patient 10001 model input keys: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels available for MLM pretraining
Patient 10002 model input keys: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels available for MLM pretraining
Patient 10003 model input keys: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels available for MLM pretraining
Patient 10004 model input keys: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels available for MLM pr

## 9. Creating and Saving the PatientDataset

The `PatientDataset` class batches multiple `Patient` objects and prepares them for model training. It handles masking, batching, and can save the dataset to disk for later use.

In [30]:
# Create the PatientDataset for pretraining (remove endpoint_dict)
patient_dataset = PatientDataset(
    code_embed=code_dict,            # CodeDict: maps medical codes to integer IDs and handles code normalization
    age_embed=age_dict,              # AgeDict: bins and encodes patient ages as integer IDs
    sex_embed=sex_dict,              # SexDict: encodes sex/gender as integer IDs
    state_embed=state_dict,          # StateDict: encodes US state information as integer IDs
    endpoint_dict=None,              # Remove endpoint_dict for pretraining - we only need MLM
    patient_paths=None,              # List of file paths for patient objects (None if patients are loaded in RAM)
    max_length=50,                   # Maximum sequence length for each patient (truncates or pads sequences)
    do_eval=True,                    # Indicates if the dataset is for evaluation (affects patient output)
    mask_substances=True,            # Whether to mask substances (drugs) as well as diagnoses for MLM
    dataset_path=None,               # Path to dataset directory (used if saving/loading patients from disk)
    patients=patient_objs,           # List of Patient objects loaded in RAM
    dynamic_masking=False,           # If True, masking is done dynamically each epoch; if False, static masks
    min_unmasked=1,                  # Minimum number of unmasked tokens per patient sequence
    max_masked=20,                   # Maximum number of masked tokens per patient sequence
    masked_lm_prob=0.15              # Probability of masking each token (for MLM pretraining)
)

print("PatientDataset created successfully for pretraining")
print(f"Dataset length: {len(patient_dataset)}")
print("Sample patient data:")
sample_data = patient_dataset[0]
print(f"Keys in sample data: {list(sample_data.keys())}")
if 'code_labels' in sample_data:
    print("✓ Code labels present - ready for MLM pretraining")
else:
    print("✗ No code labels found")

PatientDataset created successfully for pretraining
Dataset length: 28
Sample patient data:
Keys in sample data: ['input_ids', 'entity_ids', 'sex_ids', 'attention_mask', 'position_ids', 'state_ids', 'age_ids', 'plos_label', 'code_labels']
✓ Code labels present - ready for MLM pretraining


## 10. Splitting and Saving Train/Validation Datasets for Pretraining

For pretraining, we split our patient dataset into train, validation, and test sets using simple random splitting (no stratification needed since we don't have endpoint labels). This creates datasets ready for masked language modeling pretraining.

In [32]:
# Use simple split without stratification since we don't have endpoint labels for pretraining
import numpy as np
from sklearn.model_selection import train_test_split

# Get indices for splitting
indices = list(range(len(patient_dataset)))

# Split indices into train/val/test
train_indices, temp_indices = train_test_split(indices, test_size=0.2, random_state=42)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

# Create subset datasets
train_patients = [patient_objs[i] for i in train_indices]
val_patients = [patient_objs[i] for i in val_indices] 
test_patients = [patient_objs[i] for i in test_indices]

# Create separate datasets
train_dataset = PatientDataset(
    code_embed=code_dict,
    age_embed=age_dict,
    sex_embed=sex_dict,
    state_embed=state_dict,
    endpoint_dict=None,  # No endpoints for pretraining
    patients=train_patients,
    do_eval=True,
    max_length=50,
    dynamic_masking=False,
    min_unmasked=1,
    max_masked=20,
    masked_lm_prob=0.15
)

val_dataset = PatientDataset(
    code_embed=code_dict,
    age_embed=age_dict,
    sex_embed=sex_dict,
    state_embed=state_dict,
    endpoint_dict=None,  # No endpoints for pretraining
    patients=val_patients,
    do_eval=True,
    max_length=50,
    dynamic_masking=False,
    min_unmasked=1,
    max_masked=20,
    masked_lm_prob=0.15
)

test_dataset = PatientDataset(
    code_embed=code_dict,
    age_embed=age_dict,
    sex_embed=sex_dict,
    state_embed=state_dict,
    endpoint_dict=None,  # No endpoints for pretraining
    patients=test_patients,
    do_eval=True,
    max_length=50,
    dynamic_masking=False,
    min_unmasked=1,
    max_masked=20,
    masked_lm_prob=0.15
)

# Create directory if it doesn't exist
import os
os.makedirs('pretrain_stuff', exist_ok=True)

# Save the datasets
train_output_path = 'pretrain_stuff/demo_train_patient_dataset.pt'
val_output_path = 'pretrain_stuff/demo_val_patient_dataset.pt'
test_output_path = 'pretrain_stuff/demo_test_patient_dataset.pt'

train_dataset.save_dataset(path=train_output_path, with_patients=True, do_copy=True)
val_dataset.save_dataset(path=val_output_path, with_patients=True, do_copy=True)
test_dataset.save_dataset(path=test_output_path, with_patients=True, do_copy=True)

print(f'✓ Train dataset saved to {train_output_path} ({len(train_dataset)} patients)')
print(f'✓ Validation dataset saved to {val_output_path} ({len(val_dataset)} patients)')
print(f'✓ Test dataset saved to {test_output_path} ({len(test_dataset)} patients)')
print("Datasets are ready for MLM pretraining!")

✓ Train dataset saved to pretrain_stuff/demo_train_patient_dataset.pt (22 patients)
✓ Validation dataset saved to pretrain_stuff/demo_val_patient_dataset.pt (3 patients)
✓ Test dataset saved to pretrain_stuff/demo_test_patient_dataset.pt (3 patients)
Datasets are ready for MLM pretraining!


## Running the Pretraining Script

To execute the pretraining process, please refer to the `pretrain_example.py` file.

## Environment Considerations

This project's development environment was set up using **Conda**. However, for deployment or running the model within a Databricks environment, it is highly recommended to create the environment using the **`poetry.lock` file**. This approach generally ensures more consistent and reproducible dependency management in production or distributed settings like Databricks.

## Model Output and Logging

A significant amount of model training information, including logs and run metadata, is stored in the `outputs/` and `mlruns/` directories. These directories are automatically created when the model is run for the first time.

The primary output file containing the **learned model parameters (model weights)** will be found within the `outputs/` directory.

## Dynamic Masking Troubleshooting

Please note that enabling **dynamic masking**, which is crucial for the pretraining run, has previously led to errors. If you encounter issues when trying to enable this feature, further investigation will be required to resolve them.