## GPT - demo on subset of CPRD

In [1]:
import pytorch_lightning 
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.gpt_pico.transformer import GPTLanguageModel
from CPRD.src.models.gpt_simple.task_heads import GPTModelForCausalLM

# TODO:
# replace boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for debug errors

cuda


## Build configurations

In [2]:
# Set GPT config to be equivalent
@dataclass
class DemoConfig:
    block_size: int = 256             # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    pos_encoding: str = None          # Manually adding later
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    tabular = False

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 10
    
opt = OptConfig()

## Demonstrate on a reduced cohort

In [3]:
from CPRD.data.database import queries

PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()

# Get a list of patients which fit a reduced set of criterion
identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

In [4]:
# Lets take only the first N for faster development
N = np.min((len(all_identifiers), 10000))
print(f"Using N={N} random samples, from the available {len(all_identifiers)}")

identifiers = random.choices(all_identifiers, k=N)

Using N=10000 random samples, from the available 117102


## Make dataloader


In [5]:
dm = FoundationalDataModule(identifiers=identifiers,
                            tabular=config.tabular,
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold)
vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building DL-friendly representation
INFO:root:Dropping samples with no dynamic events


8615 training samples
479 validation samples
479 test samples
100 vocab elements


#### Visualise a sample

In [6]:
# print(dm.train_set[0])
print("A single element of the dataset contains:\n  * " + '\n  * '.join(dm.train_set[0].keys()))

for k, v in dm.train_set[0].items():
    print(f"\n{k}: {v}")
    if k == "tokens":
        print(f"... decoding to `{dm.decode(v.tolist())}`")


A single element of the dataset contains:
  * identifier
  * in_tokens
  * in_ages
  * in_values
  * target_tokens
  * target_ages
  * target_values

identifier: p20389_944530620389

in_tokens: tensor([18, 14,  4,  3, 12,  4, 13,  9,  2, 12,  2, 13,  8, 11, 12,  2, 13,  7,
         9, 12,  2, 26, 13,  8,  9, 12,  2, 13,  9,  9, 12])

in_ages: tensor([10046, 11609, 11609, 11609, 11609, 11609, 11609, 11609, 11609, 11609,
        11609, 11738, 11738, 11738, 11738, 11738, 11748, 11748, 11748, 11748,
        11748, 11826, 12161, 12161, 12161, 12161, 12161, 12392, 12392, 12392,
        12392])

in_values: tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan])

target_tokens: tensor([14,  4,  3, 12,  4, 13,  9,  2, 12,  2, 13,  8, 11, 12,  2, 13,  7,  9,
        12,  2, 26, 13,  8,  9, 12,  2, 13,  9,  9, 12,  2])

target_ages: tensor([11609, 11609, 11609, 11609, 11609, 11609, 

#### Visualise a batch

## Create models and train

In [7]:
models = []

# Baseline model to test my changes against
#   Note: this benchmark model uses index position along the batch
models.append(GPTLanguageModel(config, vocab_size).to(device))

# My development model
# Handle positional vs. temporal encoding/embedding
# Cases: 
#     index-embedding:       use index position along the batch
#     index-encoding:        use index position along the batch
#     temporal-encoding:     use age along a patient's timeline
pos_encodings = ["index-embedding", "index-encoding", "temporal-encoding"]
for pe in pos_encodings:
    config = DemoConfig()
    config.pos_encoding = pe
    models.append(GPTModelForCausalLM(config, vocab_size).to(device))

m_names = ["kaparthy benchmark"] + pos_encodings

INFO:root:Using Positional Embedding. This module uses the index position of an event within the block of events.
INFO:root:Using Positional Encoding. This module uses the index position of an event within the block of events.
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.


In [8]:
loss_curves_train = [[] for _ in models]
loss_curves_val = [[] for _ in models]

In [9]:
for m_idx, model in enumerate(models):
    model = model.to(device)

    # print the number of parameters in the model
    print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, best_iter = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss = 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            logits, loss = model(batch['tokens'].to(device),
                                 ages=batch['ages'].to(device),
                                 targets=batch['target_tokens'].to(device),
                                 attention_mask=batch['attention_mask'].to(device)
                                 )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss = 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, loss = model(batch['tokens'].to(device), 
                                    ages=batch['ages'].to(device), 
                                    targets=batch['target_tokens'].to(device),
                                    attention_mask=batch['attention_mask'].to(device)   
                                   )
                    val_loss += loss.item()
                val_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}. Val loss {val_loss:.2f}")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                best_iter += 1
                if best_iter > 2:
                    break
            else:
                best_val = val_loss
                best_iter = 0
                
    prompt = ["DEPRESSION"]
    context = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    fut_tokens, fut_ages = model.generate(context, max_new_tokens=30)
    fut_words = dm.decode(fut_tokens[0].tolist())
    print(fut_words)


10.815844 M parameters
Epoch 0:	Train loss 1.19. Val loss 1.12
Epoch 1:	Train loss 0.92. Val loss 0.83
Epoch 2:	Train loss 0.71. Val loss 0.73
Epoch 3:	Train loss 0.66. Val loss 0.69
Epoch 4:	Train loss 0.63. Val loss 0.68
Epoch 5:	Train loss 0.62. Val loss 0.67
Epoch 6:	Train loss 0.61. Val loss 0.66
Epoch 7:	Train loss 0.61. Val loss 0.66
Epoch 8:	Train loss 0.60. Val loss 0.66
Epoch 9:	Train loss 0.60. Val loss 0.65
DEPRESSION diastolic_blood_pressure 7 0 . 0 bmi 2 0 . 5 diastolic_blood_pressure 6 0 . 0 POLYCYSTIC_OVARIAN_SYNDROME_PCOS diastolic_blood_pressure 7 0 . 0 ALLERGICRHINITISCONJ diastolic_blood_pressure 7 0 . 0 eosinophil_count 0 .
10.777444 M parameters
Epoch 0:	Train loss 1.18. Val loss 0.91
Epoch 1:	Train loss 0.74. Val loss 0.73
Epoch 2:	Train loss 0.66. Val loss 0.69
Epoch 3:	Train loss 0.63. Val loss 0.67
Epoch 4:	Train loss 0.61. Val loss 0.66
Epoch 5:	Train loss 0.61. Val loss 0.66
Epoch 6:	Train loss 0.60. Val loss 0.65
Epoch 7:	Train loss 0.60. Val loss 0.65
Epoc



Epoch 9:	Train loss 0.62. Val loss 0.69
DEPRESSION diastolic_blood_pressure 8 blood_urea . 0 basophil_count 0 eosinophil_count 0 PSORIASIS diastolic_blood_pressure 7 POLYCYSTIC_OVARIAN_SYNDROME_PCOS bmi 1 . 0 basophil_count 0 corrected_serum_calcium_level 2 bmi 1 basophil_count 0 diastolic_blood_pressure 8 diastolic_blood_pressure 7 DEPRESSION


In [10]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig("figs/loss_curves.png")

## Prompt testing

Probability of type II diabetes before and after a type I diagnosis

keys: 

    70: 'TYPE1DM'
    31: 'TYPE2DIABETES'

Small context comparison, high bmi and blood pressure vs low for diabetes risk

In [11]:
if config.tabular:
    low_risk_prompt = ["bmi", "diastolic_blood_pressure"]
    high_risk_prompt = ["bmi", "diastolic_blood_pressure"]
    ages_in_years = [19, 20]
else:
    low_risk_prompt = ["bmi", "2", "2", ".", "5", "diastolic_blood_pressure", "7", "9", ".", "0"]
    high_risk_prompt = ["bmi", "3", "7", ".", "5", "diastolic_blood_pressure", "9", "9", ".", "0"]
    ages_in_years = [19, 19, 19, 19, 20, 20, 20, 20, 20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

In [12]:
prompts, ages, desc = [], [], []

desc.append("Control: Low risk")
prompts.append(low_risk_prompt)
ages.append(ages_in_years)

desc.append("Control: High risk")
prompts.append(high_risk_prompt)
ages.append(ages_in_years)

desc.append("Control: Low risk + depression")
prompts.append(["DEPRESSION"] + low_risk_prompt)
ages.append([17] + ages_in_years)

desc.append("Low risk context: Type 1 diagnosis in prompt")
prompts.append(["TYPE1DM"] + low_risk_prompt)
ages.append([17] + ages_in_years)

desc.append("Low risk context: Type 1I diagnosis in prompt")
prompts.append(["TYPE2DIABETES"] + low_risk_prompt)
ages.append([17] + ages_in_years)

for model_idx in range(len(pos_encodings)+1):
    print(f"\n\nMODEL_IDX {model_idx}\n==================")
    
    for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
        print(f"\n{desc[p_idx]}: \n\t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        lgts, _ = models[model_idx](encoded_prompt,
                                    ages=to_days(age))
        probs = torch.nn.functional.softmax(lgts, dim=2)
        print(f"probability of type I diabetes {100*float(probs[0, 0, 70].cpu().detach().numpy()):.4f}%")
        print(f"probability of type II diabetes {100*float(probs[0, 0, 31].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



MODEL_IDX 0

Control: Low risk: 
	 (bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0002%
probability of type II diabetes 0.0011%

Control: High risk: 
	 (bmi,3,7,.,5,diastolic_blood_pressure,9,9,.,0): 
probability of type I diabetes 0.0002%
probability of type II diabetes 0.0011%

Control: Low risk + depression: 
	 (DEPRESSION,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0418%
probability of type II diabetes 1.8673%

Low risk context: Type 1 diagnosis in prompt: 
	 (TYPE1DM,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0550%
probability of type II diabetes 1.4223%

Low risk context: Type 1I diagnosis in prompt: 
	 (TYPE2DIABETES,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0864%
probability of type II diabetes 0.3458%


MODEL_IDX 1

Control: Low risk: 
	 (bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
probability of type I diabetes 0.0005%
prob