# Dataloader tutorial

# Setup

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging

import CPRD
from CPRD.data.foundational_loader import FoundationalDataModule

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for debug errors

cuda


In [2]:
# Set GPT config to be equivalent
@dataclass
class DemoConfig:
    tokenizer: str = "tabular"
    unk_freq_threshold: float = 0.0
    block_size: int = 64             # what is the maximum context length for predictions?
    # n_layer: int = 6
    # n_head: int = 6
    # n_embd: int = 384
    # pos_encoding: str = "index-embedding"                 # Manually adding later
    # bias: bool = True
    # attention_type: str = "global"    
    # dropout: float = 0.0

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    # eval_interval: int = 1
    # learning_rate: float = 3e-4
    # epochs: int = 10
    
opt = OptConfig()

# Data processing

Data is first extracted from CPRD using [DExtER](https://link.springer.com/article/10.1007/s10654-020-00677-6) and is available within the optimal project master dataset. Some outlier filtering has already been done in this extraction, and the ICD diagnostic codes in CPRD have been summarised and processed.

In [3]:
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/baseline
!ls --color -lah /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/timeseries

total 70K
drwx--S--- 7 nobody nobody  16K Nov 29 10:15 [0m[01;34m.[0m
drwx--S--- 8 nobody nobody 4.0K Nov 29 10:16 [01;34m..[0m
drwx--S--- 2 nobody nobody 4.0K Nov 14 10:44 [01;34mbaseline[0m
-rwx------ 1 nobody nobody 4.0K Oct 23 08:26 [01;32m._.DS_Store[0m
-rwx------ 1 nobody nobody 6.1K Nov  4 13:01 [01;32m.DS_Store[0m
drwx--S--- 3 nobody nobody 4.0K Nov 29 10:15 [01;34mHES[0m
drwx--S--- 2 nobody nobody 4.0K Nov 13 10:27 [01;34mmetadata[0m
drwx--S--- 4 nobody nobody 4.0K Jan 10 11:04 [01;34mtimeseries[0m
drwx--S--- 2 nobody nobody 4.0K Jan 10 10:36 [01;34mzip[0m
total 30G
drwx--S--- 2 nobody nobody 4.0K Nov 14 10:44 [0m[01;34m.[0m
drwx--S--- 7 nobody nobody  16K Nov 29 10:15 [01;34m..[0m
-rwx------ 1 nobody nobody  15G Nov 10 15:43 [01;32mmasterDataOptimal_v3.csv[0m
-rwx------ 1 nobody nobody  15G Nov 14 10:27 [01;32mmasterDataOptimalWithIMD_v3.csv[0m
total 98K
drwx--S--- 4 nobody nobody 4.0K Jan 10 11:04 [0m[01;34m.[0m
drwx--S--- 7 nobody nobody  16K 

These are pre-processed further in R to obtain three .csv files which are in turn compiled into an SQLite database within the `CPRD.data` module. 

* TODO: Replace R filtering with in-built python script
* TODO: It is future work to combine these steps into DExtER

In [4]:
!ls --color -lah /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing
!ls --color -lah /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed

total 67K
drwx--S---  3 gaddcz nobody 4.0K Sep 28 11:12 [0m[01;34m.[0m
drwx--S---. 5 gaddcz nobody 4.0K Oct 24 11:21 [01;34m..[0m
drwx--S---  2 gaddcz nobody  16K Oct 18 15:48 [01;34mprocessed[0m
-rwx------  1 gaddcz nobody 7.5K Sep 28 15:22 [01;32mprocessing_diagnoses.R[0m
-rwx------. 1 gaddcz nobody  14K Oct 18 16:10 [01;32mprocessing_measurements_and_tests.R[0m
-rwx------  1 gaddcz nobody 2.2K Sep 19 11:00 [01;32mprocessing_static.R[0m
total 2.0G
drwx--S--- 2 gaddcz nobody  16K Oct 18 15:48 [0m[01;34m.[0m
drwx--S--- 3 gaddcz nobody 4.0K Sep 28 11:12 [01;34m..[0m
-rwx------ 1 gaddcz nobody 983M Oct 18 15:48 [01;32mcprd.db[0m
-rwx------ 1 gaddcz nobody 128M Oct  2 15:06 [01;32mdiagnosis_history.csv[0m
-rwx------ 1 gaddcz nobody 874M Sep 28 15:38 [01;32mmeasurements.csv[0m
-rwx------ 1 gaddcz nobody  42M Sep 20 12:06 [01;32mstatic.csv[0m


In doing this we convert each of these files into an SQL table.

# SQL tables

## Static table
<div align="justify">
The static table has one row per data owner. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* ETHNICITY, SEX, etc: The static variables which remain constant throughout a lifetime.
* INDEX_AGE / START_AGE / END_AGE: 
</div>


<div align="center">
TODO: 
</div>

* Repace missing with nans or some other unique value so these can be masked later
* Add filtering to only include events after index age



In [5]:
print(f"`Static table` docstring:\n{CPRD.data.database.build_static_db.build_static_table.__doc__}")

`Static table` docstring:

    
    Produced anonymized table:
    ┌──────────────────────┬─────┬───────────┬───────────────┬─────────────┬─────────────┬───────────────┐
    │ PRACTICE_PATIENT_ID  ┆ SEX ┆ ETHNICITY ┆ YEAR_OF_BIRTH ┆ INDEX_AGE   ┆ START_AGE   ┆ END_AGE       │
    │ ---                  ┆ --- ┆ ---       ┆ ---           ┆ ---         ┆ ---         ┆ ---           │
    │ str                  ┆ str ┆ str       ┆ str           ┆ i64 (days)  ┆ i64 (days)  ┆ i64 (days)    │
    ╞══════════════════════╪═════╪═══════════╪═══════════════╪═════════════╪═════════════╪═══════════════╡
    │ <anonymous 1>        ┆ M   ┆ WHITE     ┆ yyyy--mm-dd   ┆ dd          ┆ dd          ┆ dd            │
    │ <anonymous 2>        ┆ F   ┆ MISSING   ┆ yyyy--mm-dd   ┆ dd          ┆ dd          ┆ dd            │
    │ …                    ┆ …   ┆ …         ┆ …             ┆             ┆             ┆               │
    │ <anonymous N>        ┆ M   ┆ WHITE     ┆ yyyy--mm-dd   ┆ dd          ┆ dd  

## Diagnosis table
<div align="justify">
The diagnosis table has one row per diagnosis. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* EVENT: The categorical event
* AGE_AT_EVENT: Number if days between subject birth and the day of event
* VALUE: to be removed
* EVENT_TYPE: to be removed
</div>


<div align="center">
  TODO: 
</div>

* We could embed based on subgroups. For example each of these conditions are categorical, but they could be further sub categories. For example, diabetes can be one category and types further divide this.

In [6]:
print(f"`Diagnosis table` docstring:\n{CPRD.data.database.build_diagnosis_db.build_diagnosis_table.__doc__}")

`Diagnosis table` docstring:

    Build measurements and tests table in database

    Produced anonymized table:
    ┌──────────────────────┬───────┬──────────────┬──────────────┬────────────────────────────┐
    │ PRACTICE_PATIENT_ID  ┆ VALUE ┆ EVENT        ┆ AGE_AT_EVENT ┆ EVENT_TYPE                 │
    │ ---                  ┆ ---   ┆ ---          ┆ ---          ┆ ---                        │
    │ str                  ┆ f64   ┆ str          ┆ i64 (days)   ┆ str                        │
    ╞══════════════════════╪═══════╪══════════════╪══════════════╪════════════════════════════╡
    │ <anonymous 1>        ┆ null  ┆ HF           ┆ 11632        ┆ categorical                │
    │ <anonymous 2>        ┆ null  ┆ HF           ┆ 25635        ┆ categorical                │
    │ …                    ┆ …     ┆ …            ┆ …            ┆ …                          │
    │ <anonymous N>        ┆ null  ┆ FIBROMYALGIA ┆ 8546         ┆ categorical                │
    └──────────────────

## Measurements and tests table
<div align="justify">
The measurements table has one row per diagnosis. 

* PRACTICE_PATIENT_ID: The data owner identifier,
* EVENT: The categorical event
* AGE_AT_EVENT: Number if days between subject birth and the day of event
* VALUE: The measurement/test record
* EVENT_TYPE: whether the value is categorical (e.g. EVENT=="smoking", value="ex-smoker"), or continuous (e.g. EVENT=="bmi", value=23.3)
</div>


<div align="center">
  TODO: 
</div>



In [7]:
print(f"`Measurements and tests table` docstring:\n{CPRD.data.database.build_measurements_and_tests_db.build_measurements_table.__doc__}")

`Measurements and tests table` docstring:
 
    Build measurements and tests table in database

    Produced anonymized table:
    ┌──────────────────────┬───────┬──────────────────┬──────────────┬───────────────────────┐
    │ PRACTICE_PATIENT_ID  ┆ VALUE ┆ EVENT            ┆ AGE_AT_EVENT ┆ EVENT_TYPE            │
    │ ---                  ┆ ---   ┆ ---              ┆ ---          ┆ ---                   │
    │ str                  ┆ f64   ┆ str              ┆ i64 (days)   ┆ str                   │
    ╞══════════════════════╪═══════╪══════════════════╪══════════════╪═══════════════════════╡
    │ <anonymous 1>        ┆ 23.3  ┆ bmi              ┆ 10254        ┆ univariate_regression │
    │ <anonymous 1>        ┆ 24.1  ┆ bmi              ┆ 11829        ┆ univariate_regression │
    │ …                    ┆ …     ┆ …                ┆ …            ┆ …                     │
    │ <anonymous N>        ┆ 0.17  ┆ eosinophil_count ┆ 12016        ┆ univariate_regression │
    └─────────────

# Filtering

We can query the SQL tables separately to find the set of patients that fit the study criteria. 

First we can connect directly to the database using sqlite3

<div align="center">
  TODO: 
</div>

* The db shouldn't really be interfaced with like this, this is mostly for demonstration and a 

In [8]:
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()

Performing a look-up directly on the tables shows what diagnoses, measurements and tests are available

In [9]:
# Check what measurements are available
cursor.execute("SELECT DISTINCT EVENT FROM measurement_table")
measurements = [_item[0] for _item in cursor.fetchall()]
print("We have access to measurements and tests:\n * " + '\n * '.join(measurements))

# Check what diagnoses are available
cursor.execute("SELECT DISTINCT EVENT FROM diagnosis_table")
diagnoses = [_item[0] for _item in cursor.fetchall()]
print("\nAdditionally, we have access to diagnoses:\n * " + '\n * '.join(diagnoses))

We have access to measurements and tests:
 * bmi
 * hydroxyvitamin2
 * hydroxyvitamin3
 * aspartate_transam
 * serum_level
 * creatinine_ratio
 * basophil_count
 * blood_calcium
 * blood_urea
 * brain_natriuretic_peptide_level
 * calcium_adjusted_level
 * calculated_LDL_cholesterol_level
 * combined_total_vitamin_D2_and_D3_level
 * corrected_serum_calcium_level
 * diastolic_blood_pressure
 * eosinophil_count

Additionally, we have access to diagnoses:
 * HF
 * AF
 * ISCHAEMICSTROKE
 * STROKEUNSPECIFIED
 * STROKE_HAEMRGIC
 * HYPERTENSION
 * MINFARCTION
 * IHD_NOMI
 * PAD_STRICT
 * VALVULARDISEASES
 * AORTICANEURYSM
 * TYPE1DM
 * TYPE2DIABETES
 * CKDSTAGE3TO5
 * DEPRESSION
 * ANXIETY
 * BIPOLAR
 * EATINGDISORDERS
 * SCHIZOPHRENIAMM
 * AUTISM
 * ALCOHOLMISUSE
 * SUBSTANCEMISUSE
 * CHRONIC_LIVER_DISEASE_ALCOHOL
 * NAFLD
 * OTHER_CHRONIC_LIVER_DISEASE_OPTIMAL
 * ULCERATIVE_COLITIS
 * CROHNS_DISEASE
 * ALL_DEMENTIA
 * PARKINSONS
 * EPILEPSY
 * ALLCA_NOBCC_VFINAL
 * LYMPHOMA_PREVALENCE
 * LEU

Given some subset of these we can query the database to find those subjects who have at least one entry of a list of measurements, tests or diagnoses

In [10]:
identifiers1 = CPRD.data.database.queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = CPRD.data.database.queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

In [11]:
# For now, lets take only the first 10,000
N = np.min((len(all_identifiers), 10000))
print(f"Using N={N} random samples, from the available {len(all_identifiers)}")

identifiers = random.choices(all_identifiers, k=N)

Using N=10000 random samples, from the available 117102


# Make PyTorch dataloader

We can now initialise our DL friendly datasets. This is done within the DataModule initialisation. 

Initialisation takes one required argument, which is the set of patient `identifiers` we wish to build into our dataset. The remaining arguments are optional and explained in the docstring


<div align="center">
  TODO: 
</div>

* Add option of saving and loading initialised train/test/val splits

In [12]:
print(f"`FoundationalDataModule` docstring:\n{FoundationalDataModule.__doc__}")

`FoundationalDataModule` docstring:

    PyTorch-Lightning datamodule for foundational models
    
    ARGS:
        practice_patient_id (list[str])
            List of practice patient identifiers which satisfy study criteria.
            
    KWARGS:
        batch_size (int): 
        
        unk_freq_threshold (float). 
            Value between 0 and 1, controlling at what level of frequency rare tokens (equiv. conditions/measurements 
            with this tokenizer) are mapped to the UNK token. Used to reduce vocabulary size
            
        min_workers (int):

        load_event_stream (optional, str):
        
        save_event_stream (optional, str):
        
    


In [13]:
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer=config.tokenizer,
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold)

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{dm.tokenizer.vocab_size} dictionary elements")

print("\nDictionaries")
print(dm.tokenizer._itos)
# print(dm.tokenizer._stoi)

INFO:root:Building polars dataset
INFO:root:Using measurements
INFO:root:Using test/measurement standardisation method: normalise
INFO:root:Removing measurement and test outliers. Using three deviations from mean as cutoff
INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


8603 training samples
478 validation samples
478 test samples
90 dictionary elements

Dictionaries
{0: 'PAD', 1: 'UNK', 2: 'diastolic_blood_pressure', 3: 'bmi', 4: 'eosinophil_count', 5: 'basophil_count', 6: 'corrected_serum_calcium_level', 7: 'DEPRESSION', 8: 'serum_level', 9: 'calculated_LDL_cholesterol_level', 10: 'ANXIETY', 11: 'HYPERTENSION', 12: 'TYPE2DIABETES', 13: 'OSTEOARTHRITIS', 14: 'ASTHMA_PUSHASTHMA', 15: 'ATOPICECZEMA', 16: 'ALLERGICRHINITISCONJ', 17: 'aspartate_transam', 18: 'ANY_DEAFNESS_HEARING_LOSS', 19: 'PREVALENT_IBS', 20: 'IHD_NOMI', 21: 'ALCOHOLMISUSE', 22: 'ALLCA_NOBCC_VFINAL', 23: 'CKDSTAGE3TO5', 24: 'blood_urea', 25: 'PERIPHERAL_NEUROPATHY', 26: 'HYPOTHYROIDISM_DRAFT_V1', 27: 'COPD', 28: 'calcium_adjusted_level', 29: 'combined_total_vitamin_D2_and_D3_level', 30: 'AF', 31: 'HF', 32: 'PSORIASIS', 33: 'OSTEOPOROSIS', 34: 'GOUT', 35: 'SUBSTANCEMISUSE', 36: 'MINFARCTION', 37: 'STROKEUNSPECIFIED', 38: 'ALL_DEMENTIA', 39: 'hydroxyvitamin3', 40: 'hydroxyvitamin2', 41: 

## TODO: Standardisation documentation

In [14]:
dm.standardisation_dict

{'corrected_serum_calcium_level': (2.319731275014289, 0.12185859347306796),
 'basophil_count': (0.07133528604493906, 0.11177124586927319),
 'brain_natriuretic_peptide_level': (144.51785714285714, 284.01419945753855),
 'eosinophil_count': (0.2226722643918027, 0.1935410331107238),
 'calculated_LDL_cholesterol_level': (2.633089972735546, 1.0715495962742911),
 'calcium_adjusted_level': (2.3207865168539263, 0.1035893083134164),
 'aspartate_transam': (26.487053020961774, 15.683939299536211),
 'creatinine_ratio': (4.900059523809522, 8.418335029392356),
 'blood_calcium': (2.312631578947369, 0.1467742826876682),
 'hydroxyvitamin2': (3.263071895424836, 2.8410069286407222),
 'serum_level': (26.785693929894556, 18.686021657832168),
 'hydroxyvitamin3': (51.030952380952385, 31.920175536084418),
 'combined_total_vitamin_D2_and_D3_level': (56.50543130990417,
  30.3484062588742),
 'blood_urea': (6.375120772946863, 3.656222849264762),
 'diastolic_blood_pressure': (78.87509075567591, 11.749617841797654),

## Dataset

Each dataset within the DataModule is a ``cprd.data.dataset.dataset_polars.EventStreamDataset`` object. Within this class we process the SQL tables from a form in which each patient (static) or record (dianosis/measurement/test) has its own row, into a ragged form where each patient has their own row in every case. This does not change the form of the static table. However, the form of the diagnoses and measurements/tests changes - and during the same update we combine these frames to one frame which contains dynamic information.

This is all wrapped in the method ``fit()``, but we can see the separate steps below:

In [15]:
print(f"`_load_dynamic` docstring:\n{CPRD.data.dataset.dataset_polars.EventStreamDataset._load_dynamic.__doc__}")

`_load_dynamic` docstring:
    
        Load and merge dynamic tables from SQL
        
        ARGS:
            practice_patient_id (list[str])
                List of practice patient identifiers which satisfy study criteria.
            
        KWARGS:
            include_measurements: bool = True,
                Flag of whether to include measurements and tests in polars dataframe
            include_diagnoses: bool = True
                Flag of whether to include diagnoses in polars dataframe
        
        RETURNS:
            Polars lazy frame, of the form:
            ┌──────────────────────┬───────────────────────┬───────────────────────────────────┬─────────────────────────┬───────────────────────────────────┐
            │ PRACTICE_PATIENT_ID  ┆ VALUE                 ┆ EVENT                             ┆ AGE_AT_EVENT            ┆ EVENT_TYPE                        │
            │ ---                  ┆ ---                   ┆ ---                               ┆ ---     

This is then combined into one DL friendly representation, which our DataModule uses

In [16]:
print(f"`_build_DL_representation` docstring:\n{CPRD.data.dataset.dataset_polars.EventStreamDataset._build_DL_representation.__doc__}")

`_build_DL_representation` docstring:

        Build the DL-friendly representation in polars given the list of `practice_patient_id`s which fit study criteria
                
        ARGS:
            practice_patient_ids (list[str])
                List of practice patient identifiers which satisfy study criteria.
            
        KWARGS:
        
        
        RETURNS:
            Polars lazy frame, of the (anonymized) form:
            ┌──────────────────────┬─────┬─────────────┬───────────────┬──────────────────────┬─────────────────────────┬─────────────────────┬────────────────────────┐
            │ PRACTICE_PATIENT_ID  ┆ SEX ┆ ETHNICITY   ┆ YEAR_OF_BIRTH ┆ VALUE                ┆ EVENT                   ┆ AGE_AT_EVENT        ┆ EVENT_TYPE             │
            │ ---                  ┆ --- ┆ ---         ┆ ---           ┆ ---                  ┆ ---                     ┆ ---                 ┆ ---                    │
            │ str                  ┆ str ┆ str       

The ``__getitem__`` within the dataset class then retrieves rows of this table and performs the relevant processing such as tokenization. Speed ups could be obtained by pre-processing this tokenization.

As these are ragged lists for memory efficiency, we also provide a ``collate_fn`` in the DataModule. This performs padding - and by performing padding this way we can avoid excessive padding.

# Data iterators

## Using \_\_getitem__

In [17]:
# print(dm.train_set[0])
print("A single element of the training dataset contains:\n  * " + '\n  * '.join(dm.train_set[0].keys()))

sample = np.random.randint(len(dm.train_set))

for k, v in dm.train_set[sample].items():
    print(f"\n{k}: {v}")
    if k == "tokens":
        print(f"... decoding to `{dm.decode(v.tolist())}`")


A single element of the training dataset contains:
  * identifier
  * tokens
  * ages
  * values

identifier: p20692_1246521620692

tokens: tensor([ 7, 10,  3,  5,  4,  2,  2,  5,  4,  5,  4,  4,  2,  2,  5,  4,  5,  4,
         2,  5,  4,  5,  4])
... decoding to `DEPRESSION ANXIETY bmi basophil_count eosinophil_count diastolic_blood_pressure diastolic_blood_pressure basophil_count eosinophil_count basophil_count eosinophil_count eosinophil_count diastolic_blood_pressure diastolic_blood_pressure basophil_count eosinophil_count basophil_count eosinophil_count diastolic_blood_pressure basophil_count eosinophil_count basophil_count eosinophil_count`

ages: tensor([12896, 12896, 13047, 13355, 13355, 13374, 13653, 13658, 13658, 13782,
        13782, 13806, 14057, 14097, 14136, 14136, 14345, 14345, 14442, 14696,
        14696, 14854, 14854])

values: tensor([    nan,     nan, -1.1811, -0.3698,  0.0379, -0.3298, -1.1809, -0.2804,
         0.1929, -0.1014,  0.3995,  0.1929, -0.9256, -1.6064, 

## Collating into batches

In [34]:
for idx, batch in enumerate(dm.train_dataloader()):
    break
print("A sample from the dataloader batch gives:\n  * " + '\n  * '.join(batch.keys()))

k = 10
print(f"\nFor example, a single sample from a collated batch gives (viewing only first {k} elements of each padded sequence):")

# ages
print(f"\n* The time of event (in days since birth) of events:\n\t {batch['ages'][0, :k]}")

# tokens
print(f"\n* The tokenized and padded (within batch) block of a patient's sequence of events:\n\t {batch['tokens'][0,:k]}")
print(f"... which can be decoded. E.g. the above tokens decode to:\n\t {dm.decode(batch['tokens'][0,:k].tolist())}`")

# values
print(f"\n* When we use a {config.tokenizer} tokenizer, the corresponding values are given as:\n\t {batch['values'][0,:k]}")

# attention mask
#   B = batch size, L = Longest sequence length within batch
print(f"\n* The attention mask tells us which values are padded entries, so we can then mask them alongside the self-attention masking" + 
      f"\n\t This is of shape (B={batch['attention_mask'].shape[0]} * L={batch['attention_mask'].shape[1]})" +
      f"\n{batch['attention_mask']}")



A sample from the dataloader batch gives:
  * tokens
  * ages
  * values
  * attention_mask

For example, a single sample from a collated batch gives (viewing only first 10 elements of each padded sequence):

* The time of event (in days since birth) of events:
	 tensor([ 7630,  8532, 10069, 10682, 10880, 11021, 11601, 14146, 15031, 15114])

* The tokenized and padded (within batch) block of a patient's sequence of events:
	 tensor([35,  3,  3, 10,  4, 15,  7,  3, 21,  2])
... which can be decoded. E.g. the above tokens decode to:
	 SUBSTANCEMISUSE bmi bmi ANXIETY eosinophil_count ATOPICECZEMA DEPRESSION bmi ALCOHOLMISUSE diastolic_blood_pressure`

* When we use a tabular tokenizer, the corresponding values are given as:
	 tensor([    nan, -0.8156, -1.0546,     nan, -0.1171,     nan,     nan, -0.7031,
            nan,  1.1170])

* The attention mask tells us which values are padded entries, so we can then mask them alongside the self-attention masking
	 This is of shape (B=64 * L=64)
t

# Tabular vs. non-tabular tokenization

We can also use a non-tabular tokenizer.

In [35]:
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="non-tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold)

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{dm.tokenizer.vocab_size} dictionary elements")

print("\nDictionaries")
print(dm.tokenizer._itos)
# print(dm.tokenizer._stoi)

INFO:root:Building polars dataset
INFO:root:Using measurements
INFO:root:Using test/measurement standardisation method: normalise
INFO:root:Removing measurement and test outliers. Using three deviations from mean as cutoff
INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using non-tabular tokenizer


8603 training samples
478 validation samples
478 test samples
101 dictionary elements

Dictionaries
{0: 'PAD', 1: 'UNK', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9', 12: '.', 13: 'diastolic_blood_pressure', 14: 'bmi', 15: 'eosinophil_count', 16: 'basophil_count', 17: 'corrected_serum_calcium_level', 18: 'DEPRESSION', 19: 'serum_level', 20: 'calculated_LDL_cholesterol_level', 21: 'ANXIETY', 22: 'HYPERTENSION', 23: 'TYPE2DIABETES', 24: 'OSTEOARTHRITIS', 25: 'ASTHMA_PUSHASTHMA', 26: 'ATOPICECZEMA', 27: 'ALLERGICRHINITISCONJ', 28: 'aspartate_transam', 29: 'ANY_DEAFNESS_HEARING_LOSS', 30: 'PREVALENT_IBS', 31: 'IHD_NOMI', 32: 'ALCOHOLMISUSE', 33: 'ALLCA_NOBCC_VFINAL', 34: 'CKDSTAGE3TO5', 35: 'blood_urea', 36: 'PERIPHERAL_NEUROPATHY', 37: 'HYPOTHYROIDISM_DRAFT_V1', 38: 'calcium_adjusted_level', 39: 'COPD', 40: 'combined_total_vitamin_D2_and_D3_level', 41: 'AF', 42: 'PSORIASIS', 43: 'HF', 44: 'GOUT', 45: 'OSTEOPOROSIS', 46: 'SUBSTANCEMISUSE', 47: 'MINFARCTI

In [37]:
for idx, batch in enumerate(dm.train_dataloader()):
    break
print("A sample from the dataloader batch gives:\n  * " + '\n  * '.join(batch.keys()))

k = 10
print(f"\nRepeating the above example, a single sample from a collated batch gives (viewing only first {k} elements of each sequence):")

# ages
print(f"\n* The time of event (in days since birth) of events:\n\t {batch['ages'][0, :k]}")

# tokens
print(f"\n* The tokenized and padded (within batch) block of a patient's sequence of events:\n\t {batch['tokens'][0,:k]}")
print(f"... which can be decoded. E.g. the above tokens decode to:\n\t {dm.decode(batch['tokens'][0,:k].tolist())}`")

# values
print(f"\n* When we use a {config.tokenizer} tokenizer, the corresponding values are given as:\n\t {batch['values'][0,:k]}")

# attention mask
#   B = batch size, L = Longest sequence length within batch
print(f"\n* The attention mask tells us which values are padded entries, so we can then mask them alongside the self-attention masking" + 
      f"\n\t This is of shape (B={batch['attention_mask'].shape[0]} * L={batch['attention_mask'].shape[1]})" +
      f"\n{batch['attention_mask']}")



A sample from the dataloader batch gives:
  * tokens
  * ages
  * values
  * attention_mask

Repeating the above example, a single sample from a collated batch gives (viewing only first 10 elements of each sequence):

* The time of event (in days since birth) of events:
	 tensor([7004, 7009, 7009, 7009, 7009, 7009, 7009, 7009, 7009, 7009])

* The tokenized and padded (within batch) block of a patient's sequence of events:
	 tensor([ 6, 15,  2, 12,  4, 11,  8,  4,  2,  6])
... which can be decoded. E.g. the above tokens decode to:
	 4 eosinophil_count 0 . 2 9 6 2 0 4`

* When we use a tabular tokenizer, the corresponding values are given as:
	 tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

* The attention mask tells us which values are padded entries, so we can then mask them alongside the self-attention masking
	 This is of shape (B=64 * L=64)
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1