# Demo Notebook:
## Transformer For Causal Language Modelling 

In [1]:
import pytorch_lightning 
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.benchmarks.karpathy_gpt.transformer import GPTLanguageModel
from CPRD.src.models.transformer.task_heads.causal_lm import TransformerForCausalLM

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for if i need more informative debugging statements
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark,
#      note: there will be fewer paramaters due to weight tying in the causal language modelling head
@dataclass
class DemoConfig:
    learn_positional_embedding: bool = True
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 20
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [3]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set

# Lets take only the first N for faster run-time
N = np.min((len(all_identifiers), 10000))
print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
identifiers = random.choices(all_identifiers, k=N)

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="non-tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=True,
                            include_diagnoses=True,
                            preprocess_measurements=False)
vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training, {len(dm.val_set)} validation, and {len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building polars dataset


Using N=10000 random samples, from the available 117102


INFO:root:Using measurements
INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using non-tabular tokenizer


8647 training, 481 validation, and 480 test samples
100 vocab elements


In [4]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (88, 3)
┌───────────────────────────────────┬────────┬──────────┐
│ EVENT                             ┆ counts ┆ freq     │
│ ---                               ┆ ---    ┆ ---      │
│ str                               ┆ u32    ┆ f64      │
╞═══════════════════════════════════╪════════╪══════════╡
│ UNK                               ┆ 0      ┆ 0.0      │
│ diastolic_blood_pressure          ┆ 181777 ┆ 0.422629 │
│ bmi                               ┆ 70392  ┆ 0.16366  │
│ eosinophil_count                  ┆ 66905  ┆ 0.155553 │
│ basophil_count                    ┆ 44291  ┆ 0.102976 │
│ corrected_serum_calcium_level     ┆ 12459  ┆ 0.028967 │
│ DEPRESSION                        ┆ 6910   ┆ 0.016066 │
│ serum_level                       ┆ 5513   ┆ 0.012818 │
│ calculated_LDL_cholesterol_level  ┆ 5249   ┆ 0.012204 │
│ ANXIETY                           ┆ 3812   ┆ 0.008863 │
│ HYPERTENSION                      ┆ 2697   ┆ 0.00627  │
│ TYPE2DIABETES                     ┆ 2247   ┆ 0.005224 │

## Create models and train

In [5]:
models, m_names = [], []

# Baseline model to test my changes against
# models.append(GPTLanguageModel(config, vocab_size).to(device))
# m_names.append("kaparthy benchmark")

# My development model
for pe in [True, False]:
    config = DemoConfig()
    config.learn_positional_embedding = pe
    models.append(TransformerForCausalLM(config, vocab_size).to(device))
    m_names.append(f"{'pos_embedding' if pe else 'pos_encoding'}")

loss_curves_train = [[] for _ in models]
loss_curves_val = [[] for _ in models]

INFO:root:Using Positional Embedding. This module uses the index position of an event within the block of events.
INFO:root:Using Positional Encoding. This module uses the index position of an event within the block of events.


In [6]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"\nTraining model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters\n")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss = 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            _, loss = model(batch['tokens'].to(device),
                            attention_mask=batch['attention_mask'].to(device)
                            )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss = 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, loss = model(batch['tokens'].to(device), 
                                    attention_mask=batch['attention_mask'].to(device)   
                                   )
                    val_loss += loss.item()
                val_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}. Val loss {val_loss:.2f}")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 2:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: an initial diagnosis of depression
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    # generate: then sample the next 10 tokens
    new_tokens = model.generate(tokens, max_new_tokens=30)[0].tolist()
    generated = dm.decode(new_tokens)
    print(f"\t {generated}")


Training model `pos_embedding`, with 10.777344 M parameters

Epoch 0:	Train loss 1.81. Val loss 1.56
Epoch 1:	Train loss 1.20. Val loss 1.31
Epoch 2:	Train loss 1.07. Val loss 1.20
Epoch 3:	Train loss 1.04. Val loss 1.18
Epoch 4:	Train loss 1.02. Val loss 1.17
Epoch 5:	Train loss 1.01. Val loss 1.15
Epoch 6:	Train loss 1.01. Val loss 1.16
Epoch 7:	Train loss 1.00. Val loss 1.15
Epoch 8:	Train loss 1.00. Val loss 1.15
Epoch 9:	Train loss 1.00. Val loss 1.15
Epoch 10:	Train loss 0.99. Val loss 1.15
Epoch 11:	Train loss 0.99. Val loss 1.14
Epoch 12:	Train loss 0.99. Val loss 1.15
Epoch 13:	Train loss 0.98. Val loss 1.15
	 DEPRESSION ANXIETY diastolic_blood_pressure 7 9 . 0 basophil_count 0 . 1 eosinophil_count 0 . 3 basophil_count 0 . 1 eosinophil_count 0 . 3 diastolic_blood_pressure 7 0 . 0 eosinophil_count 0 .

Training model `pos_encoding`, with 10.67904 M parameters

Epoch 0:	Train loss 2.55. Val loss 2.47
Epoch 1:	Train loss 2.10. Val loss 2.27
Epoch 2:	Train loss 1.86. Val loss 1.9

In [7]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.grid()
plt.savefig(f"figs/transformer/loss.png")

# Prompt testing

## Diabetes

Probability of type II diabetes before and after a type I diagnosis

keys: 

    70: 'TYPE1DM'
    31: 'TYPE2DIABETES'

In [8]:
target_token1 = dm.tokenizer._stoi["TYPE1DM"]
target_token2 = dm.tokenizer._stoi["TYPE2DIABETES"]

Small context comparison, high bmi and blood pressure vs low for diabetes risk

In [9]:
low_risk_prompt = ["bmi", "2", "2", ".", "5", "diastolic_blood_pressure", "7", "9", ".", "0"]
high_risk_prompt = ["bmi", "3", "7", ".", "5", "diastolic_blood_pressure", "9", "9", ".", "0"]

In [10]:
prompts, desc = [], []

desc.append("Control: Low risk")
prompts.append(low_risk_prompt)

desc.append("Control: High risk")
prompts.append(high_risk_prompt)

desc.append("Control: Low risk + depression")
prompts.append(["DEPRESSION"] + low_risk_prompt)

desc.append("Low risk context: Type 1 diagnosis in prompt")
prompts.append(["TYPE1DM"] + low_risk_prompt)

desc.append("Low risk context: Type 1I diagnosis in prompt")
prompts.append(["TYPE2DIABETES"] + low_risk_prompt)

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    
    for p_idx, prompt in enumerate(prompts):
        print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        lgts, _ = model(encoded_prompt)
        print(lgts.shape)
        probs = torch.nn.functional.softmax(lgts, dim=1)
        print(f"probability of type I diabetes {100*float(probs[0, target_token1].cpu().detach().numpy()):.4f}%")
        print(f"probability of type II diabetes {100*float(probs[0, target_token2].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



pos_embedding

Control: Low risk: 	 (bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
torch.Size([9, 100])
probability of type I diabetes 0.0012%
probability of type II diabetes 0.0011%

Control: High risk: 	 (bmi,3,7,.,5,diastolic_blood_pressure,9,9,.,0): 
torch.Size([9, 100])
probability of type I diabetes 0.0012%
probability of type II diabetes 0.0011%

Control: Low risk + depression: 	 (DEPRESSION,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
torch.Size([10, 100])
probability of type I diabetes 0.0561%
probability of type II diabetes 0.1972%

Low risk context: Type 1 diagnosis in prompt: 	 (TYPE1DM,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
torch.Size([10, 100])
probability of type I diabetes 0.4271%
probability of type II diabetes 0.7650%

Low risk context: Type 1I diagnosis in prompt: 	 (TYPE2DIABETES,bmi,2,2,.,5,diastolic_blood_pressure,7,9,.,0): 
torch.Size([10, 100])
probability of type I diabetes 0.3746%
probability of type II diabetes 0.3872%


pos_encoding

Control:

# Appendix: model architectures

In [11]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



pos_embedding


TransformerForCausalLM(
  (transformer): Transformer(
    (wpe): PositionalEmbedding(
      (wpe): Embedding(256, 384)
    )
    (wte): Embedding(100, 384, padding_idx=0)
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=True)
          (acti): ReLU()
      