# BEHRT for survival

``` This is a modified version of the NextVIsit-6month.ipynb script```

This notebook has been changed to consider a survival downstream task. Where possible I have kept the original code structure.

We benchmark different datasets, considering two experimental setups for each:
- **Scratch**.
      This is the Scratch Fine-tuning (SFT) set-up presented in SurvivEHR
- **Pre-trained**.
      This is different to Full Fine-tuning (FFT) presentation in SurvivEHR.
      Here we pre-train on the MLM pre-training objective, but only the samples in the fine-tuning dataset given.
      This is because the BEHRT code does not scale (either in terms of compute, nor memory) to the number of pre-training samples used in SurvivEHR.

The CPRD data (hypertension and cardiovascular disease outcome datasets) used in this notebook are not shared. We instead share an example simulated dataset.

In [None]:
dataset = "cvd"
from_pretrained = True

# Notebook results:

### Cardiovascular disease

**Scratch**

```
Test results after: 7 epochs	
| Loss: 0.5792801863022102	
| Ctd: 0.5975980840896481	
| IBS: 0.03388068110585833	
| INBLL: 0.14627500543686311	
 
```

**Semi pre-trained**

```
Test results after: 5 epochs	
| Loss: 0.580362102957245	
| Ctd: 0.609454119345055	
| IBS: 0.03381188359111024	
| INBLL: 0.14602788332436595	
```

In [None]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/BEHRT-with-FastEHR/my-virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

In [None]:
import sys 
sys.path.insert(0, '../')

from common.common import create_folder,load_obj
import pandas as pd
import numpy as np
import os
import sklearn.metrics as skm
import math
import random
import time
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import pytorch_pretrained_bert as Bert
from pycox.evaluation import EvalSurv

from model.utils import age_vocab
from model import optimiser
from dataLoader.Survival import Survival
from dataLoader.utils import seq_padding,code2index, position_idx, index_seg
from DeSurv.src.classes import ODESurvSingle, ODESurvMultiple


# File Parameters

In [None]:
match dataset:
    case "local_example":
        file_config = {
            'vocab':'data/local_example/token2idx',  # vocabulary idx2token, token2idx
            'train': 'data/local_example/data_train.parquet',  # formated data 
            'test': 'data/local_example/data_test.parquet',  # formated data 
            'val': 'data/local_example/data_val.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'pretrainModel': 'local_MLM-notebook.ckpt', # pre-trained model name
            'model_name': 'local_MLMSurv-notebook.ckpt', # model name\
            'file_name': 'local_MLMSurv-notebook.out',  # log path
            'event_code': ["DEATH"],
            "competing_risks": False,
            'freeze_backbone': False, # Whether to train the BEHRT encoder architecture (False) or not (True)
        }
    case "fastehr_example":
        file_config = {
            'vocab':'/rds/homes/g/gaddcz/Projects/FastEHR/examples/data/_built/adapted/BEHRT/T2D_hypertension/token2idx',  # vocabulary idx2token, token2idx
            'train': '/rds/homes/g/gaddcz/Projects/FastEHR/examples/data/_built/adapted/BEHRT/T2D_hypertension/dataset.parquet',  # formated data
            'test': None,  # formated data 
            'val': None,  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'pretrainModel': 'fastehr_MLM-notebook.ckpt', # pre-trained model name
            'model_name': 'fastehr_MLMSurv-notebook.ckpt', # model name\
            'file_name': 'fastehr_MLMSurv-notebook.out',  # log path
            'freeze_backbone': False,  # Whether to train the BEHRT encoder architecture (False) or not (True)
        }
    case "hypertension":
        file_config = {
            'vocab':'/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/token2idx',  # vocabulary idx2token, token2idx
            'train': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/train_dataset.parquet',  # formated data 
            'test': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/test_dataset.parquet',  # formated data 
            'val': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/val_dataset.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'pretrainModel': 'hypertension_MLM-notebook.ckpt', # pre-trained model name
            'model_name': 'hypertension_MLMSurv-notebook.ckpt', # model name\
            'file_name': 'hypertension_MLMSurv-notebook.out',  # log path
            "event_code": ["HYPERTENSION"],
            "competing_risks": False,
            'freeze_backbone': False, # Whether to train the BEHRT encoder architecture (False) or not (True)
        }
    case "cvd":
        file_config = {
            'vocab':'/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/token2idx',  # vocabulary idx2token, token2idx
            'train': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/train_dataset.parquet',  # formated data 
            'test': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/test_dataset.parquet',  # formated data 
            'val': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/val_dataset.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'pretrainModel': 'cvd_MLM-notebook.ckpt', # pre-trained model name
            'model_name': 'cvd_MLMSurv-notebook.ckpt', # model name\
            'file_name': 'cvd_MLMSurv-notebook.out',  # log path
            "event_code": ["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"],
            "competing_risks": True,
            'freeze_backbone': False,  # Whether to train the BEHRT encoder architecture (False) or not (True)
        }
    
    case _:
        raise NotImplementedError

if from_pretrained is False:
    file_config["pretrainModel"] = None
    
print(dataset)

In [None]:
optim_config = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}


global_params = {
    'batch_size': 64,
    'gradient_accumulation_steps': 1,
    'device': 'cuda:0',
    'output_dir': file_config["model_path"],  # output dir
    'best_name': file_config["file_name"], # output model name
    'save_model': True,
    'max_len_seq': 64,    # Modified as shipped version was a bug. Cannot load previous checkpoint if you modify the architecture
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 1        # As during modified pre-training, we reduce to one (from 5) to be comparable
}

pretrainModel = file_config["pretrainModel"]  # MLM pretrained model path

In [None]:
create_folder(global_params['output_dir'])

In [None]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [None]:
# Not needed for new task

# def format_label_vocab(token2idx):
#     token2idx = token2idx.copy()
#     del token2idx['PAD']
#     del token2idx['SEP']
#     del token2idx['CLS']
#     del token2idx['MASK']
#     token = list(token2idx.keys())
#     labelVocab = {}
#     for i,x in enumerate(token):
#         labelVocab[x] = i
#     return labelVocab

# Vocab_diag = format_label_vocab(BertVocab['token2idx'])

In [None]:
model_config = {
    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

feature_dict = {
    'age': True,
    'seg': True,
    'posi': True
}

# Set Up Model

In [None]:
# Data pre-processing class block

#
#
# Removed as this is handled in FastEHR
#
#

In [None]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, segment, age
    """

    def __init__(self, config, feature_dict):
        super(BertEmbeddings, self).__init__()
        self.feature_dict = feature_dict
        
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))

        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):
        if seg_ids is None:
            seg_ids = torch.zeros_like(word_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(word_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(word_ids)

        word_embed = self.word_embeddings(word_ids)
        segment_embed = self.segment_embeddings(seg_ids)
        age_embed = self.age_embeddings(age_ids)
        posi_embeddings = self.posi_embeddings(posi_ids)
        
        embeddings = word_embed
        
        if self.feature_dict['age']:
            embeddings = embeddings + age_embed
        if self.feature_dict['seg']:
            embeddings = embeddings + segment_embed
        if self.feature_dict['posi']:
            embeddings = embeddings + posi_embeddings
        
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos/(10000**(2*idx/hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos/(10000**(2*idx/hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)


class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, feature_dict):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config, feature_dict)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(input_ids)
        if seg_ids is None:
            seg_ids = torch.zeros_like(input_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(input_ids)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


# Previous task wrapper

# class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):
#     def __init__(self, config, num_labels, feature_dict):
#         super(BertForMultiLabelPrediction, self).__init__(config)
#         self.num_labels = num_labels
#         self.bert = BertModel(config, feature_dict)
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.classifier = nn.Linear(config.hidden_size, num_labels)
#         self.apply(self.init_bert_weights)

#     def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):
#         _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,
#                                      output_all_encoded_layers=False)
#         pooled_output = self.dropout(pooled_output)
#         logits = self.classifier(pooled_output)

#         if labels is not None:
#             loss_fct = nn.MultiLabelSoftMarginLoss()
#             loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
#             return loss, logits
#         else:
#             return logits

# New task wrapper

class BertForSurvival(Bert.modeling.BertPreTrainedModel):
    """
    """
    def __init__(self, 
                 config, 
                 feature_dict,
                 competing_risks=False,
                 num_risks=None
                ):
        """
        """
        super(BertForSurvival, self).__init__(config)
        self.bert = BertModel(config, feature_dict)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.competing_risks = competing_risks
        self.num_risks = num_risks
        self.t_eval = np.linspace(0, 1, 1000)
    
        if competing_risks:
            assert num_risks is not None
            print(f"Using ODESurvMultiple with causes: {num_risks}")
            self.desurv_model = ODESurvMultiple(
                lr=optim_config['lr'],
                cov_dim=config.hidden_size,
                hidden_dim=32,
                num_risks=num_risks,
                device="cpu" if global_params["device"] == "cpu" else "gpu",
            )
            self.device = self.desurv_model.odenet.device
        
        else:
            print(f"Using ODESurvSingle for union of given events")
            self.desurv_model = ODESurvSingle(
                lr=optim_config['lr'],
                cov_dim=config.hidden_size,
                hidden_dim=32,
                device="cpu" if global_params["device"] == "cpu" else "gpu",
            )
            self.device = self.desurv_model.net.device
            
        self.apply(self.init_bert_weights)

    def forward(self,
                target_label, target_time, patient_id,
                input_ids,
                age_ids=None,
                seg_ids=None,
                posi_ids=None,
                attention_mask=None,
                # labels=None
               ):
        
        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)

        # Forward DeSurv model for survival prediction using pooled_output as input features
        argsort_t = torch.argsort(target_time)
        x_ = pooled_output[argsort_t,:].to(global_params["device"])
        t_ = target_time[argsort_t].to(global_params["device"])
        k_ = target_label[argsort_t].to(global_params["device"])
        
        return self.desurv_model.forward(x_,t_,k_)

    def predict(self, *args, **kwargs):
        if self.competing_risks:
            return self.predict_cr(*args, **kwargs)
        else:
            return self.predict_sr(*args, **kwargs)
            
    def predict_cr(self,
                   input_ids,
                   age_ids=None,
                   seg_ids=None,
                   posi_ids=None,
                   attention_mask=None,                
                  ):

        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        # pooled_output = self.dropout(pooled_output)
        
        # The normalised grid over which to predict
        t_test = torch.tensor(np.concatenate([self.t_eval] * pooled_output.shape[0], 0), dtype=torch.float32, device=self.device) 
        H_test = pooled_output.repeat_interleave(self.t_eval.size, 0).to(self.device, torch.float32)

        # Batched predict: Cannot make all predictions at once due to memory constraints
        pred_bsz = 512                                                        # Predict in batches
        pred = []
        pi = []
        for H_test_batched, t_test_batched in zip(torch.split(H_test, pred_bsz), torch.split(t_test, pred_bsz)):
            _pred, _pi = self.desurv_model.predict(H_test_batched, t_test_batched)
            pred.append(_pred)
            pi.append(_pi)

        pred = torch.concat(pred)
        pi = torch.concat(pi)
        pred = pred.reshape((pooled_output.shape[0], self.t_eval.size, -1)).cpu().detach().numpy()
        pi = pi.reshape((pooled_output.shape[0], self.t_eval.size, -1)).cpu().detach().numpy()
        preds = [pred[:, :, _i] for _i in range(pred.shape[-1])]
        pis = [pi[:, :, _i] for _i in range(pi.shape[-1])]

        return preds, pis

    def predict_sr(self,
                   input_ids,
                   age_ids=None,
                   seg_ids=None,
                   posi_ids=None,
                   attention_mask=None,  
                  ):

        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        # pooled_output = self.dropout(pooled_output)

        # The normalised grid over which to predict
        t_test = torch.tensor(np.concatenate([self.t_eval] * pooled_output.shape[0], 0), dtype=torch.float32, device=self.device)
        H_test = pooled_output.repeat_interleave(self.t_eval.size, 0).to(self.device, torch.float32)

        # Batched predict: Cannot make all predictions at once due to memory constraints
        pred_bsz = 512                                                        # Predict in batches
        pred = []
        pi = []
        for H_test_batched, t_test_batched in zip(torch.split(H_test, pred_bsz), torch.split(t_test, pred_bsz)):
            _pred = self.desurv_model.predict(H_test_batched, t_test_batched)
            pred.append(_pred)
        pred = [torch.concat(pred).reshape(pooled_output.shape[0], self.t_eval.size).cpu().detach().numpy()]

        return pred, None

# Load Data

In [None]:
# Added code that is used for the survival target labels
if file_config["competing_risks"]:
    # for competing risks, convert event codes to one-hot (1,2,3,..., K if in event_code, else 0)
    event_pos = {code: i + 1 for i, code in enumerate(file_config["event_code"])}
    label2idx = {k: event_pos.get(k, 0) for k in BertVocab['token2idx']}
    print(f"Using a competing risks model with causes: {file_config['event_code']}")
else:
    # for single-risk, convert event codes to binary (1 if in event_code, else 0)
    label2idx = {key:1 if key in file_config["event_code"] else 0 for key, item in BertVocab['token2idx'].items()}
    print(f"Using a single-risk model with union of events: {file_config['event_code']}")

# Report for sanity checking
from collections import defaultdict
def group_keys_by_value(d, *, sort_keys=False):
    groups = defaultdict(list)
    for k, v in d.items():
        groups[v].append(k)
    if sort_keys:
        for v in groups:
            groups[v].sort()
    return dict(groups)
    
groups = group_keys_by_value(label2idx, sort_keys=True)

print(f"These are combined to outcomes with label2idx mapping:")
for idx, (v, keys) in enumerate(groups.items()):
    if idx == 0:
        print(f"{v}: {len(keys)} keys are considered right-censored events")
    else:
        print(f"{v}: {keys}")

In [None]:
# data = pd.read_parquet(file_config['train']).reset_index(drop=True)
# data['label'] = data.label.apply(lambda x: list(set(x)))
# Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])
# trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)

# from dataLoader.MLM import MLMLoader
# Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=global_params['max_len_seq'], code='caliber_id')
# trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)

In [None]:
# Training data

data = pd.read_parquet(file_config['train'])
# remove patients with visits less than min visit
data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
data = data[data['length'] >= global_params['min_visit']]
data = data.reset_index(drop=True)
data["patid"] = data.index


Dset = Survival(
    BertVocab['token2idx'],
    label2idx,
    ageVocab,
    data,
    max_len=global_params['max_len_seq'],
    code='caliber_id',
    label="target_event",
    label_time="target_time",
    )
trainload = DataLoader(
    dataset=Dset,
    batch_size=global_params['batch_size'],
    shuffle=True,
    num_workers=3
    )


In [None]:
# data = pd.read_parquet(file_config['test']).reset_index(drop=True)
# data['label'] = data.label.apply(lambda x: list(set(x)))
# Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])
# testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)

In [None]:
# Testing data
if file_config["test"] is not None:
    data = pd.read_parquet(file_config['test'])
    # remove patients with visits less than min visit
    data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
    data = data[data['length'] >= global_params['min_visit']]
    data = data.reset_index(drop=True)
    data["patid"] = data.index
    
    
    Dset = Survival(
        BertVocab['token2idx'],
        label2idx,
        ageVocab,
        data,
        max_len=global_params['max_len_seq'],
        code='caliber_id',
        label="target_event",
        label_time="target_time",
        )
    testload = DataLoader(
        dataset=Dset,
        batch_size=global_params['batch_size'],
        shuffle=False,
        num_workers=3
        )

# Added validation testing data
if file_config["val"] is not None:
    data = pd.read_parquet(file_config['val'])
    # remove patients with visits less than min visit
    data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
    data = data[data['length'] >= global_params['min_visit']]
    data = data.reset_index(drop=True)
    data["patid"] = data.index
    
    
    Dset = Survival(
        BertVocab['token2idx'],
        label2idx,
        ageVocab,
        data,
        max_len=global_params['max_len_seq'],
        code='caliber_id',
        label="target_event",
        label_time="target_time",
        )
    valload = DataLoader(
        dataset=Dset,
        batch_size=global_params['batch_size'],
        shuffle=False,
        num_workers=3
        )

# Set Up Model

In [None]:
conf = BertConfig(model_config)
model = BertForSurvival(conf, 
                        feature_dict,
                        competing_risks=file_config["competing_risks"],
                        num_risks=len(file_config["event_code"])
                       )


In [None]:
print(pretrainModel)

if pretrainModel is not None:
    # load pretrained model and update weights
    pretrained_dict = torch.load(file_config["model_path"] + pretrainModel)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    model.load_state_dict(model_dict)

In [None]:
model = model.to(global_params['device'])
# optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)              # Removed as originally code overwrites this again later 

# Evaluation Matrix

In [None]:
# import sklearn
# def precision(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
#     label, output=label.cpu(), output.detach().cpu()
#     tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
#     return tempprc, output, label

# def precision_test(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
#     tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
# #     roc = sklearn.metrics.roc_auc_score()
#     return tempprc, output, label

# def auroc_test(logits, label):
#     sig = nn.Sigmoid()
#     output=sig(logits)
#     tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')
# #     roc = sklearn.metrics.roc_auc_score()
#     return tempprc

# Multi-hot Label Encoder

In [None]:
# from sklearn.preprocessing import MultiLabelBinarizer
# mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))
# mlb.fit([[each] for each in list(Vocab_diag.values())])

# Train and Test

In [None]:
# for step, batch in enumerate(trainload):
#     age_ids, input_ids, posi_ids, segment_ids, attMask, target_label, target_time, patient_id = batch
#     # age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch
#     # loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)

#     target_time = target_time.squeeze(-1).to(global_params['device'])         # [64]
#     target_label = target_label.squeeze(-1).to(global_params['device'])         # [64]
    
#     loss = model.desurv_model(torch.ones(target_time.shape[0], 288).to(global_params['device']), target_time, target_label)
#     if type(loss) is not np.int:
#         print(target_time)
#         print(target_label)
#         print(loss)

In [None]:
def train(e):
    model.train()
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt +=1
        age_ids, input_ids, posi_ids, segment_ids, attMask, target_label, target_time, patient_id = batch
        
        patient_id  = patient_id.squeeze(-1).to(global_params['device'])          # [64]
        target_time = target_time.squeeze(-1).to(global_params['device'])         # [64]
        target_label = target_label.squeeze(-1).to(global_params['device'])         # [64]

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        
        loss = model(target_label, 
                     target_time, 
                     patient_id,
                     input_ids, 
                     age_ids, 
                     segment_ids,
                     posi_ids,
                     attention_mask=attMask
                     )
        
        if global_params['gradient_accumulation_steps'] > 1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        
        if step % 2000==0:
            print("epoch: {}\t| Cnt: {}\t| Loss: {}\t".format(e, cnt, temp_loss/2000))
            temp_loss = 0
        
        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()

def validation(loader):
    model.eval()
    val_loss = 0
    val_samples = 0
    for step, batch in enumerate(loader):
        age_ids, input_ids, posi_ids, segment_ids, attMask, target_label, target_time, patient_id = batch
        
        patient_id  = patient_id.squeeze(-1).to(global_params['device'])          # [64]
        target_time = target_time.squeeze(-1).to(global_params['device'])         # [64]
        target_label = target_label.squeeze(-1).to(global_params['device'])         # [64]

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        
        loss = model(target_label, 
                     target_time, 
                     patient_id,
                     input_ids, 
                     age_ids, 
                     segment_ids,
                     posi_ids,
                     attention_mask=attMask
                     )

        val_samples += target_label.shape[0]
        val_loss += loss.item()
    val_loss /= val_samples
    return val_loss

def evaluation(loader):
    model.eval()

    ctd, ibs, inbll = [], [], []
    total_samples = 0
    for step, batch in enumerate(loader):
        
        age_ids, input_ids, posi_ids, segment_ids, attMask, target_label, target_time, patient_id = batch
        
        patient_id  = patient_id.squeeze(-1).to(global_params['device'])          # [64]
        target_time = target_time.squeeze(-1).to(global_params['device'])         # [64]
        target_label = target_label.squeeze(-1).to(global_params['device'])         # [64]

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        
        with torch.no_grad():

            pred_surv_CDFs, _ = model.predict(
                input_ids, 
                age_ids, 
                segment_ids,
                posi_ids,
                attention_mask=attMask
            )

            # Convert to numpy
            target_time = target_time.cpu().numpy()
            target_label = target_label.cpu().numpy()

            # Calculate the metrics by combining outcomes
            cdf = np.zeros_like(pred_surv_CDFs[0])
            lbls = np.zeros(target_label.shape)
            for outcome in range(len(pred_surv_CDFs)):
                lbls += (target_label == outcome + 1)
                cdf += pred_surv_CDFs[outcome]

            try:
                surv = pd.DataFrame(np.transpose((1 - cdf)), index=model.t_eval)
                ev = EvalSurv(surv, target_time, lbls, censor_surv='km')
                
                time_grid = np.linspace(start=0, stop=model.t_eval.max() , num=300)
                batch_ctd = ev.concordance_td() 
                batch_ibs = ev.integrated_brier_score(time_grid) 
                batch_inbll = ev.integrated_nbll(time_grid)
                
                # log 
                # Note: this is will handle batches with no comparable pairs the same way as SurvivEHR's callback
                ctd.append(batch_ctd)
                ibs.append(batch_ibs)
                inbll.append(batch_inbll)
            except:
                pass

    ctd = np.mean(ctd)
    ibs = np.mean(ibs)
    inbll = np.mean(inbll)

    return ctd, ibs, inbll


# Note: modification

In the original code the test loader is used to perform early stopping. This is very bad and leads to over reporting performance. 

Ideally this would be done with a validation split, but in-keeping with original code I have swapped it to the training split. 
I also added a max_wait on the early stopping

In [None]:
import warnings
warnings.filterwarnings(action='ignore')
optim_config = {
    'lr': 3e-6,
    'warmup_proportion': 0.1
}


if file_config["freeze_backbone"]:
    optim = optimiser.adam(params=list(model.desurv_model.named_parameters()), config=optim_config)
else:
    optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)


In [None]:
max_wait = 3

best_pre = np.inf
wait = 0
for e in range(50):

    # Train for an epoch
    train(e)

    # Validation 
    # Loss to check for early stopping
    if file_config["val"] is not None:
        val_loss = validation(valload)
    else:
        val_loss = None
    # Metrics
    if file_config["test"] is not None:
        ctd, ibs, inbll = evaluation(valload)
    else:
        ctd, ibs, inbll = None, None, None

    # Report
    print("epoch: {}\t| Loss: {}\t| Ctd: {}\t| IBS: {}\t| INBLL: {}\t".format(e, val_loss, ctd, ibs, inbll ))
    
    if val_loss < best_pre:
        # Save a trained model
        print("** ** * Saving fine - tuned model ** ** * ")
        wait = 0
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])
        create_folder(global_params['output_dir'])
        if global_params['save_model']:
            torch.save(model_to_save.state_dict(), output_model_file)
        best_pre = val_loss
    else:
        print("** ** * No improvement in model   ** ** * ")
        wait += 1
        if wait >= max_wait:
            print("** ** * Stopping training       ** ** * ")
            model_dict = torch.load(output_model_file)
            model.load_state_dict(model_dict)
            
            tst_loss = validation(testload)
            ctd, ibs, inbll = evaluation(testload)
            print("Test results after: {} epochs\t| Loss: {}\t| Ctd: {}\t| IBS: {}\t| INBLL: {}\t".format(e, tst_loss, ctd, ibs, inbll ))

            
            break
