# BEHRT MLM with new data

```This is a modified version of the MLM.ipynb notebook```

This notebook has been changed to run new benchmarking datasets


In [1]:
dataset = "local_example"

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/BEHRT-with-FastEHR/my-virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/BEHRT-with-FastEHR/my-virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/BEHRT-with-FastEHR/task


In [3]:
import sys 
sys.path.insert(0, '../')

In [4]:
from common.common import create_folder
from common.pytorch import load_model
import pytorch_pretrained_bert as Bert
from model.utils import age_vocab
from common.common import load_obj
from dataLoader.MLM import MLMLoader
from torch.utils.data import DataLoader
import pandas as pd
from model.MLM import BertForMaskedLM
from model.optimiser import adam
import sklearn.metrics as skm
import numpy as np
import torch
import time
import torch.nn as nn
import os

In [5]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        
class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

### Decide which dataset to use

In [6]:

match dataset:
    case "local_example":
        file_config = {
            'vocab':'data/local_example/token2idx',  # vocabulary idx2token, token2idx
            'data': 'data/local_example/data_train.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'model_name': 'local_MLM-notebook.ckpt', # model name\
            'file_name': 'local_MLM-notebook.out',  # log path
        }
    case "fastehr_example":
        file_config = {
            'vocab':'/rds/homes/g/gaddcz/Projects/FastEHR/examples/data/_built/adapted/BEHRT/T2D_hypertension/token2idx',  # vocabulary idx2token, token2idx
            'data': '/rds/homes/g/gaddcz/Projects/FastEHR/examples/data/_built/adapted/BEHRT/T2D_hypertension/dataset.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'model_name': 'fastehr_MLM-notebook.ckpt', # model name\
            'file_name': 'fastehr_MLM-notebook.out',  # log path
        }
    case "hypertension":
        file_config = {
            'vocab':'/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/token2idx',  # vocabulary idx2token, token2idx
            'data': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/BEHRT/train_dataset.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'model_name': 'hypertension_MLM-notebook.ckpt', # model name\
            'file_name': 'hypertension_MLM-notebook.out',  # log path
        }
    case "cvd":
        file_config = {
            'vocab':'/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/token2idx',  # vocabulary idx2token, token2idx
            'data': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/BEHRT/train_dataset.parquet',  # formated data 
            'model_path': f'data/{dataset}/', # where to save model
            'model_name': 'cvd_MLM-notebook.ckpt', # model name\
            'file_name': 'cvd_MLM-notebook.out',  # log path
        }
    
    case _:
        raise NotImplementedError
        
create_folder(file_config['model_path'])
print(dataset)

local_example


In [7]:
global_params = {
    'max_seq_len': 64,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 1,                        # Reduced to one (from 5) to be comparable
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 256,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': 'cuda:0'
}

### Create dataset and dataloaders

In [8]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [9]:
data = pd.read_parquet(file_config['data'])
# remove patients with visits less than min visit
data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
data = data[data['length'] >= global_params['min_visit']]
data = data.reset_index(drop=True)

In [10]:
Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')
trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)

### Define model

In [11]:
model_config = {
    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.1, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

In [12]:
conf = BertConfig(model_config)
model = BertForMaskedLM(conf)

In [13]:
model = model.to(train_params['device'])
optim = adam(params=list(model.named_parameters()), config=optim_param)

t_total value of -1 results in schedule not being applied


### Train model using BEHRT's experiment setup

In [14]:
def cal_acc(label, pred):
    logs = nn.LogSoftmax()
    label=label.cpu().numpy()
    ind = np.where(label!=-1)[0]
    truepred = pred.detach().cpu().numpy()
    truepred = truepred[ind]
    truelabel = label[ind]
    truepred = logs(torch.tensor(truepred))
    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
    precision = skm.precision_score(truelabel, outs, average='micro')
    return precision

In [15]:
def train(e, loader):
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt= 0
    start = time.time()

    for step, batch in enumerate(loader):
        cnt +=1
        batch = tuple(t.to(train_params['device']) for t in batch)
        age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch
        loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)
        if global_params['gradient_accumulation_steps'] >1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()
        
        temp_loss += loss.item()
        tr_loss += loss.item()
        
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        
        if step % 200==0:
            print("epoch: {}\t| cnt: {}\t|Loss: {}\t| precision: {:.4f}\t| time: {:.2f}".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))
            temp_loss = 0
            start = time.time()
            
        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()

    print("** ** * Saving fine - tuned model ** ** * ")
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    create_folder(file_config['model_path'])
    output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])

    torch.save(model_to_save.state_dict(), output_model_file)
        
    cost = time.time() - start
    return tr_loss, cost

In [16]:
data_len = 1000

f = open(os.path.join(file_config['model_path'], file_config['file_name']), "w")
f.write('{}\t{}\t{}\n'.format('epoch', 'loss', 'time'))
for e in range(50):
    loss, time_cost = train(e, trainload)
    loss = loss/data_len
    f.write('{}\t{}\t{}\n'.format(e, loss, time_cost))
f.close()    

  truepred = logs(torch.tensor(truepred))
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1485.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


epoch: 0	| cnt: 1	|Loss: 0.0012827775478363038	| precision: 0.0057	| time: 1.49
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 1	| cnt: 1	|Loss: 0.000769543170928955	| precision: 0.4615	| time: 0.38
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 2	| cnt: 1	|Loss: 0.0006995547413825989	| precision: 0.4871	| time: 0.32
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 3	| cnt: 1	|Loss: 0.0006977683901786805	| precision: 0.5125	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 4	| cnt: 1	|Loss: 0.0006403416395187378	| precision: 0.5222	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 5	| cnt: 1	|Loss: 0.0006386865377426147	| precision: 0.5067	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 6	| cnt: 1	|Loss: 0.0006288836002349853	| precision: 0.4964	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 7	| cnt: 1	|Loss: 0.0005998504161834717	| precision: 0.5248	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 8	| cnt: 1	|Loss: 0.0005783680677413941	| precision: 0.5709	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 9	| cnt: 1	|Loss: 0.0006060991883277893	| precision: 0.5017	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 10	| cnt: 1	|Loss: 0.0006409665942192078	| precision: 0.5180	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 11	| cnt: 1	|Loss: 0.0005791863799095154	| precision: 0.5798	| time: 0.41
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 12	| cnt: 1	|Loss: 0.0006020164489746094	| precision: 0.5213	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 13	| cnt: 1	|Loss: 0.0006534802317619324	| precision: 0.4695	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 14	| cnt: 1	|Loss: 0.0005927158594131469	| precision: 0.5576	| time: 0.31
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 15	| cnt: 1	|Loss: 0.0005723985433578492	| precision: 0.5594	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 16	| cnt: 1	|Loss: 0.0006252151131629944	| precision: 0.5167	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 17	| cnt: 1	|Loss: 0.0006041809320449829	| precision: 0.5152	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 18	| cnt: 1	|Loss: 0.0005607065558433532	| precision: 0.5461	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 19	| cnt: 1	|Loss: 0.0005538775324821472	| precision: 0.5318	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 20	| cnt: 1	|Loss: 0.0006020679473876953	| precision: 0.5215	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 21	| cnt: 1	|Loss: 0.0005866385102272034	| precision: 0.5258	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 22	| cnt: 1	|Loss: 0.0005903868675231934	| precision: 0.5310	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 23	| cnt: 1	|Loss: 0.0005873569250106811	| precision: 0.5385	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 24	| cnt: 1	|Loss: 0.0006153668165206909	| precision: 0.5361	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 25	| cnt: 1	|Loss: 0.0005415284633636475	| precision: 0.5382	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 26	| cnt: 1	|Loss: 0.0005852963924407959	| precision: 0.5364	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 27	| cnt: 1	|Loss: 0.0006130174398422242	| precision: 0.4966	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 28	| cnt: 1	|Loss: 0.0005821158289909362	| precision: 0.5256	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 29	| cnt: 1	|Loss: 0.0005855637192726136	| precision: 0.5051	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 30	| cnt: 1	|Loss: 0.0005689348578453064	| precision: 0.5276	| time: 0.32
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 31	| cnt: 1	|Loss: 0.0005971667766571045	| precision: 0.5417	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 32	| cnt: 1	|Loss: 0.0005990300178527832	| precision: 0.5343	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 33	| cnt: 1	|Loss: 0.0005350791215896607	| precision: 0.5580	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 34	| cnt: 1	|Loss: 0.0005607466697692871	| precision: 0.5199	| time: 0.48
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 35	| cnt: 1	|Loss: 0.0005318915843963623	| precision: 0.5696	| time: 0.46
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 36	| cnt: 1	|Loss: 0.0006212589740753173	| precision: 0.4982	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 37	| cnt: 1	|Loss: 0.00053653484582901	| precision: 0.5493	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 38	| cnt: 1	|Loss: 0.0005994047522544861	| precision: 0.5697	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 39	| cnt: 1	|Loss: 0.0006321800947189331	| precision: 0.5153	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 40	| cnt: 1	|Loss: 0.0005846552848815918	| precision: 0.5475	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 41	| cnt: 1	|Loss: 0.0005630881190299987	| precision: 0.5618	| time: 0.32
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 42	| cnt: 1	|Loss: 0.000582395851612091	| precision: 0.5150	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 43	| cnt: 1	|Loss: 0.0005837717056274414	| precision: 0.5222	| time: 0.36
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 44	| cnt: 1	|Loss: 0.0006117733120918273	| precision: 0.5229	| time: 0.32
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 45	| cnt: 1	|Loss: 0.0005559228062629699	| precision: 0.5724	| time: 0.34
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 46	| cnt: 1	|Loss: 0.0006410897374153137	| precision: 0.4904	| time: 0.33
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 47	| cnt: 1	|Loss: 0.00056022047996521	| precision: 0.5277	| time: 0.37
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 48	| cnt: 1	|Loss: 0.0005666314959526062	| precision: 0.5316	| time: 0.35
** ** * Saving fine - tuned model ** ** * 


  truepred = logs(torch.tensor(truepred))


epoch: 49	| cnt: 1	|Loss: 0.0005858151316642761	| precision: 0.5366	| time: 0.34
** ** * Saving fine - tuned model ** ** * 
