## GPT - demo on subset of CPRD

In [1]:
import pytorch_lightning 
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.gpt_pico.transformer import GPTLanguageModel
from CPRD.src.models.gpt_simple.task_heads import GPTModelForCausalLM

# TODO:
# replace boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for debug errors

cuda


## Build configurations

In [2]:
# Set GPT config to be equivalent
@dataclass
class DemoConfig:
    block_size: int = 256             # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    pos_encoding: str = "index-embedding"                 # Manually adding later
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 10
    
opt = OptConfig()

## Demonstrate on a reduced cohort

In [3]:
from CPRD.data.database import queries

PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()

# # Check what measurements are available
# cursor.execute("SELECT DISTINCT * FROM measurement_table")
# measurements = cursor.fetchall()
# print(measurements)

# Check what diagnoses are available
# cursor.execute("SELECT DISTINCT * FROM diagnosis_table")
# diagnoses = cursor.fetchall()
# print(diagnoses)

# Get a list of patients which fit a reduced set of criterion
identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

In [4]:
# Lets take only the first N for faster development
N = np.min((len(all_identifiers), 10000))
print(f"Using N={N} random samples, from the available {len(all_identifiers)}")

identifiers = random.choices(all_identifiers, k=N)

Using N=10000 random samples, from the available 117102


## Make dataloader


In [5]:
dm = FoundationalDataModule(identifiers=identifiers,
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold)

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")

INFO:root:Building DL-friendly representation
INFO:root:Dropping samples with no dynamic events


8640 training samples
481 validation samples
480 test samples


#### Visualise a sample

In [6]:
# print(dm.train_set[0])
print("A single element of the dataset contains:\n  * " + '\n  * '.join(dm.train_set[0].keys()))
print(dm.train_set[0])

A single element of the dataset contains:
  * identifier
  * sex
  * ethnicity
  * year_of_birth
  * input_ids
  * input_pos
  * input_ages
  * target_ids
  * target_pos
  * target_ages
{'identifier': 'p20389_944538320389', 'sex': 'M', 'ethnicity': 'WHITE', 'year_of_birth': '1972-07-15', 'input_ids': tensor([18, 14,  4,  6, 12,  5, 13,  8,  7, 12,  2, 45, 14,  4,  9, 12,  4, 13,
         9, 10, 12,  2, 15,  2, 12,  7, 17,  4, 12,  3, 10, 15,  2, 12,  5, 13,
         9,  4, 12,  2, 14,  4, 10, 12,  5, 16,  2, 12,  3, 15,  2, 12,  7, 14,
         5,  3, 12,  5, 13,  9, 11, 12]), 'input_pos': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61]), 'input_ages': tensor([12861, 12990, 12990, 12990, 12990, 12990, 12990, 12990, 12990, 12990,
        12990, 12990, 13242, 

#### Visualise a batch

In [10]:
for idx, batch in enumerate(dm.train_dataloader()):
    break
print("A sample from the dataloader batch gives:")
print("Batch Dataframe Columns:\n  * " + '\n  * '.join(batch.keys()))
print(f"\nThe position index of inputs and targets: \ninputs: {batch['input_pos'][0,:10]}  \ntargets: {batch['target_pos'][0,:10]}")
print(f"\nThe time of event (in days since birth) of event of inputs and targets: \ninputs: {batch['input_ages'][0,:10]}  \ntargets: {batch['target_ages'][0,:10]}")
print(f"\nThe shifted next-step, tokenized and padded (within batch), representation from a block of a patient's sequence for events: \ninputs: {batch['input_ids'][0,:10]} \ntargets: {batch['target_ids'][0,:10]}")
print(f"\nWhich can be decoded. E.g. first sample's first 10 block tokens: \ninputs: {dm.decode(batch['input_ids'][0,:10].tolist())}  \ntargets: {dm.decode(batch['target_ids'][0,:10].tolist())}")
print(f"\nThe attention mask ({batch['attention_mask'].shape}) for padding: \n{batch['attention_mask']}")


A sample from the dataloader batch gives:
Batch Dataframe Columns:
  * input_ids
  * target_ids
  * input_pos
  * target_pos
  * input_ages
  * target_ages
  * attention_mask

The position index of inputs and targets: 
inputs: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  
targets: tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

The time of event (in days since birth) of event of inputs and targets: 
inputs: tensor([8034, 8117, 8117, 8117, 8117, 8117, 8117, 8117, 8117, 8117])  
targets: tensor([8117, 8117, 8117, 8117, 8117, 8117, 8117, 8117, 8117, 8117])

The shifted next-step, tokenized and padded (within batch), representation from a block of a patient's sequence for events: 
inputs: tensor([18, 14,  4,  5, 12,  7, 13,  9,  2, 12]) 
targets: tensor([14,  4,  5, 12,  7, 13,  9,  2, 12,  2])

Which can be decoded. E.g. first sample's first 10 block tokens: 
inputs: DEPRESSION bmi 2 3 . 5 diastolic_blood_pressure 7 0 .  
targets: bmi 2 3 . 5 diastolic_blood_pressure 7 0 . 0

The attention m

In [11]:
vocab_size = dm.train_set.tokenizer.vocab_size

print(vocab_size)
print(dm.train_set.tokenizer._itos)

100
{0: 'PAD', 1: 'UNK', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9', 12: '.', 13: 'diastolic_blood_pressure', 14: 'bmi', 15: 'eosinophil_count', 16: 'basophil_count', 17: 'corrected_serum_calcium_level', 18: 'DEPRESSION', 19: 'serum_level', 20: 'calculated_LDL_cholesterol_level', 21: 'ANXIETY', 22: 'HYPERTENSION', 23: 'TYPE2DIABETES', 24: 'ASTHMA_PUSHASTHMA', 25: 'OSTEOARTHRITIS', 26: 'ATOPICECZEMA', 27: 'ANY_DEAFNESS_HEARING_LOSS', 28: 'ALLERGICRHINITISCONJ', 29: 'aspartate_transam', 30: 'PREVALENT_IBS', 31: 'ALLCA_NOBCC_VFINAL', 32: 'ALCOHOLMISUSE', 33: 'IHD_NOMI', 34: 'CKDSTAGE3TO5', 35: 'blood_urea', 36: 'PERIPHERAL_NEUROPATHY', 37: 'calcium_adjusted_level', 38: 'HYPOTHYROIDISM_DRAFT_V1', 39: 'COPD', 40: 'PSORIASIS', 41: 'AF', 42: 'combined_total_vitamin_D2_and_D3_level', 43: 'OSTEOPOROSIS', 44: 'HF', 45: 'SUBSTANCEMISUSE', 46: 'GOUT', 47: 'MINFARCTION', 48: 'STROKEUNSPECIFIED', 49: 'ALL_DEMENTIA', 50: 'hydroxyvitamin3', 51: 'hydroxyvitamin2', 

## Create models and train

In [12]:
models = []

# Baseline model to test my changes against
#   Note: this benchmark model uses index position along the batch
models.append(GPTLanguageModel(config, vocab_size).to(device))

# My development model
# Handle positional vs. temporal encoding/embedding
# Cases: 
#     index-embedding:       use index position along the batch
#     index-encoding:        use index position along the batch
#     temporal-encoding:     use age along a patient's timeline
pos_encodings = ["index-embedding", "index-encoding", "temporal-encoding"]
for pe in pos_encodings:
    config = DemoConfig()
    config.pos_encoding = pe
    models.append(GPTModelForCausalLM(config, vocab_size).to(device))

m_names = ["kaparthy benchmark"] + pos_encodings

INFO:root:Using Positional Embedding. This module uses the index position of an event within the block of events.
INFO:root:Using Positional Encoding. This module uses the index position of an event within the block of events.
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.


In [13]:
loss_curves_train = [[] for _ in models]
loss_curves_val = [[] for _ in models]

In [14]:
for m_idx, model in enumerate(models):
    model = model.to(device)

    # print the number of parameters in the model
    print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, best_iter = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss = 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            logits, loss = model(batch['input_ids'].to(device),
                                 ages=batch['input_ages'].to(device),
                                 targets=batch['target_ids'].to(device),
                                 attention_mask=batch['attention_mask'].to(device)
                                 )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss = 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, loss = model(batch['input_ids'].to(device), 
                                    ages=batch['input_ages'].to(device), 
                                    targets=batch['target_ids'].to(device),
                                    attention_mask=batch['attention_mask'].to(device)   
                                   )
                    val_loss += loss.item()
                val_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}. Val loss {val_loss:.2f}")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                best_iter += 1
                if best_iter > 2:
                    break
            else:
                best_val = val_loss
                best_iter = 0
                
    prompt = ["DEPRESSION"]
    context = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    fut_tokens, fut_ages = model.generate(context, max_new_tokens=30)
    fut_words = dm.decode(fut_tokens[0].tolist())
    print(fut_words)


10.815844 M parameters
Epoch 0:	Train loss 1.20. Val loss 1.21
Epoch 1:	Train loss 0.92. Val loss 0.88
Epoch 2:	Train loss 0.71. Val loss 0.78
Epoch 3:	Train loss 0.66. Val loss 0.76
Epoch 4:	Train loss 0.64. Val loss 0.73
Epoch 5:	Train loss 0.63. Val loss 0.72
Epoch 6:	Train loss 0.62. Val loss 0.71
Epoch 7:	Train loss 0.61. Val loss 0.71
Epoch 8:	Train loss 0.61. Val loss 0.71
Epoch 9:	Train loss 0.61. Val loss 0.70
DEPRESSION diastolic_blood_pressure 7 0 . 0 bmi 2 0 . 5 diastolic_blood_pressure 6 0 . 0 POLYCYSTIC_OVARIAN_SYNDROME_PCOS diastolic_blood_pressure 7 0 . 0 ANY_DEAFNESS_HEARING_LOSS diastolic_blood_pressure 7 0 . 0 eosinophil_count 0 .
10.777444 M parameters
Epoch 0:	Train loss 1.21. Val loss 1.00
Epoch 1:	Train loss 0.74. Val loss 0.78
Epoch 2:	Train loss 0.66. Val loss 0.73
Epoch 3:	Train loss 0.63. Val loss 0.71
Epoch 4:	Train loss 0.61. Val loss 0.71
Epoch 5:	Train loss 0.61. Val loss 0.70
Epoch 6:	Train loss 0.61. Val loss 0.70
Epoch 7:	Train loss 0.60. Val loss 0.70

                                but this head has no way of sampling age at next event.
                                Using 50 days as intervals


Epoch 9:	Train loss 0.63. Val loss 0.75
DEPRESSION diastolic_blood_pressure 8 PERIPHERAL_NEUROPATHY diastolic_blood_pressure 6 basophil_count 0 COPD diastolic_blood_pressure 4 diastolic_blood_pressure 7 POLYCYSTIC_OVARIAN_SYNDROME_PCOS diastolic_blood_pressure 6 diastolic_blood_pressure 7 basophil_count 0 diastolic_blood_pressure 7 bmi 1 basophil_count 0 diastolic_blood_pressure 8 diastolic_blood_pressure 7 diastolic_blood_pressure


In [15]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig("figs/loss_curves.png")

## Prompt testing

Probability of type II diabetes before and after a type I diagnosis

keys: 

    70: 'TYPE1DM'
    31: 'TYPE2DIABETES'

Small context comparison, high bmi and blood pressure vs low for diabetes risk

In [16]:
low_risk_prompt = ["bmi", "2", "2", ".", "5", "diastolic_blood_pressure", "7", "9", ".", "0"]
high_risk_prompt = ["bmi", "3", "7", ".", "5", "diastolic_blood_pressure", "9", "9", ".", "0"]
ages_in_years = [18, 19, 19, 19, 19, 20, 20, 20, 20, 20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

In [17]:
prompts, ages, desc = [], [], []

desc.append("Control: Low risk")
prompts.append(low_risk_prompt)
ages.append(ages_in_years)

desc.append("Control: High risk")
prompts.append(high_risk_prompt)
ages.append(ages_in_years)

desc.append("Control: Low risk + depression")
prompts.append(["DEPRESSION"] + low_risk_prompt)
ages.append([17] + ages_in_years)

desc.append("Low risk context: Type 1 diagnosis in prompt")
prompts.append(["TYPE1DM"] + low_risk_prompt)
ages.append([17] + ages_in_years)

desc.append("Low risk context: Type 1I diagnosis in prompt")
prompts.append(["TYPE2DIABETES"] + low_risk_prompt)
ages.append([17] + ages_in_years)

for model_idx in range(len(pos_encodings)+1):
    print(f"\n\nMODEL_IDX {model_idx}\n==================")
    
    for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
        print(f"\n{desc[p_idx]}: \n\t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        lgts, _ = models[model_idx](encoded_prompt,
                                    ages=to_days(age))
        probs = torch.nn.functional.softmax(lgts, dim=2)
        print(f"probability of type I diabetes {100*float(probs[0, 0, 70].cpu().detach().numpy()):.4f}%")
        print(f"probability of type II diabetes {100*float(probs[0, 0, 31].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



MODEL_IDX 0

Control: Low risk: 
	 (bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0003%
probability of type II diabetes 0.0008%

Control: High risk: 
	 (bmi,3,7,.,5,diastolic_blood_pressure,9,9,.,0): 
probability of type I diabetes 0.0003%
probability of type II diabetes 0.0008%

Control: Low risk + depression: 
	 (DEPRESSION,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.2829%
probability of type II diabetes 0.7801%

Low risk context: Type 1 diagnosis in prompt: 
	 (TYPE1DM,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.1048%
probability of type II diabetes 0.4623%

Low risk context: Type 1I diagnosis in prompt: 
	 (TYPE2DIABETES,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.1897%
probability of type II diabetes 1.2209%


MODEL_IDX 1

Control: Low risk: 
	 (bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0010%
prob