## CPRD GPT

In [1]:
import pytorch_lightning 
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.gpt_pico.transformer import GPTLanguageModel
from CPRD.src.models.gpt_simple.task_heads import GPTModelForCausalLM

# TODO:
# mask padding tokens
# replace boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for debug errors

cuda


## Build configurations

In [2]:
# Set GPT config to be equivalent
@dataclass
class DemoConfig:
    block_size: int = 256             # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    pos_encoding: str = "index-embedding"                 # Manually adding later
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    

config = DemoConfig()

# optim hyperparameters
batch_size = 64
eval_interval = 5
learning_rate = 3e-4
epochs = 15

## Demonstrate on a reduced cohort

In [3]:
from CPRD.data.database import queries

PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()

# # Check what measurements are available
# cursor.execute("SELECT DISTINCT * FROM measurement_table")
# measurements = cursor.fetchall()
# print(measurements)

# Check what diagnoses are available
# cursor.execute("SELECT DISTINCT * FROM diagnosis_table")
# diagnoses = cursor.fetchall()
# print(diagnoses)

# Get a list of patients which fit a reduced set of criterion
identifiers1 = queries.query_measurement(["bmi", "hydroxyvitamin2", "hydroxyvitamin3"], cursor)         #  
identifiers2 = queries.query_diagnosis([ "FIBROMYALGIA", "HF"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

print(identifiers[:10])
print(len(identifiers))

['p20485_2544331320485', 'p20426_632238320426', 'p20415_270000020415', 'p20655_1200089720655', 'p20508_940887220508', 'p20485_2547686820485', 'p20524_6918058520524', 'p20758_2523011320758', 'p20495_2858389320495', 'p20508_940152320508']
16327


## Make dataloader


In [4]:
dm = FoundationalDataModule(identifiers=identifiers, batch_size=batch_size, max_seq_length=config.block_size, unk_freq_threshold=0)

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")

Building DL-friendly representation
Dropping samples with no temporal events
14694 training samples
817 validation samples
816 test samples


#### Visualise a batch

In [5]:
for idx, batch in enumerate(dm.train_dataloader()):
    break
print("A sample from the dataloader batch gives:")
print(f"\nThe position index of inputs and targets: \ninputs: {batch['input_pos'][0,:10]}  \ntargets: {batch['target_pos'][0,:10]}")
print(f"\nThe time of event (in days since birth) of event of inputs and targets: \ninputs: {batch['input_ages'][0,:10]}  \ntargets: {batch['target_ages'][0,:10]}")
print(f"\nThe shifted next-step, tokenized and padded (within batch), representation from a block of a patient's sequence for events: \ninputs: {batch['input_ids'][0,:10]} \ntargets: {batch['target_ids'][0,:10]}")
print(f"\nWhich can be decoded. E.g. first sample's first 10 block tokens: \ninputs: {dm.decode(batch['input_ids'][0,:10].tolist())}  \ntargets: {dm.decode(batch['target_ids'][0,:10].tolist())}")
print(f"\nThe attention mask ({batch['attention_mask'].shape}) for padding: \n{batch['attention_mask']}")


A sample from the dataloader batch gives:

The position index of inputs and targets: 
inputs: tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])  
targets: tensor([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

The time of event (in days since birth) of event of inputs and targets: 
inputs: tensor([20639, 23842, 23842, 23842, 23842, 23842, 23849, 23849, 23849, 23849])  
targets: tensor([23842, 23842, 23842, 23842, 23842, 23849, 23849, 23849, 23849, 23849])

The shifted next-step tokenized and padded (within batch) representation from a block of a patient's sequence for events: 
inputs: tensor([22, 13, 11,  3, 12,  2, 13, 10,  2, 12]) 
targets: tensor([13, 11,  3, 12,  2, 13, 10,  2, 12,  2])

Which can be decoded. E.g. first sample's first 10 block tokens: 
inputs: OSTEOARTHRITIS diastolic_blood_pressure 9 1 . 0 diastolic_blood_pressure 8 0 .  
targets: diastolic_blood_pressure 9 1 . 0 diastolic_blood_pressure 8 0 . 0

The attention mask for padding (shapetorch.Size([64, 256])): 
tensor([[1, 

In [6]:
vocab_size = dm.train_set.tokenizer.vocab_size

print(vocab_size)
print(dm.train_set.tokenizer._itos)

101
{0: 'PAD', 1: 'UNK', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9', 12: '.', 13: 'diastolic_blood_pressure', 14: 'eosinophil_count', 15: 'bmi', 16: 'basophil_count', 17: 'corrected_serum_calcium_level', 18: 'serum_level', 19: 'calculated_LDL_cholesterol_level', 20: 'HF', 21: 'HYPERTENSION', 22: 'OSTEOARTHRITIS', 23: 'IHD_NOMI', 24: 'aspartate_transam', 25: 'DEPRESSION', 26: 'AF', 27: 'CKDSTAGE3TO5', 28: 'ANY_DEAFNESS_HEARING_LOSS', 29: 'ASTHMA_PUSHASTHMA', 30: 'ANXIETY', 31: 'TYPE2DIABETES', 32: 'ATOPICECZEMA', 33: 'blood_urea', 34: 'MINFARCTION', 35: 'FIBROMYALGIA', 36: 'ALLCA_NOBCC_VFINAL', 37: 'COPD', 38: 'calcium_adjusted_level', 39: 'VALVULARDISEASES', 40: 'ALLERGICRHINITISCONJ', 41: 'GOUT', 42: 'HYPOTHYROIDISM_DRAFT_V1', 43: 'PERIPHERAL_NEUROPATHY', 44: 'OSTEOPOROSIS', 45: 'combined_total_vitamin_D2_and_D3_level', 46: 'PREVALENT_IBS', 47: 'STROKEUNSPECIFIED', 48: 'PAD_STRICT', 49: 'ALL_DEMENTIA', 50: 'ALCOHOLMISUSE', 51: 'PSORIASIS', 52: 'hy

In [14]:
models = []

# Baseline model to test my changes against
models.append(GPTLanguageModel(config, vocab_size).to(device))

# My development model
pos_encodings = ["index-embedding", "index-encoding", "temporal-encoding"]
for pe in pos_encodings:
    config = DemoConfig()
    config.pos_encoding = pe
    models.append(GPTModelForCausalLM(config, vocab_size).to(device))

m_names = ["kaparthy benchmark"] + pos_encodings

In [15]:
loss_curves_train = [[] for _ in models]
loss_curves_val = [[] for _ in models]

In [16]:
for m_idx, model in enumerate(models):
    model = model.to(device)

    # print the number of parameters in the model
    print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        epoch_loss = 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            logits, loss = model(batch['input_ids'].to(device),
                                 positions=batch['input_pos'].to(device),
                                 ages=batch['input_ages'].to(device),
                                 targets=batch['target_ids'].to(device),
                                 attention_mask=batch['attention_mask'].to(device)
                                 )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)

        # every once in a while evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % eval_interval == 0 or epoch == epochs - 1:
                val_loss = 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, loss = model(batch['input_ids'].to(device), 
                                    positions=batch['input_pos'].to(device),
                                    ages=batch['input_ages'].to(device), 
                                    targets=batch['target_ids'].to(device),
                                    attention_mask=batch['attention_mask'].to(device)   
                                   )
                    val_loss += loss.item()
                val_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}. Val loss {val_loss:.2f}")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
                
    prompt = ["DEPRESSION"]
    context = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    fut_tokens, fut_positions, fut_ages = model.generate(context, max_new_tokens=30)
    fut_words = dm.decode(fut_tokens[0].tolist())
    print(fut_words)


10.816613 M parameters
Epoch 0:	Train loss 1.50. Val loss 1.21
Epoch 5:	Train loss 0.85. Val loss 0.92
Epoch 10:	Train loss 0.83. Val loss 0.90
Epoch 14:	Train loss 0.83. Val loss 0.90
DEPRESSION HYPOTHYROIDISM_DRAFT_V1 bmi 3 7 . 3 diastolic_blood_pressure 8 6 . 0 basophil_count 0 . 1 calculated_LDL_cholesterol_level 4 . 6 corrected_serum_calcium_level 2 . 3 3 eosinophil_count 0 . 5 diastolic_blood_pressure 7
10.777829 M parameters
Epoch 0:	Train loss 1.36. Val loss 1.04
Epoch 5:	Train loss 0.84. Val loss 0.90
Epoch 10:	Train loss 0.83. Val loss 0.90
Epoch 14:	Train loss 0.82. Val loss 0.90
DEPRESSION SUBSTANCEMISUSE bmi 1 8 . 8 diastolic_blood_pressure 8 0 . 0 diastolic_blood_pressure 7 0 . 0 OSTEOPOROSIS bmi 2 0 . 8 eosinophil_count 0 . 2 diastolic_blood_pressure 8 1 .
10.679525 M parameters
Epoch 0:	Train loss 1.91. Val loss 1.50
Epoch 5:	Train loss 0.97. Val loss 1.05
Epoch 10:	Train loss 0.88. Val loss 0.95
Epoch 14:	Train loss 0.86. Val loss 0.92
DEPRESSION bmi 2 3 . 8 diastolic_

                                but this head has no way of sampling age at next event.
                                Using 50 days as intervals


Epoch 14:	Train loss 0.81. Val loss 0.93
DEPRESSION ANXIETY ENDOMETRIOSIS_ADENOMYOSIS_V2 diastolic_blood_pressure 6 2 2 calculated_LDL_cholesterol_level 2 corrected_serum_calcium_level 2 basophil_count 0 bmi 3 2 5 FIBROMYALGIA bmi 2 eosinophil_count 0 diastolic_blood_pressure 8 basophil_count 0 diastolic_blood_pressure 7 eosinophil_count 0 diastolic_blood_pressure


In [17]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()
cols = ["k", "r", "b", "y"]
for m_idx, _ in enumerate(models):
    plt.plot(np.linspace(0,epochs,len(loss_curves_train[m_idx])), loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    plt.plot(np.linspace(0,epochs,len(loss_curves_val[m_idx])), loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])

    # plt.plot(np.arange(v.shape[0]), v, label=f"{m_idx}-val", c=cols[m_idx])
plt.legend()
plt.savefig("figs/loss_curves.png")

In [11]:
# prompt = "bmi 1 8 . 6 bmi 3 0 . 6"
# context = torch.from_numpy(np.array(encode(prompt)).reshape((1,-1)))
# print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))