# Demo Notebook:
## Transformer For Causal Language Modelling 

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/my-virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Loaded environment from {venv_site_pkgs} for node-type '{node_type}.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")


Path '/rds/homes/g/gaddcz/Projects/CPRD/my-virtual-env-icelake/lib/python3.10/site-packages' not found. Check that it exists and/or that it exists for node-type 'icelake'.


In [3]:
import pytorch_lightning 
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.benchmarks.karpathy_gpt.transformer import GPTLanguageModel
from CPRD.src.models.transformer.task_heads.causal_lm import TransformerForCausalLM
from tqdm import tqdm

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # just for if i need more informative debugging statements
!pwd
%load_ext autoreload
%autoreload 2

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/NLP


## Build configurations

In [4]:
# Set config to be equivalent architecture of kaparthy benchmark,
#      note: there will be fewer paramaters due to weight tying in the causal language modelling head
@dataclass
class DemoConfig:
    learn_positional_embedding: bool = True
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 16
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 20
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [5]:
# Get a list of patients which fit a reduced set of criterion
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            # include_measurements=True,
                            # drop_missing_data=True,
                            # include_diagnoses=True,
                            # drop_empty_dynamic=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=4
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")


INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 74.82M tokens
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 27it [00:05,  4.64it/s]
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 2it [00:00,  8.83it/s]
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 2it [00:00, 12.77it/s]

466364 training patients
17841 validation patients
22034 test patients
184 vocab elements





In [1]:
import time

start = time.time()   # starting time
for row_idx, row in enumerate(dm.train_set):
    print(time.time() - start)
    start = time.time()
    if row_idx > 3:
        break
print(f"{row}")

NameError: name 'dm' is not defined

In [None]:
# display(dm.train_set.tokenizer._itos)

In [None]:
# display(dm.train_set.tokenizer._stoi)

In [9]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
display(dm.tokenizer._event_counts)


EVENT,COUNT,FREQUENCY
str,u32,f64
"""UNK""",0,0.0
"""Plasma_N-termi…",23,3.0742e-07
"""SICKLE_CELL_DI…",123,2e-06
"""CYSTICFIBROSIS…",127,2e-06
"""SYSTEMIC_SCLER…",199,3e-06
"""ADDISON_DISEAS…",239,3e-06
"""DOWNSSYNDROME""",361,5e-06
"""PLASMACELL_NEO…",399,5e-06
"""HAEMOCHROMATOS…",515,7e-06
"""SJOGRENSSYNDRO…",530,7e-06


In [10]:

for i, batch in enumerate(dm.train_dataloader()):
    tokens = batch["tokens"][0].tolist()    
    sentence = dm.decode(tokens).split(" ")
    for t, w in zip(tokens,sentence):
        print(f"{t}".ljust(20) + f"{w}")
        if t == 0:
            break
    break
# print(batch)

183                 Diastolic_blood_pressure_5
174                 Body_mass_index_3
182                 Systolic_blood_pressure_4
181                 O_E_-_weight_2
183                 Diastolic_blood_pressure_5
174                 Body_mass_index_3
182                 Systolic_blood_pressure_4
181                 O_E_-_weight_2
183                 Diastolic_blood_pressure_5
174                 Body_mass_index_3
182                 Systolic_blood_pressure_4
181                 O_E_-_weight_2
183                 Diastolic_blood_pressure_5
174                 Body_mass_index_3
182                 Systolic_blood_pressure_4
183                 Diastolic_blood_pressure_5
182                 Systolic_blood_pressure_4
101                 OSTEOARTHRITIS
87                  Serum_T4_level_78
108                 ASTHMA_PUSHASTHMA
160                 O_E_-_height_1
181                 O_E_-_weight_2
183                 Diastolic_blood_pressure_5
174                 Body_mass_index_3
182         

## Create models and train

In [11]:
models, m_names = [], []

# Baseline model to test my changes against
# models.append(GPTLanguageModel(config, vocab_size).to(device))
# m_names.append("kaparthy benchmark")

# My development model
for pe in [True, False]:
    config = DemoConfig()
    config.learn_positional_embedding = pe
    models.append(TransformerForCausalLM(config, vocab_size).to(device))
    m_names.append(f"{'pos_embedding' if pe else 'pos_encoding'}")

loss_curves_train = [[] for _ in models]
loss_curves_val = [[] for _ in models]

INFO:root:Using Positional Embedding. This module uses the index position of an event within the block of events.
INFO:root:Using Positional Encoding. This module uses the index position of an event within the block of events.


In [12]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"\nTraining model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters\n")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss = 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 50:
                break
                
            # evaluate the loss
            _, loss = model(batch['tokens'].to(device),
                            attention_mask=batch['attention_mask'].to(device)
                            )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss = 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.train_dataloader())):
                    if j > 20:
                        break
                        
                    _, loss = model(batch['tokens'].to(device), 
                                    attention_mask=batch['attention_mask'].to(device)   
                                   )
                    val_loss += loss.item()
                val_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}. Val loss {val_loss:.2f}")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 2:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: an initial diagnosis of depression
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    # generate: then sample the next 10 tokens
    new_tokens = model.generate(tokens, max_new_tokens=30)[0].tolist()
    generated = dm.decode(new_tokens)
    print(f"\t {generated}")


Training model `pos_embedding`, with 10.8096 M parameters



Training epoch 0:   0%|          | 32/29148 [00:23<5:51:35,  1.38it/s]

KeyboardInterrupt



In [None]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"\Testing model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters\n")
    model = model.to(device)
    
    model.eval()
    total_loss = 0
    for i, batch in tqdm(enumerate(dm.test_dataloader()), desc=f"Testing", total=len(dm.test_dataloader())):
            
        # evaluate the loss
        _, loss = model(batch['tokens'].to(device),
                        attention_mask=batch['attention_mask'].to(device)
                        )
        total_loss += loss
    print(total_loss / len(dm.test_dataloader()))

In [None]:
batch['tokens']

In [None]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.grid()
plt.savefig(f"figs/transformer/loss.png")

# Prompt testing

## Diabetes

Probability of type II diabetes before and after a type I diagnosis

keys: 

    70: 'TYPE1DM'
    31: 'TYPE2DIABETES'

In [None]:
target_token1 = dm.tokenizer._stoi["TYPE1DM"]
target_token2 = dm.tokenizer._stoi["TYPE2DIABETES"]

print(target_token1)
print(target_token2)

Small context comparison, high bmi and blood pressure vs low for diabetes risk

In [None]:
low_risk_prompt = ["Body_mass_index_3", "2", "2", ".", "5", "Diastolic_blood_pressure_5", "7", "9", ".", "0"]
high_risk_prompt = ["Body_mass_index_3", "3", "7", ".", "5", "Diastolic_blood_pressure_5", "9", "9", ".", "0"]
print(dm.encode(low_risk_prompt))
print(dm.encode(high_risk_prompt))

In [None]:
prompts, desc = [], []

desc.append("Control: Low risk")
prompts.append(low_risk_prompt)

desc.append("Control: High risk")
prompts.append(high_risk_prompt)

desc.append("Control: depression + Low risk")
prompts.append(["DEPRESSION"] + low_risk_prompt)

desc.append("Low risk context: Type 1 diagnosis in prompt")
prompts.append(["TYPE1DM"] + low_risk_prompt)

desc.append("Low risk context: Type 1I diagnosis in prompt")
prompts.append(["TYPE2DIABETES"] + low_risk_prompt)

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    
    for p_idx, prompt in enumerate(prompts):
        print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        lgts, _ = model(encoded_prompt)
        print(lgts.shape)
        probs = torch.nn.functional.softmax(lgts, dim=1)
        print(f"probability of type I diabetes {100*float(probs[0, target_token1].cpu().detach().numpy()):.4f}%")
        print(f"probability of type II diabetes {100*float(probs[0, target_token2].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input Transformer.ipynb