In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

import umap
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


In [4]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.type='zeroshot'",
                             "experiment.run_id='CR_11M_Natalia'",
                              "experiment.fine_tune_id='CVD-0shot'",
                             "experiment.train=False",
                             "experiment.test=False",
                             'experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]',
                             # Dataloader
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                             # "data.batch_size=512",
                             # Optimiser
                             "optim.limit_test_batches=null",
                             # Single-Risk specific
                            ]
                 )     

cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"


# print(OmegaConf.to_yaml(cfg))
experiment, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in experiment.model.parameters())/1e6} M parameters")

wandb.finish()


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_val.pickle
I

Loaded model with 11.433294 M parameters


In [8]:
dm.val_set[0]

{'static_covariates': tensor([0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.5081]),
 'tokens': tensor([247, 261, 219, 251, 262, 247, 219, 251, 251, 260, 261, 262, 189, 231,
         235, 135, 244, 185, 238, 232, 240, 237, 239, 243, 243, 231, 235, 135,
         244, 238, 232, 240, 237, 239, 243, 231, 247, 261, 235, 101, 135, 244,
         238, 232, 240, 237, 239, 251, 243, 248, 245, 246, 227, 262, 127, 214,
         248, 180, 245, 246, 227, 261, 262, 231, 235, 226, 244, 238, 220, 232,
         240, 237, 239, 243, 236, 241, 261, 262, 231, 261, 235, 193, 226, 244,
         238, 220, 232, 240, 237, 239, 243, 236, 214, 180, 262, 241, 198, 213,
         260, 260, 255, 247, 251, 189,  97, 160, 254, 254, 126, 205, 160, 218,
         255, 255, 255, 205, 260, 160, 218, 260, 255, 205, 160, 218, 260, 160,
         255, 205, 260, 160, 218, 255, 205, 160, 218, 217, 260, 205, 160, 218,
         261, 262, 160, 231, 23

In [51]:
token_labels = []
tokens = []

for token, event in enumerate(list(dm.train_set.tokenizer._stoi.keys())):

    # only look at diagnoses
    if event.upper() == event:
        tokens.append(token)
        token_labels.append(event)
        
display(token_labels[:10])

tokens = torch.tensor(tokens).to(device)
print(tokens)

['PAD',
 'UNK',
 'ADDISONS_DISEASE',
 'CYSTICFIBROSIS',
 'SYSTEMIC_SCLEROSIS',
 'SICKLE_CELL_DISEASE_V2',
 'ADDISON_DISEASE',
 'DOWNSSYNDROME',
 'HAEMOCHROMATOSIS_V2',
 'PLASMACELL_NEOPLASM_V2']

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  16,  18,  19,  20,  21,  22,  23,  26,  28,  29,  30,  31,  32,
         33,  34,  35,  36,  37,  38,  39,  40,  41,  43,  44,  45,  46,  48,
         49,  52,  56,  57,  58,  60,  62,  65,  67,  70,  74,  75,  78,  79,
         81,  82,  83,  89,  91,  93,  94,  95,  97, 100, 104, 106, 114, 120,
        123, 126, 129, 133, 134, 136], device='cuda:0')


In [52]:
# tokens = torch.arange(265).to(device)[:50]


tkn_embedding = experiment.model.transformer.wte.dynamic_embedding_layer(tokens)
scaled_tkn_embedding = StandardScaler().fit_transform(tkn_embedding.detach().cpu())


In [53]:

plt.close()

reducer = umap.UMAP()

embedding = reducer.fit_transform(scaled_tkn_embedding)
print(embedding.shape)

plt.figure()
plt.scatter(embedding[:,0], embedding[:, 1], label=token_labels)
for i, txt in enumerate(token_labels):
    plt.annotate(txt[:10].lower(), (embedding[i,0], embedding[i,1]))

plt.savefig("token_embedding.png")


(76, 2)
