# Dataloader tutorial

# Setup

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging

import CPRD
from CPRD.data.foundational_loader import FoundationalDataModule

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for debug errors

cuda


In [20]:
# Set GPT config to be equivalent
@dataclass
class DemoConfig:
    block_size: int = 32             # what is the maximum context length for predictions?
    # n_layer: int = 6
    # n_head: int = 6
    # n_embd: int = 384
    # pos_encoding: str = "index-embedding"                 # Manually adding later
    # bias: bool = True
    # attention_type: str = "global"    
    # dropout: float = 0.0
    unk_freq_threshold: float = 0.0

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    # eval_interval: int = 1
    # learning_rate: float = 3e-4
    # epochs: int = 10
    
opt = OptConfig()

# Data processing

Data is first extracted from CPRD using [DExtER](https://link.springer.com/article/10.1007/s10654-020-00677-6) and is available within the optimal project master dataset. Some outlier filtering has already been done in this extraction, and the ICD diagnostic codes in CPRD have been summarised and processed.

In [21]:
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/baseline
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/timeseries

total 69K
drwx--S--- 6 nobody nobody  16K Nov 13 10:25 [0m[01;34m.[0m
drwx--S--- 7 nobody nobody 4.0K Nov 10 11:51 [01;34m..[0m
drwx--S--- 2 nobody nobody 4.0K Nov 14 10:44 [01;34mbaseline[0m
-rwx------ 1 nobody nobody 4.0K Oct 23 08:26 [01;32m._.DS_Store[0m
-rwx------ 1 nobody nobody 6.1K Nov  4 13:01 [01;32m.DS_Store[0m
drwx--S--- 2 nobody nobody 4.0K Nov 13 10:27 [01;34mmetadata[0m
drwx--S--- 4 nobody nobody 4.0K Nov  4 12:57 [01;34mtimeseries[0m
drwx--S--- 2 nobody nobody 4.0K Nov 13 10:28 [01;34mzip[0m
total 30G
drwx--S--- 2 nobody nobody 4.0K Nov 14 10:44 [0m[01;34m.[0m
drwx--S--- 6 nobody nobody  16K Nov 13 10:25 [01;34m..[0m
-rwx------ 1 nobody nobody  15G Nov 10 15:43 [01;32mmasterDataOptimal_v3.csv[0m
-rwx------ 1 nobody nobody  15G Nov 14 10:27 [01;32mmasterDataOptimalWithIMD_v3.csv[0m
total 98K
drwx--S--- 4 nobody nobody 4.0K Nov  4 12:57 [0m[01;34m.[0m
drwx--S--- 6 nobody nobody  16K Nov 13 10:25 [01;34m..[0m
-rwx------ 1 nobody nobody 4.0K N

These are pre-processed further in R to obtain three .csv files which are in turn compiled into an SQLite database within the `CPRD.data` module. 

* TODO: Replace R filtering with in-built python script
* TODO: It is future work to combine these steps into DExtER

In [22]:
!ls --color -lah /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing
!ls --color -lah /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed

total 67K
drwx--S---  3 gaddcz nobody 4.0K Sep 28 11:12 [0m[01;34m.[0m
drwx--S---. 5 gaddcz nobody 4.0K Oct 24 11:21 [01;34m..[0m
drwx--S---  2 gaddcz nobody  16K Oct 18 15:48 [01;34mprocessed[0m
-rwx------  1 gaddcz nobody 7.5K Sep 28 15:22 [01;32mprocessing_diagnoses.R[0m
-rwx------. 1 gaddcz nobody  14K Oct 18 16:10 [01;32mprocessing_measurements_and_tests.R[0m
-rwx------  1 gaddcz nobody 2.2K Sep 19 11:00 [01;32mprocessing_static.R[0m
total 2.0G
drwx--S--- 2 gaddcz nobody  16K Oct 18 15:48 [0m[01;34m.[0m
drwx--S--- 3 gaddcz nobody 4.0K Sep 28 11:12 [01;34m..[0m
-rwx------ 1 gaddcz nobody 983M Oct 18 15:48 [01;32mcprd.db[0m
-rwx------ 1 gaddcz nobody 128M Oct  2 15:06 [01;32mdiagnosis_history.csv[0m
-rwx------ 1 gaddcz nobody 874M Sep 28 15:38 [01;32mmeasurements.csv[0m
-rwx------ 1 gaddcz nobody  42M Sep 20 12:06 [01;32mstatic.csv[0m


In doing this we convert each of these files into an SQL table.

# SQL tables

## Static table
<div align="justify">
The static table has one row per data owner. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* ETHNICITY, SEX, etc: The static variables which remain constant throughout a lifetime.
* INDEX_AGE / START_AGE / END_AGE: 
</div>


<div align="center">
TODO: 
</div>

* Repace missing with nans or some other unique value so these can be masked later
* Add filtering to only include events after index age



In [23]:
print(f"`Static table` docstring:\n{CPRD.data.database.build_static_db.build_static_table.__doc__}")

`Static table` docstring:

    
    Produced anonymized table:
    ┌──────────────────────┬─────┬───────────┬───────────────┬─────────────┬─────────────┬───────────────┐
    │ PRACTICE_PATIENT_ID  ┆ SEX ┆ ETHNICITY ┆ YEAR_OF_BIRTH ┆ INDEX_AGE   ┆ START_AGE   ┆ END_AGE       │
    │ ---                  ┆ --- ┆ ---       ┆ ---           ┆ ---         ┆ ---         ┆ ---           │
    │ str                  ┆ str ┆ str       ┆ str           ┆ i64 (days)  ┆ i64 (days)  ┆ i64 (days)    │
    ╞══════════════════════╪═════╪═══════════╪═══════════════╪═════════════╪═════════════╪═══════════════╡
    │ <anonymous 1>        ┆ M   ┆ WHITE     ┆ yyyy--mm-dd   ┆ dd          ┆ dd          ┆ dd            │
    │ <anonymous 2>        ┆ F   ┆ MISSING   ┆ yyyy--mm-dd   ┆ dd          ┆ dd          ┆ dd            │
    │ …                    ┆ …   ┆ …         ┆ …             ┆             ┆             ┆               │
    │ <anonymous N>        ┆ M   ┆ WHITE     ┆ yyyy--mm-dd   ┆ dd          ┆ dd  

## Diagnosis table
<div align="justify">
The diagnosis table has one row per diagnosis. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* EVENT: The categorical event
* AGE_AT_EVENT: Number if days between subject birth and the day of event
* VALUE: to be removed
* EVENT_TYPE: to be removed
</div>


<div align="center">
  TODO: 
</div>

* We could embed based on subgroups. For example each of these conditions are categorical, but they could be further sub categories. For example, diabetes can be one category and types further divide this.

In [24]:
print(f"`Diagnosis table` docstring:\n{CPRD.data.database.build_diagnosis_db.build_diagnosis_table.__doc__}")

`Diagnosis table` docstring:

    Build measurements and tests table in database

    Produced anonymized table:
    ┌──────────────────────┬───────┬──────────────┬──────────────┬────────────────────────────┐
    │ PRACTICE_PATIENT_ID  ┆ VALUE ┆ EVENT        ┆ AGE_AT_EVENT ┆ EVENT_TYPE                 │
    │ ---                  ┆ ---   ┆ ---          ┆ ---          ┆ ---                        │
    │ str                  ┆ f64   ┆ str          ┆ i64 (days)   ┆ str                        │
    ╞══════════════════════╪═══════╪══════════════╪══════════════╪════════════════════════════╡
    │ <anonymous 1>        ┆ null  ┆ HF           ┆ 11632        ┆ categorical                │
    │ <anonymous 2>        ┆ null  ┆ HF           ┆ 25635        ┆ categorical                │
    │ …                    ┆ …     ┆ …            ┆ …            ┆ …                          │
    │ <anonymous N>        ┆ null  ┆ FIBROMYALGIA ┆ 8546         ┆ categorical                │
    └──────────────────

## Measurements and tests table
<div align="justify">
The measurements table has one row per diagnosis. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* EVENT: The categorical event
* AGE_AT_EVENT: Number if days between subject birth and the day of event
* VALUE: The measurement/test record
* EVENT_TYPE: whether the value is categorical (e.g. EVENT=="smoking", value="ex-smoker"), or continuous (e.g. EVENT=="bmi", value=23.3)
</div>


<div align="center">
  TODO: 
</div>



In [25]:
print(f"`Measurements and tests table` docstring:\n{CPRD.data.database.build_measurements_and_tests_db.build_measurements_table.__doc__}")

`Measurements and tests table` docstring:
 
    Build measurements and tests table in database

    Produced anonymized table:
    ┌──────────────────────┬───────┬──────────────────┬──────────────┬───────────────────────┐
    │ PRACTICE_PATIENT_ID  ┆ VALUE ┆ EVENT            ┆ AGE_AT_EVENT ┆ EVENT_TYPE            │
    │ ---                  ┆ ---   ┆ ---              ┆ ---          ┆ ---                   │
    │ str                  ┆ f64   ┆ str              ┆ i64 (days)   ┆ str                   │
    ╞══════════════════════╪═══════╪══════════════════╪══════════════╪═══════════════════════╡
    │ <anonymous 1>        ┆ 23.3  ┆ bmi              ┆ 10254        ┆ univariate_regression │
    │ <anonymous 1>        ┆ 24.1  ┆ bmi              ┆ 11829        ┆ univariate_regression │
    │ …                    ┆ …     ┆ …                ┆ …            ┆ …                     │
    │ <anonymous N>        ┆ 0.17  ┆ eosinophil_count ┆ 12016        ┆ univariate_regression │
    └─────────────

# Filtering

We can query the SQL tables separately to find the set of patients that fit the study criteria. 

First we can connect directly to the database using sqlite3

<div align="center">
  TODO: 
</div>

* The db shouldn't really be interfaced with like this, this is mostly for demonstration and a 

In [26]:
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()

Performing a look-up directly on the tables shows what diagnoses, measurements and tests are available

In [27]:
# Check what measurements are available
cursor.execute("SELECT DISTINCT EVENT FROM measurement_table")
measurements = [_item[0] for _item in cursor.fetchall()]
print("We have access to measurements and tests:\n *" + '\n *'.join(measurements))

# Check what diagnoses are available
cursor.execute("SELECT DISTINCT EVENT FROM diagnosis_table")
diagnoses = [_item[0] for _item in cursor.fetchall()]
print("\nAdditionally, we have access to diagnoses:\n *" + '\n *'.join(diagnoses))



We have access to measurements and tests:
 *bmi
 *hydroxyvitamin2
 *hydroxyvitamin3
 *aspartate_transam
 *serum_level
 *creatinine_ratio
 *basophil_count
 *blood_calcium
 *blood_urea
 *brain_natriuretic_peptide_level
 *calcium_adjusted_level
 *calculated_LDL_cholesterol_level
 *combined_total_vitamin_D2_and_D3_level
 *corrected_serum_calcium_level
 *diastolic_blood_pressure
 *eosinophil_count

Additionally, we have access to diagnoses:
 *HF
 *AF
 *ISCHAEMICSTROKE
 *STROKEUNSPECIFIED
 *STROKE_HAEMRGIC
 *HYPERTENSION
 *MINFARCTION
 *IHD_NOMI
 *PAD_STRICT
 *VALVULARDISEASES
 *AORTICANEURYSM
 *TYPE1DM
 *TYPE2DIABETES
 *CKDSTAGE3TO5
 *DEPRESSION
 *ANXIETY
 *BIPOLAR
 *EATINGDISORDERS
 *SCHIZOPHRENIAMM
 *AUTISM
 *ALCOHOLMISUSE
 *SUBSTANCEMISUSE
 *CHRONIC_LIVER_DISEASE_ALCOHOL
 *NAFLD
 *OTHER_CHRONIC_LIVER_DISEASE_OPTIMAL
 *ULCERATIVE_COLITIS
 *CROHNS_DISEASE
 *ALL_DEMENTIA
 *PARKINSONS
 *EPILEPSY
 *ALLCA_NOBCC_VFINAL
 *LYMPHOMA_PREVALENCE
 *LEUKAEMIA_PREVALENCE
 *PLASMACELL_NEOPLASM
 *ASTHMA_

Given some subset of these we can query the database to find those subjects who have at least one entry of a list of measurements, tests or diagnoses

In [28]:
identifiers1 = CPRD.data.database.queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = CPRD.data.database.queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

In [29]:
# For now, lets take only the first 10,000
N = np.min((len(all_identifiers), 10000))
print(f"Using N={N} random samples, from the available {len(all_identifiers)}")

identifiers = random.choices(all_identifiers, k=N)

Using N=10000 random samples, from the available 117102


# Make PyTorch dataloader

We can now initialise our DL friendly datasets. This is done within the DataModule initialisation. 

Initialisation takes one required argument, which is the set of patient `identifiers` we wish to build into our dataset. The remaining arguments are optional and explained in the docstring


<div align="center">
  TODO: 
</div>

* Add option of saving and loading initialised train/test/val splits

In [30]:
print(f"`FoundationalDataModule` docstring:\n{FoundationalDataModule.__doc__}")

`FoundationalDataModule` docstring:

    PyTorch-Lightning datamodule for foundational models
    
    ARGS:
        practice_patient_id (list[str])
            List of practice patient identifiers which satisfy study criteria.
            
    KWARGS:
        batch_size (int): 
        
        unk_freq_threshold (float). 
            Value between 0 and 1, controlling at what level of frequency rare tokens (equiv. conditions/measurements 
            with this tokenizer) are mapped to the UNK token. Used to reduce vocabulary size
            
        min_workers (int):
            
        weighted_sampler (bool):
            NotImplemented. 
        load_event_stream (optional, str):
        
        save_event_stream (optional, str):
        
    


In [31]:
dm = FoundationalDataModule(identifiers=identifiers,
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold)

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")

INFO:root:Building DL-friendly representation
INFO:root:Dropping samples with no dynamic events


8623 training samples
480 validation samples
479 test samples


## Dataset

Each dataset within the DataModule is a ``cprd.data.dataset.dataset_polars.EventStreamDataset`` object. Within this class we process the SQL tables from a form in which each patient (static) or record (dianosis/measurement/test) has its own row, into a ragged form where each patient has their own row in every case. This does not change the form of the static table. However, the form of the diagnoses and measurements/tests changes - and during the same update we combine these frames to one frame which contains dynamic information.

This is all wrapped in the method ``fit()``, but we can see the separate steps below:

In [32]:
print(f"`_load_dynamic` docstring:\n{CPRD.data.dataset.dataset_polars.EventStreamDataset._load_dynamic.__doc__}")

`_load_dynamic` docstring:
    
        Load and merge dynamic tables from SQL
        
        ARGS:
            practice_patient_id (list[str])
                List of practice patient identifiers which satisfy study criteria.
            
        KWARGS:
        
        
        RETURNS:
            Polars lazy frame, of the (anonymized) form:
            ┌──────────────────────┬───────────────────────┬───────────────────────────────────┬─────────────────────────┬───────────────────────────────────┐
            │ PRACTICE_PATIENT_ID  ┆ VALUE                 ┆ EVENT                             ┆ AGE_AT_EVENT            ┆ EVENT_TYPE                        │
            │ ---                  ┆ ---                   ┆ ---                               ┆ ---                     ┆ ---                               │
            │ str                  ┆ list[f64]             ┆ list[str]                         ┆ list[i64]   (in days)   ┆ list[str]                         │
            ╞═

This is then combined into one DL friendly representation, which our DataModule uses

In [33]:
print(f"`_build_DL_representation` docstring:\n{CPRD.data.dataset.dataset_polars.EventStreamDataset._build_DL_representation.__doc__}")

`_build_DL_representation` docstring:

        Build the DL-friendly representation in polars given the list of `practice_patient_id`s which fit study criteria
                
        ARGS:
            practice_patient_id (list[str])
                List of practice patient identifiers which satisfy study criteria.
            
        KWARGS:
        
        
        RETURNS:
            Polars lazy frame, of the (anonymized) form:
            ┌──────────────────────┬─────┬─────────────┬───────────────┬──────────────────────┬─────────────────────────┬─────────────────────┬────────────────────────┐
            │ PRACTICE_PATIENT_ID  ┆ SEX ┆ ETHNICITY   ┆ YEAR_OF_BIRTH ┆ VALUE                ┆ EVENT                   ┆ AGE_AT_EVENT        ┆ EVENT_TYPE             │
            │ ---                  ┆ --- ┆ ---         ┆ ---           ┆ ---                  ┆ ---                     ┆ ---                 ┆ ---                    │
            │ str                  ┆ str ┆ str        

The ``__getitem__`` within the dataset class then retrieves rows of this table and performs the relevant processing such as tokenization. Speed ups could be obtained by pre-processing this tokenization.

As these are ragged lists for memory efficiency, we also provide a ``collate_fn`` in the DataModule. This performs padding - and by performing padding this way we can avoid excessive padding.

# Data iterators

## Using \_\_getitem__

In [34]:
# print(dm.train_set[0])
print("A single element of the dataset contains:\n  * " + '\n  * '.join(dm.train_set[0].keys()))

print(dm.train_set[0])


A single element of the dataset contains:
  * identifier
  * sex
  * ethnicity
  * year_of_birth
  * input_ids
  * input_pos
  * input_ages
  * target_ids
  * target_pos
  * target_ages
{'identifier': 'p20389_944530620389', 'sex': 'F', 'ethnicity': 'WHITE', 'year_of_birth': '1975-07-15', 'input_ids': tensor([18, 14,  4,  3, 12,  4, 13,  9,  2, 12,  2, 13,  8, 11, 12,  2, 13,  7,
         9, 12,  2, 26, 13,  8,  9, 12,  2, 13,  9,  9, 12]), 'input_pos': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]), 'input_ages': tensor([10046, 11609, 11609, 11609, 11609, 11609, 11609, 11609, 11609, 11609,
        11609, 11738, 11738, 11738, 11738, 11738, 11748, 11748, 11748, 11748,
        11748, 11826, 12161, 12161, 12161, 12161, 12161, 12392, 12392, 12392,
        12392]), 'target_ids': tensor([14,  4,  3, 12,  4, 13,  9,  2, 12,  2, 13,  8, 11, 12,  2, 13,  7,  9,
        12,  2, 26, 13,  8,  9, 12,  2, 13

## Collating into batches

In [35]:
for idx, batch in enumerate(dm.train_dataloader()):
    break
print("A sample from the dataloader batch gives:")
print("Batch Dataframe Columns:\n  * " + '\n  * '.join(batch.keys()))

print(f"\nThe position index of inputs and targets: \ninputs: {batch['input_pos'][0,:10]}  \ntargets: {batch['target_pos'][0,:10]}")
print(f"\nThe time of event (in days since birth) of event of inputs and targets: \ninputs: {batch['input_ages'][0,:10]}  \ntargets: {batch['target_ages'][0,:10]}")
print(f"\nThe shifted next-step, tokenized and padded (within batch), representation from a block of a patient's sequence for events: \ninputs: {batch['input_ids'][0,:10]} \ntargets: {batch['target_ids'][0,:10]}")
print(f"\nWhich can be decoded. E.g. first sample's first 10 block tokens: \ninputs: {dm.decode(batch['input_ids'][0,:10].tolist())}  \ntargets: {dm.decode(batch['target_ids'][0,:10].tolist())}")
print(f"\nThe attention mask ({batch['attention_mask'].shape}) for padding: \n{batch['attention_mask']}")


A sample from the dataloader batch gives:
Batch Dataframe Columns:
  * input_ids
  * target_ids
  * input_pos
  * target_pos
  * input_ages
  * target_ages
  * attention_mask

The position index of inputs and targets: 
inputs: tensor([404, 405, 406, 407, 408, 409, 410, 411, 412, 413])  
targets: tensor([405, 406, 407, 408, 409, 410, 411, 412, 413, 414])

The time of event (in days since birth) of event of inputs and targets: 
inputs: tensor([16699, 16708, 16708, 16708, 16708, 16708, 16718, 16718, 16718, 16718])  
targets: tensor([16708, 16708, 16708, 16708, 16708, 16718, 16718, 16718, 16718, 16718])

The shifted next-step, tokenized and padded (within batch), representation from a block of a patient's sequence for events: 
inputs: tensor([11, 14,  6,  5, 12,  9, 14,  6,  5, 12]) 
targets: tensor([14,  6,  5, 12,  9, 14,  6,  5, 12,  4])

Which can be decoded. E.g. first sample's first 10 block tokens: 
inputs: 9 bmi 4 3 . 7 bmi 4 3 .  
targets: bmi 4 3 . 7 bmi 4 3 . 2

The attention ma

In [36]:
vocab_size = dm.train_set.tokenizer.vocab_size

print(vocab_size)
print(dm.train_set.tokenizer._itos)

101
{0: 'PAD', 1: 'UNK', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9', 12: '.', 13: 'diastolic_blood_pressure', 14: 'bmi', 15: 'eosinophil_count', 16: 'basophil_count', 17: 'corrected_serum_calcium_level', 18: 'DEPRESSION', 19: 'serum_level', 20: 'calculated_LDL_cholesterol_level', 21: 'ANXIETY', 22: 'HYPERTENSION', 23: 'TYPE2DIABETES', 24: 'OSTEOARTHRITIS', 25: 'ASTHMA_PUSHASTHMA', 26: 'ATOPICECZEMA', 27: 'ALLERGICRHINITISCONJ', 28: 'ANY_DEAFNESS_HEARING_LOSS', 29: 'aspartate_transam', 30: 'ALLCA_NOBCC_VFINAL', 31: 'PREVALENT_IBS', 32: 'IHD_NOMI', 33: 'CKDSTAGE3TO5', 34: 'ALCOHOLMISUSE', 35: 'blood_urea', 36: 'PERIPHERAL_NEUROPATHY', 37: 'COPD', 38: 'calcium_adjusted_level', 39: 'HYPOTHYROIDISM_DRAFT_V1', 40: 'AF', 41: 'GOUT', 42: 'OSTEOPOROSIS', 43: 'HF', 44: 'PSORIASIS', 45: 'SUBSTANCEMISUSE', 46: 'MINFARCTION', 47: 'combined_total_vitamin_D2_and_D3_level', 48: 'STROKEUNSPECIFIED', 49: 'ALL_DEMENTIA', 50: 'hydroxyvitamin3', 51: 'VALVULARDISEASES',