In [None]:
import torch
import numpy as np

import torch.nn as nn
import time
import math

from torch.utils.data import Dataset
import pandas as pd
from models.FineTuningModel import  *

In [None]:
src = pd.read_csv("data/tokenized_visits.csv", header=None)
age_data = pd.read_csv("data/tokenized_age.csv", header=None)
gender_data = pd.read_csv("data/tokenized_gender.csv", header=None)
race_data = pd.read_csv("data/tokenized_race.csv", header=None)
ethnicity_data = pd.read_csv("data/tokenized_ethnicity.csv", header=None)
time_date = pd.read_csv("data/tokenized_timestamps.csv", header=None)
mask_data = pd.read_csv("data/tokenized_masks.csv", header=None)
target_data = pd.read_csv("data/target.csv", header=None)

In [None]:
train_l = int(len(src)*0.70)
val_l = int(len(src)*0.10)
test_l = len(src) - val_l - train_l
number_output = target_data.shape[1]

In [None]:
global_params = {
    'max_seq_len': src.shape[1]
}

optim_param = {
    'lr_discr': 3e-5,
    'lr_gen': 3e-5
}

train_params = {
    'batch_size': 32,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'data_len' : len(target_data),
    'train_data_len' : train_l,
    'val_data_len' : val_l,
    'test_data_len' : test_l,
    'epochs' : 35,
    'action' : 'train',
    'alpha_unsup':0.35
}

model_config = {
    'vocab_size': 708, # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 111, # number of vocab for age embedding
    'gender_vocab_size': 2,
    'ethnicity_vocab_size': 2,
    'race_vocab_size': 6,
    'num_labels':1,
    'feature_dict':708,
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : number_output
}

In [None]:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, random_state=None)

k = 5
i = 1
few_shots = 1

for train_index, test_index in kf.split(src):

  amount_few_shots = round(train_l * few_shots)
  val_index = train_index[train_l:]
  train_index= train_index[:train_l]
  train_index = train_index[np.random.choice(len(train_index), size=amount_few_shots, replace=False)]

  train_code = src.values[train_index]
  val_code = src.values[val_index]
  test_code = src.values[test_index]

  train_age = age_data.values[train_index]
  val_age = age_data.values[val_index]
  test_age = age_data.values[test_index]

  train_gender = gender_data.values[train_index]
  val_gender = gender_data.values[val_index]
  test_gender= gender_data.values[test_index]

  train_ethnicity = ethnicity_data.values[train_index]
  val_ethnicity = ethnicity_data.values[val_index]
  test_ethnicity = ethnicity_data.values[test_index]

  train_race = race_data.values[train_index]
  val_race = race_data.values[val_index]
  test_race = race_data.values[test_index]

  train_time = time_date.values[train_index]
  val_time = time_date.values[val_index]
  test_time = time_date.values[test_index]

  train_labels = target_data.values[train_index]
  val_labels = target_data.values[val_index]
  test_labels = target_data.values[test_index]

  train_masks = mask_data.values[train_index]
  val_masks = mask_data.values[val_index]
  test_masks= mask_data.values[test_index]

  if i == k:
    train_data = {"code":train_code, "age":train_age, "gender":train_gender, "ethnicity":train_ethnicity, "race":train_race, "time":train_time, "labels":train_labels, "masks":train_masks}
    val_data = {"code":val_code, "age":val_age, "gender":val_gender, "ethnicity":val_ethnicity, "race":val_race, "time":val_time, "labels":val_labels, "masks":val_masks}
    test_data = {"code":test_code,  "age":test_age, "gender":test_gender, "ethnicity":test_ethnicity, "race":test_race,"time":test_time, "labels":test_labels, "masks":test_masks}
    break
  i+=1

In [None]:
noise_size=100
hidden_size=288
hidden_levels_d=[288, 288]
hidden_levels_g=[288, 288]
out_dropout_rate = 0.2

conf = BertConfig(model_config)
bert = BertForEHR(conf)
generator = Generator(noise_size=noise_size, output_size=hidden_size, hidden_sizes=hidden_levels_g, dropout_rate=out_dropout_rate)
discriminator = Discr(input_size=hidden_size, hidden_sizes=hidden_levels_d, num_labels=number_output, dropout_rate=out_dropout_rate)

In [None]:
discriminator = discriminator.to(train_params['device'])
generator = generator.to(train_params['device'])
bert = bert.to(train_params['device'])

#models parameters
transformer_vars = [i for i in bert.parameters()]
d_vars = transformer_vars + [v for v in discriminator.parameters()]
g_vars = [v for v in generator.parameters()]

#optimizer
optim_disc_bert = torch.optim.Adam(d_vars, lr=optim_param['lr_discr'])
optim_gen = torch.optim.AdamW(g_vars, lr=optim_param['lr_gen'])


bce_loss = nn.BCELoss()
bce_logits_loss = nn.BCEWithLogitsLoss(reduction='none')

In [None]:
def run_epoch(e, trainload, device):
    tr_loss = 0
    start = time.time()
    bert.train()
    discriminator.train()
    generator.train()
    for step, batch in enumerate(trainload):
        batch = tuple(t for t in batch)
        input_ids, age_ids, gender_ids, ethnicity_ids, race_ids, time_ids, posi_ids, segment_ids, attMask, labels, masks = batch

        input_ids=input_ids.to(device)
        age_ids=age_ids.to(device)
        gender_ids=gender_ids.to(device)
        ethnicity_ids=ethnicity_ids.to(device)
        race_ids=race_ids.to(device)
        time_ids=time_ids.to(device)
        posi_ids=posi_ids.to(device)
        segment_ids=segment_ids.to(device)
        attMask=attMask.to(device)
        labels=labels.to(device)
        labels = torch.squeeze(labels, 1)
        masks=masks.to(device)
        masks = torch.squeeze(masks, 1)

        output_behrt = behrt(input_ids, age_ids, gender_ids, ethnicity_ids, race_ids, time_ids, posi_ids, segment_ids,attention_mask=attMask, labels=labels)
        output_gen = generator(torch.randn(output_behrt.shape[0], noise_size).to(device))
        discr_input = torch.cat([output_behrt, output_gen], dim=0)
        features, logits, probs = discriminator(discr_input)

        features_list = torch.split(features, output_behrt.shape[0])
        D_real_features = features_list[0]
        D_fake_features = features_list[1]

        logits_list = torch.split(logits, output_behrt.shape[0])
        logits = logits_list[0]
        logits = logits[:,0:-1]

        probs_list = torch.split(probs, output_behrt.shape[0])
        D_real_probs = probs_list[0]
        D_fake_probs = probs_list[1]

        discr_loss_real = bce_loss(D_real_probs[:, -1], torch.ones(output_behrt.shape[0]).to(device))
        discr_loss_fake = bce_loss(D_fake_probs[:, -1], torch.zeros(output_behrt.shape[0]).to(device))
        discr_unsupervised_loss = (discr_loss_real + discr_loss_fake) / 2

        masked_lm_loss = bce_logits_loss(logits, labels)
        masked_lm_loss = torch.mul(masked_lm_loss, masks)
        masked_lm_loss = torch.div(masked_lm_loss.sum(dim=0), masks.sum(dim=0) + 0.001)
        discr_supervised_loss = torch.div(torch.sum(masked_lm_loss), masks.shape[1])

        discr_loss = discr_supervised_loss + train_params['alpha_unsup'] * discr_unsupervised_loss



        g_loss_d = bce_loss(D_fake_probs[:, -1], torch.ones(output_behrt.shape[0]).to(device))
        g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0), 2))
        g_loss = g_loss_d + g_feat_reg

        optim_gen.zero_grad()
        optim_disc_behrt.zero_grad()

        g_loss.backward(retain_graph=True)
        discr_loss.backward()

        optim_gen.step()
        optim_disc_behrt.step()

        loss= g_loss.item() + discr_unsupervised_loss.item() + discr_supervised_loss.item()
        tr_loss += loss

        if step%500 == 0:
            print("Generator Loss:", g_loss)
            print("Discr Supervised Loss:", discr_supervised_loss)
            print("Discr Unsupervised Loss:", discr_unsupervised_loss)

            print("TOTAL LOSS", loss)
    cost = time.time() - start
    return tr_loss, cost

In [None]:
def eval(_valload, eval, device):
    bert.eval()
    discriminator.eval()
    generator.eval()
    tr_loss = 0
    tr_d_sup = 0
    start = time.time()
    
    if eval:
        with open("preds.csv", 'w') as f:
            f.write('')
        with open("labels.csv", 'w') as f:
            f.write('')
        with open("masks.csv", 'w') as f:
            f.write('')
            
    for step, batch in enumerate(_valload):
        batch = tuple(t for t in batch)
        input_ids, age_ids, gender_ids, ethnicity_ids, race_ids, time_ids, posi_ids, segment_ids, attMask, labels, masks = batch

        input_ids=input_ids.to(device)
        age_ids=age_ids.to(device)
        gender_ids=gender_ids.to(device)
        ethnicity_ids=ethnicity_ids.to(device)
        race_ids=race_ids.to(device)
        time_ids=time_ids.to(device)
        posi_ids=posi_ids.to(device)
        segment_ids=segment_ids.to(device)
        attMask=attMask.to(device)
        labels=labels.to(device)
        labels = torch.squeeze(labels, 1)
        masks = masks.to(device)
        masks = torch.squeeze(masks, 1)
        

        output_bert = bert(input_ids, age_ids, gender_ids, ethnicity_ids, race_ids, time_ids, posi_ids, segment_ids,attention_mask=attMask, labels=labels)

        features, logits, probs = discriminator(output_bert)

        logits_list = torch.split(logits, output_bert.shape[0])
        D_real_logits = logits_list[0]



        logits = D_real_logits[:,0:-1]

        masked_lm_loss = bce_logits_loss(logits, labels)
        masked_lm_loss = torch.mul(masked_lm_loss, masks)
        discr_supervised_loss = torch.div(masked_lm_loss.sum(), masks.sum() + 0.001)




        discr_loss = discr_supervised_loss 

        tr_loss += discr_loss.item()
        tr_d_sup += discr_supervised_loss.item()

        if eval:
            with open("preds.csv", 'a') as f:
                pd.DataFrame(logits.detach().cpu().numpy()).to_csv(f, header=False)
            with open("labels.csv", 'a') as f:
                pd.DataFrame(labels.detach().cpu().numpy()).to_csv(f, header=False)
            with open("masks.csv", 'a') as f:
                pd.DataFrame(masks.detach().cpu().numpy()).to_csv(f, header=False)
            
    print("Discr Supervised Loss:", tr_d_sup)

    cost = time.time() - start
    return tr_loss, cost, logits, labels, masks, tr_d_sup

In [None]:
def train(trainload, valload, device):
    with open("log_train.txt", 'w') as f:
            f.write('')
    best_val = math.inf
    for e in range(train_params["epochs"]):
        print("Epoch n" + str(e))
        train_loss, train_time_cost = run_epoch(e, trainload, device)
        val_loss, val_time_cost,pred, label, mask, discr_loss = eval(valload, False, device)
        train_loss = train_loss / math.ceil((train_params["train_data_len"] * few_shots /train_params['batch_size']))
        val_loss = val_loss / math.ceil((train_params["val_data_len"]/train_params['batch_size']))
        print('TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
        with open("log_train.txt", 'a') as f:
            f.write("Epoch n" + str(e) + '\n TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
            f.write('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost) + '\n\n\n')
        print('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost))
        
        if discr_loss < best_val:
            print("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = bert.module if hasattr(bert, 'module') else bert
            save_model(model_to_save.state_dict(), 'bert')
            model_to_save = generator.module if hasattr(generator, 'module') else generator
            save_model(model_to_save.state_dict(), 'generator')
            model_to_save = discriminator.module if hasattr(discriminator, 'module') else discriminator
            save_model(model_to_save.state_dict(), 'discriminator')
            best_val = discr_loss
    return train_loss, val_loss

In [None]:
def save_model(_model_dict, file_name):
    torch.save(_model_dict, file_name)

In [None]:
pretrained_dict = torch.load("bert_pretrain", map_location=train_params['device'])
model_dict = bert.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
bert.load_state_dict(model_dict)

In [None]:
if train_params['action'] == 'eval' or train_params['action'] == 'resume':
    bert.load_state_dict(torch.load("bert", map_location=train_params['device']))
    generator.load_state_dict(torch.load("generator", map_location=train_params['device']))
    discriminator.load_state_dict(torch.load("discriminator", map_location=train_params['device']))
    print("Loading succesfull")

if train_params['action'] == 'train' or train_params['action'] == 'resume':
    TrainDset = DataLoader(train_data, max_len=train_params['max_len_seq'], code='code')
    trainload = torch.utils.data.DataLoader(dataset=TrainDset, batch_size=train_params['batch_size'], shuffle=True)
    ValDset = DataLoader(val_data, max_len=train_params['max_len_seq'], code='code')
    valload = torch.utils.data.DataLoader(dataset=ValDset, batch_size=train_params['batch_size'], shuffle=True)
    train_loss, val_loss = train(trainload, valload, train_params['device'])

elif train_params['action'] == 'eval':
    TestDset = DataLoader(test_data, max_len=train_params['max_len_seq'], code='code')
    testload = torch.utils.data.DataLoader(dataset=TestDset, batch_size=train_params['batch_size'], shuffle=False)
    loss, cost, pred, label, mask, discr_loss = eval(testload, True, train_params['device'])