In [1]:
import sys 
sys.path.insert(0, '../')

In [2]:
from common.common import create_folder
from common.pytorch import load_model
import pytorch_pretrained_bert as Bert
from model.utils import age_vocab
from common.common import load_obj
from dataLoader.MLM import MLMLoader
from torch.utils.data import DataLoader
import pandas as pd
from model.MLM import BertForMaskedLM
from model.optimiser import adam
import sklearn.metrics as skm
import numpy as np
import torch
import time
import torch.nn as nn
import os

In [3]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        
class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [4]:
file_config = {
    'vocab':'../data/dict',  # vocabulary idx2token, token2idx
    'data': '',  # formated data 
    'model_path': '../saved_model', # where to save model
    'model_name': 'hamed', # model name
    'file_name': 'log',  # log path
}
create_folder(file_config['model_path'])

In [5]:
global_params = {
    'max_seq_len': 64,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 5,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 256,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': 'cpu'
}

In [6]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [7]:
#data = pd.read_parquet(file_config['data'])
# remove patients with visits less than min visit
#data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
#data = data[data['length'] >= global_params['min_visit']]
#data = data.reset_index(drop=True)

In [8]:
code_df = pd.read_csv("../data/codes.csv", dtype=str).T.apply(lambda x: x.dropna().tolist()).tolist()

age_df = pd.read_csv("../data/ages.csv", dtype=str).T.apply(lambda x: x.dropna().tolist()).tolist()

data = {"code": code_df, "age": age_df}

In [11]:
Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'])
trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)

In [12]:
model_config = {
    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.1, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

In [13]:
conf = BertConfig(model_config)
model = BertForMaskedLM(conf)

In [14]:
model = model.to(train_params['device'])
optim = adam(params=list(model.named_parameters()), config=optim_param)

t_total value of -1 results in schedule not being applied


In [15]:
def cal_acc(label, pred):
    logs = nn.LogSoftmax()
    label=label.cpu().numpy()
    ind = np.where(label!=-1)[0]
    truepred = pred.detach().cpu().numpy()
    truepred = truepred[ind]
    truelabel = label[ind]
    truepred = logs(torch.tensor(truepred))
    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
    precision = skm.precision_score(truelabel, outs, average='micro')
    return precision

In [16]:
def train(e, loader):
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt= 0
    start = time.time()

    for step, batch in enumerate(loader):
        cnt +=1
        batch = tuple(t.to(train_params['device']) for t in batch)
        age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch
        loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)
        if global_params['gradient_accumulation_steps'] >1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()
        
        temp_loss += loss.item()
        tr_loss += loss.item()
        
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1
        
        if step % 200==0:
            print("epoch: {}\t| cnt: {}\t|Loss: {}\t| precision: {:.4f}\t| time: {:.2f}".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))
            temp_loss = 0
            start = time.time()
            
        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()

    print("** ** * Saving fine - tuned model ** ** * ")
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    create_folder(file_config['model_path'])
    output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])

    torch.save(model_to_save.state_dict(), output_model_file)
        
    cost = time.time() - start
    return tr_loss, cost

In [17]:
f = open(os.path.join(file_config['model_path'], file_config['file_name']), "w")
f.write('{}\t{}\t{}\n'.format('epoch', 'loss', 'time'))
for e in range(50):
    loss, time_cost = train(e, trainload)
    loss = loss#/data_len
    f.write('{}\t{}\t{}\n'.format(e, loss, time_cost))
f.close()

  
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\python_arg_parser.cpp:1055.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


epoch: 0	| cnt: 1	|Loss: 0.0021208405494689943	| precision: 0.0065	| time: 19.47
** ** * Saving fine - tuned model ** ** * 


  


epoch: 1	| cnt: 1	|Loss: 0.0019477716684341432	| precision: 0.1445	| time: 13.68
** ** * Saving fine - tuned model ** ** * 


  


epoch: 2	| cnt: 1	|Loss: 0.0016545246839523315	| precision: 0.5000	| time: 12.12
** ** * Saving fine - tuned model ** ** * 


  


epoch: 3	| cnt: 1	|Loss: 0.0014064831733703613	| precision: 0.5338	| time: 11.74
** ** * Saving fine - tuned model ** ** * 


  


epoch: 4	| cnt: 1	|Loss: 0.0011123137474060058	| precision: 0.6131	| time: 11.56
** ** * Saving fine - tuned model ** ** * 


  


epoch: 5	| cnt: 1	|Loss: 0.001025469183921814	| precision: 0.5714	| time: 11.57
** ** * Saving fine - tuned model ** ** * 


  


epoch: 6	| cnt: 1	|Loss: 0.0009439226388931275	| precision: 0.5481	| time: 11.62
** ** * Saving fine - tuned model ** ** * 


  


epoch: 7	| cnt: 1	|Loss: 0.0008361288905143738	| precision: 0.5755	| time: 11.62
** ** * Saving fine - tuned model ** ** * 


  


epoch: 8	| cnt: 1	|Loss: 0.0007327263951301575	| precision: 0.6045	| time: 11.70
** ** * Saving fine - tuned model ** ** * 


  


epoch: 9	| cnt: 1	|Loss: 0.0007373352646827698	| precision: 0.5547	| time: 12.98
** ** * Saving fine - tuned model ** ** * 


  


epoch: 10	| cnt: 1	|Loss: 0.0006860011219978333	| precision: 0.5380	| time: 12.05
** ** * Saving fine - tuned model ** ** * 


  


epoch: 11	| cnt: 1	|Loss: 0.0006420586705207825	| precision: 0.5588	| time: 11.78
** ** * Saving fine - tuned model ** ** * 


  


epoch: 12	| cnt: 1	|Loss: 0.0006681892871856689	| precision: 0.5260	| time: 11.57
** ** * Saving fine - tuned model ** ** * 


  


epoch: 13	| cnt: 1	|Loss: 0.0006044732928276062	| precision: 0.5625	| time: 11.87
** ** * Saving fine - tuned model ** ** * 


  


epoch: 14	| cnt: 1	|Loss: 0.0006026567220687866	| precision: 0.6235	| time: 11.58
** ** * Saving fine - tuned model ** ** * 


  


epoch: 15	| cnt: 1	|Loss: 0.0006538300514221192	| precision: 0.5704	| time: 12.26
** ** * Saving fine - tuned model ** ** * 


  


epoch: 16	| cnt: 1	|Loss: 0.000535943865776062	| precision: 0.5938	| time: 11.80
** ** * Saving fine - tuned model ** ** * 


  


epoch: 17	| cnt: 1	|Loss: 0.0005655509829521179	| precision: 0.5864	| time: 12.07
** ** * Saving fine - tuned model ** ** * 


  


epoch: 18	| cnt: 1	|Loss: 0.0005537709593772889	| precision: 0.5833	| time: 11.78
** ** * Saving fine - tuned model ** ** * 


  


epoch: 19	| cnt: 1	|Loss: 0.0006084786057472229	| precision: 0.5181	| time: 12.66
** ** * Saving fine - tuned model ** ** * 


  


epoch: 20	| cnt: 1	|Loss: 0.0005315003991127014	| precision: 0.6099	| time: 11.67
** ** * Saving fine - tuned model ** ** * 


  


epoch: 21	| cnt: 1	|Loss: 0.0005871798396110535	| precision: 0.4714	| time: 11.84
** ** * Saving fine - tuned model ** ** * 


  


epoch: 22	| cnt: 1	|Loss: 0.0005244066715240479	| precision: 0.5664	| time: 11.96
** ** * Saving fine - tuned model ** ** * 


  


epoch: 23	| cnt: 1	|Loss: 0.0005276649594306946	| precision: 0.6053	| time: 12.31
** ** * Saving fine - tuned model ** ** * 


  


epoch: 24	| cnt: 1	|Loss: 0.0005376789569854736	| precision: 0.5190	| time: 11.90
** ** * Saving fine - tuned model ** ** * 


  


epoch: 25	| cnt: 1	|Loss: 0.0005155559182167054	| precision: 0.5346	| time: 11.88
** ** * Saving fine - tuned model ** ** * 


  


epoch: 26	| cnt: 1	|Loss: 0.0005239549279212951	| precision: 0.5425	| time: 11.81
** ** * Saving fine - tuned model ** ** * 


  


epoch: 27	| cnt: 1	|Loss: 0.0004982173442840577	| precision: 0.5821	| time: 11.68
** ** * Saving fine - tuned model ** ** * 


  


epoch: 28	| cnt: 1	|Loss: 0.0004697284698486328	| precision: 0.6042	| time: 11.73
** ** * Saving fine - tuned model ** ** * 


  


epoch: 29	| cnt: 1	|Loss: 0.00046311426162719727	| precision: 0.5828	| time: 11.74
** ** * Saving fine - tuned model ** ** * 


  


epoch: 30	| cnt: 1	|Loss: 0.00046492478251457215	| precision: 0.5793	| time: 11.71
** ** * Saving fine - tuned model ** ** * 


  


epoch: 31	| cnt: 1	|Loss: 0.00046083837747573853	| precision: 0.5782	| time: 11.64
** ** * Saving fine - tuned model ** ** * 


  


epoch: 32	| cnt: 1	|Loss: 0.0004800068438053131	| precision: 0.4834	| time: 11.67
** ** * Saving fine - tuned model ** ** * 


  


epoch: 33	| cnt: 1	|Loss: 0.00046178820729255674	| precision: 0.5414	| time: 11.57
** ** * Saving fine - tuned model ** ** * 
epoch: 34	| cnt: 1	|Loss: 0.0004517430365085602	| precision: 0.5411	| time: 12.96


  


** ** * Saving fine - tuned model ** ** * 


  


epoch: 35	| cnt: 1	|Loss: 0.00046816226840019226	| precision: 0.5374	| time: 12.55
** ** * Saving fine - tuned model ** ** * 


  


epoch: 36	| cnt: 1	|Loss: 0.0004447632431983948	| precision: 0.5733	| time: 11.63
** ** * Saving fine - tuned model ** ** * 


  


epoch: 37	| cnt: 1	|Loss: 0.0004382834732532501	| precision: 0.5221	| time: 11.70
** ** * Saving fine - tuned model ** ** * 


  


epoch: 38	| cnt: 1	|Loss: 0.00041979166865348813	| precision: 0.6183	| time: 11.54
** ** * Saving fine - tuned model ** ** * 


  


epoch: 39	| cnt: 1	|Loss: 0.0004017190337181091	| precision: 0.5924	| time: 11.51
** ** * Saving fine - tuned model ** ** * 


  


epoch: 40	| cnt: 1	|Loss: 0.0004682496190071106	| precision: 0.5379	| time: 11.47
** ** * Saving fine - tuned model ** ** * 


  


epoch: 41	| cnt: 1	|Loss: 0.00045691749453544617	| precision: 0.5260	| time: 11.44
** ** * Saving fine - tuned model ** ** * 


  


epoch: 42	| cnt: 1	|Loss: 0.00044196394085884095	| precision: 0.5517	| time: 11.88
** ** * Saving fine - tuned model ** ** * 


  


epoch: 43	| cnt: 1	|Loss: 0.00038607969880104063	| precision: 0.6026	| time: 11.58
** ** * Saving fine - tuned model ** ** * 


  


epoch: 44	| cnt: 1	|Loss: 0.0004116646945476532	| precision: 0.6202	| time: 11.53
** ** * Saving fine - tuned model ** ** * 


  


epoch: 45	| cnt: 1	|Loss: 0.0004062600135803223	| precision: 0.5606	| time: 11.55
** ** * Saving fine - tuned model ** ** * 


  


epoch: 46	| cnt: 1	|Loss: 0.0004019186496734619	| precision: 0.6074	| time: 11.63
** ** * Saving fine - tuned model ** ** * 


  


epoch: 47	| cnt: 1	|Loss: 0.0004021068513393402	| precision: 0.5821	| time: 11.57
** ** * Saving fine - tuned model ** ** * 


  


epoch: 48	| cnt: 1	|Loss: 0.0004256500005722046	| precision: 0.5035	| time: 11.58
** ** * Saving fine - tuned model ** ** * 


  


epoch: 49	| cnt: 1	|Loss: 0.00038992014527320864	| precision: 0.6194	| time: 11.47
** ** * Saving fine - tuned model ** ** * 
