In [1]:
from utils.dataset import EHRDataset
from model.tokenizer import EHRTokenizer
import pytorch_pretrained_bert as Bert
from torch.utils.data import DataLoader
from model.model import *
from utils.config import BertConfig

In [2]:
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [3]:
def adam(params, config=None):
    if config is None:
        config = {
            'lr': 3e-5,
            'warmup_proportion': 0.1,
            'weight_decay': 0.01
        }
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0}
    ]

    optim = Bert.optimization.BertAdam(optimizer_grouped_parameters,
                                       lr=config['lr'],
                                       warmup=config['warmup_proportion'])
    return optim

In [4]:
#path = 'C:/Users/Johan/Documents/Skola/MasterThesis/Master-thesis/pre-processing/combined-csv-files.csv'
path = 'processing/dataframe.parquet'

In [5]:
class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [36]:
global_params = {
    'max_seq_len': 32,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 5,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 10,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': 'cuda:0' #change this to run on cuda #'cuda:0'
}

In [37]:
data = pd.read_parquet(path)

In [38]:
len(data)

107704

In [39]:
data.head()

Unnamed: 0,subject_id,icd_code,age
0,10028314,"[Z3800, P2912, Z23, Q620, Z051, Z412, P284, P9...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,10052351,"[R0789, F10129, SEP]","[56, 56, 56]"
2,10092012,"[Z051, Z23, Z3800, SEP]","[0, 0, 0, 0]"
3,10092020,"[Z87891, Z8546, Z7901, G4089, I4820, Z8673, E8...","[69, 69, 69, 69, 69, 69, 69, 69, 69]"
4,10126895,"[Z30430, O80, Z3A39, Z370, SEP]","[24, 24, 24, 24, 24]"


In [40]:
data['code_len'] = data['icd_code'].apply(lambda x: len(x))
data['age_len'] = data['age'].apply(lambda x: len(x))

In [41]:
data

Unnamed: 0,subject_id,icd_code,age,code_len,age_len
0,10028314,"[Z3800, P2912, Z23, Q620, Z051, Z412, P284, P9...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",15,15
1,10052351,"[R0789, F10129, SEP]","[56, 56, 56]",3,3
2,10092012,"[Z051, Z23, Z3800, SEP]","[0, 0, 0, 0]",4,4
3,10092020,"[Z87891, Z8546, Z7901, G4089, I4820, Z8673, E8...","[69, 69, 69, 69, 69, 69, 69, 69, 69]",9,9
4,10126895,"[Z30430, O80, Z3A39, Z370, SEP]","[24, 24, 24, 24, 24]",5,5
...,...,...,...,...,...
107699,19837828,"[Z3800, Z23, SEP]","[0, 0, 0]",3,3
107700,19910693,"[I2510, D638, Z8619, J440, I739, I10, R0902, Z...","[64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 6...",48,48
107701,19963063,"[O081, D62, Z3A01, O99011, K661, O00102, SEP]","[35, 35, 35, 35, 35, 35, 35]",7,7
107702,19979982,"[Y92239, I70201, M109, Z006, N400, I120, I2510...","[83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 8...",14,14


In [42]:
traindata = data[data['code_len'] < 3]

In [43]:
traindata

Unnamed: 0,subject_id,icd_code,age,code_len,age_len
64,11187691,"[R0789, SEP]","[63, 63]",2,2
117,12252082,"[Z3801, SEP]","[0, 0]",2,2
194,13511025,"[J45901, SEP]","[51, 51]",2,2
266,14762362,"[Z3800, SEP]","[0, 0]",2,2
271,14840492,"[F10129, SEP]","[19, 19]",2,2
...,...,...,...,...,...
107195,19779819,"[F10129, SEP]","[23, 23]",2,2
107490,15663814,"[M5117, SEP]","[52, 52]",2,2
107636,18711078,"[G610, SEP]","[33, 33]",2,2
107671,19484642,"[F10129, SEP]","[50, 50]",2,2


In [44]:
tokenizer = EHRTokenizer()

In [45]:
from sklearn.model_selection import train_test_split

In [46]:
MLMdata = data[:len(data) // 2]

In [47]:
Dset = EHRDataset(MLMdata, max_len=train_params['max_len_seq'], tokenizer=tokenizer)
trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True)

In [48]:
model_config = {
    'vocab_size': len(tokenizer.getVoc('code').keys()), # number of disease + symbols for word embedding
    'hidden_size': 50, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(tokenizer.getVoc('age').keys()), # number of vocab for age embedding
    'max_position_embeddings': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.1, # dropout rate
    'num_hidden_layers': 2, # number of multi-head attention layers required
    'num_attention_heads': 2, # number of attention heads
    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate
    'intermediate_size': 50, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

In [49]:
conf = BertConfig(**model_config)

In [50]:
model = BertForMaskedLM(conf)

{
  "age_vocab_size": 78,
  "attention_probs_dropout_prob": 0.1,
  "graph": false,
  "graph_heads": 4,
  "graph_hidden_size": 75,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 50,
  "initializer_range": 0.02,
  "intermediate_size": 50,
  "max_position_embeddings": 32,
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "seg_vocab_size": 2,
  "type_vocab_size": 2,
  "vocab_size": 17662
}



In [51]:
len(trainload)

5386

In [52]:
model = model.to(train_params['device'])
optim = adam(params=list(model.named_parameters()), config=optim_param)

t_total value of -1 results in schedule not being applied


In [53]:
next(model.parameters()).device

device(type='cuda', index=0)

In [56]:
for epoch in range(10):
    loss_ = 0
    for step, batch in enumerate(trainload): 
        #batch = batch.to(train_params['device'])
        batch = tuple(t.to(train_params['device']) for t in batch)
        age_ids, input_ids, posi_ids, segment_ids, attMask, labels = batch 
        loss, pred, label = model(input_ids, age_ids = age_ids, seg_ids = segment_ids, posi_ids = posi_ids,attention_mask=attMask, masked_lm_labels=labels)
        loss.backward()
        
        
        if step % 200==0:
            print("step: {}, len(trainload): {}, epoch: {}, loss: {} ".format(step, len(trainload), epoch, loss.item()))
            
        loss_ += loss.item()

        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()
        
    print("Average loss {} after epoch {}".format(loss_ / len(trainload), epoch))    

step: 0, len(trainload): 5386, epoch: 0, loss: 0.28796258568763733 
step: 200, len(trainload): 5386, epoch: 0, loss: 0.8747056126594543 
step: 400, len(trainload): 5386, epoch: 0, loss: 0.5510188341140747 
step: 600, len(trainload): 5386, epoch: 0, loss: 0.47993749380111694 
step: 800, len(trainload): 5386, epoch: 0, loss: 0.7583121657371521 
step: 1000, len(trainload): 5386, epoch: 0, loss: 0.3105211853981018 
step: 1200, len(trainload): 5386, epoch: 0, loss: 0.7760886549949646 
step: 1400, len(trainload): 5386, epoch: 0, loss: 1.696117877960205 
step: 1600, len(trainload): 5386, epoch: 0, loss: 0.45019176602363586 
step: 1800, len(trainload): 5386, epoch: 0, loss: 0.21406041085720062 
step: 2000, len(trainload): 5386, epoch: 0, loss: 0.15439435839653015 
step: 2200, len(trainload): 5386, epoch: 0, loss: 0.03685026615858078 
step: 2400, len(trainload): 5386, epoch: 0, loss: 0.2389022260904312 
step: 2600, len(trainload): 5386, epoch: 0, loss: 0.2813948392868042 
step: 2800, len(trainl

In [78]:
torch.cuda.is_available()

True

In [65]:
PATH = 'checkpoint'

In [66]:
torch.save(model.state_dict(), PATH)

In [None]:
#model = TheModelClass(*args, **kwargs)
#model.load_state_dict(torch.load(PATH))
#model.eval()