In [None]:
import torch
import numpy as np

import torch.nn as nn
import time
import math
import sklearn.metrics as skm

from torch.utils.data import Dataset
import pandas as pd
from models.PreTrainingModel import  *

In [None]:
src = pd.read_csv("data/tokenized_visits.csv", header=None)
age_data = pd.read_csv("data/tokenized_age.csv", header=None)
gender_data = pd.read_csv("data/tokenized_gender.csv", header=None)
race_data = pd.read_csv("data/tokenized_race.csv", header=None)
ethnicity_data = pd.read_csv("data/tokenized_ethnicity.csv", header=None)
time_date = pd.read_csv("data/tokenized_timestamps.csv", header=None)
vocab_df = pd.read_csv("data/vocab.csv", header=0)

In [None]:
vocab = vocab_df.to_dict()['Unnamed: 0']
vocab = {v: k for k, v in vocab.items()}
vocab['MASK'] = int(max(vocab.values()))+1

train_l = int(len(src)*0.70)
val_l = int(len(src)*0.10)
test_l = len(src) - val_l - train_l

In [None]:
global_params = {
    'max_seq_len': src.shape[1]
}

optim_param = {
    'lr_discr': 3e-5,
    'lr_gen': 3e-5
}

train_params = {
    'batch_size': 32,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'data_len' : len(src),
    'train_data_len' : train_l,
    'val_data_len' : val_l,
    'test_data_len' : test_l,
    'epochs' : 35,
    'action' : 'train',
}

model_config = {
    'vocab_size': 708, # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 111, # number of vocab for age embedding
    'gender_vocab_size': 2,
    'ethnicity_vocab_size': 2,
    'race_vocab_size': 6,
    'num_labels':1,
    'feature_dict':708,
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : 1
}

In [None]:
train_code = src.values[:train_l]
val_code = src.values[train_l:train_l + val_l]
test_code = src.values[train_l + val_l:]

train_age = age_data.values[:train_l]
val_age = age_data.values[train_l:train_l + val_l]
test_age = age_data.values[train_l + val_l:]

train_gender = gender_data.values[:train_l]
val_gender = gender_data.values[train_l:train_l + val_l]
test_gender= gender_data.values[train_l + val_l:]


train_time = time_date.values[:train_l]
val_time = time_date.values[train_l:train_l + val_l]
test_time = time_date.values[train_l + val_l:]

train_data = {"code":train_code, "age":train_age, "gender":train_gender, "time":train_time}
val_data = {"code":val_code, "age":val_age, "gender":val_gender, "time":val_time}
test_data = {"code":test_code,  "age":test_age, "gender":test_gender, "time":test_time}

In [None]:
conf = BertConfig(model_config)
bert = BertForMLM(conf)
bert = bert.to(train_params['device'])

bert_vars = [i for i in bert.parameters()]
optim = torch.optim.Adam(bert_vars, lr=optim_param['lr_discr'])

In [None]:
def run_epoch(e, trainload, device):
    tr_loss = 0
    start = time.time()
    bert.train()
    for step, batch in enumerate(trainload):
        optim.zero_grad()
        batch = tuple(t for t in batch)
        input_ids, age_ids, gender_ids, time_ids, posi_ids, segment_ids, attMask, masked_label = batch
        input_ids=input_ids.to(device)
        age_ids=age_ids.to(device)
        gender_ids=gender_ids.to(device)
        time_ids=time_ids.to(device)
        posi_ids=posi_ids.to(device)
        segment_ids=segment_ids.to(device)
        attMask=attMask.to(device)
        masked_label=masked_label.to(device)

        loss, pred, label = bert(input_ids, age_ids, gender_ids, time_ids, posi_ids, segment_ids, attention_mask=attMask, masked_lm_labels=masked_label)
        loss.backward()
        tr_loss += loss.item()
        if step%500 == 0:
            print(cal_acc(label, pred))
        optim.step()
        del loss, pred, label
    cost = time.time() - start
    return tr_loss, cost

In [None]:
def train(trainload, valload, device):
    with open("log_pre_train.txt", 'w') as f:
            f.write('')
    best_val = math.inf
    for e in range(train_params["epochs"]):
        print("Epoch n" + str(e))
        train_loss, train_time_cost = run_epoch(e, trainload, device)
        val_loss, val_time_cost,pred, label = eval(valload, device)
        train_loss = train_loss / math.ceil((train_params["train_data_len"]/train_params['batch_size']))
        val_loss = val_loss / math.ceil((train_params["val_data_len"]/train_params['batch_size']))
        print('TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
        with open("log_pre_train.txt", 'a') as f:
            f.write("Epoch n" + str(e) + '\n TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost) + '\n\n\n')
            f.write('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost) + '\n\n\n')
        print('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost))
        if val_loss < best_val:
          print("** ** * Saving pre - trained model ** ** * ")
          model_to_save = bert.module if hasattr(bert, 'module') else bert
          save_model(model_to_save.state_dict(), 'bert_pretrain')
          best_val = val_loss
    return train_loss, val_loss

In [None]:
def eval(_valload, device):
    tr_loss = 0
    start = time.time()
    bert.eval()
    for step, batch in enumerate(_valload):
        batch = tuple(t for t in batch)
        input_ids, age_ids, gender_ids, time_ids, posi_ids, segment_ids, attMask, masked_label = batch
        input_ids=input_ids.to(device)
        age_ids=age_ids.to(device)
        gender_ids=gender_ids.to(device)
        time_ids=time_ids.to(device)
        posi_ids=posi_ids.to(device)
        segment_ids=segment_ids.to(device)
        attMask=attMask.to(device)
        masked_label=masked_label.to(device)

        loss, pred, label = bert(input_ids, age_ids, gender_ids, time_ids, posi_ids, segment_ids, attention_mask=attMask, masked_lm_labels=masked_label)

        tr_loss += loss.item()
        del loss

    cost = time.time() - start
    return tr_loss, cost, pred, label

In [None]:
def cal_acc(label, pred):
    logs = nn.LogSoftmax(dim=1)
    label=label.cpu().numpy()
    ind = np.where(label!=-1)[0]
    truepred = pred.detach().cpu().numpy()
    truepred = truepred[ind]
    truelabel = label[ind]
    truepred = logs(torch.tensor(truepred))
    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]
    precision = skm.precision_score(truelabel, outs, average='micro')
    return precision

In [None]:
def save_model(_model_dict, file_name):
    torch.save(_model_dict, file_name)

In [None]:
if train_params['action'] == 'resume' or train_params['action'] == 'eval':
    bert.load_state_dict(torch.load("bert_pretrain", map_location=train_params['device']))
    print("Loading Successful")
if train_params['action'] == 'train' or train_params['action'] == 'resume':
    TrainDset = DataLoader(train_data, vocab, max_len=train_params['max_len_seq'], code='code')
    trainload = torch.utils.data.DataLoader(dataset=TrainDset, batch_size=train_params['batch_size'], shuffle=True)
    ValDset = DataLoader(val_data, vocab, max_len=train_params['max_len_seq'], code='code')
    valload = torch.utils.data.DataLoader(dataset=ValDset, batch_size=train_params['batch_size'], shuffle=True)
    train_loss, val_loss = train(trainload, valload, train_params['device'])
elif train_params['action'] == 'eval':
    TestDset = DataLoader(test_data, vocab, max_len=train_params['max_len_seq'], code='code')
    testload = torch.utils.data.DataLoader(dataset=TestDset, batch_size=int(32), shuffle=True)
    loss, cost, pred, label = eval(testload)