### Task 3: Next 6 or 12 months prediction

#### Fine tunning

BERT generates word embeddings for the vocabulary words included in the corpus. In other words, each word or medical diganosis in our case, is mapped to a vector. In contrast to Word2Vec, BERT generates code that captures the local context of words and therefore provides a better representation of the word.

MLM (Masked Language Model) is used for optimization, which we already performed on the 2nd task. 

In this step, we will add a task to predict the diagnosis codes after 6 or 12 months from a randomly selected visit 
date. This new task will be trained and the word embeddings will be fine tuned as well.

global_param['next_x_months'] defines which prediction will be executed:
- 0 : Next visit
- 6 : Next 6 months
- 12 : Next 12 months.

The age must be in months.

In [1]:
local_mode = False

In [2]:
if not local_mode:
  !mkdir commons
  !wget -P commons https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/commons/utils.py
  !wget -P commons https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/commons/__init__.py

  !mkdir models
  !wget -P models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/models/MLM.py
  !wget -P models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/models/NextXVisit.py
  !wget -P models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/models/optimizer.py
  !wget -P models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/models/BertConfig.py
  !wget -P models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/models/__init__.py

  !mkdir data
  !wget -P data https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/data/ages.pkl
  !wget -P data https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/data/concept.pkl
  !wget -P data https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/data/condition_codes.pkl
  !wget -P data https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/data/conditions.pkl

  !mkdir saved_models
  !wget -P saved_models https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/saved_models/mlm128.pt

  !mkdir images
  !wget -P images https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/images/cdm54.png
  !wget -P images https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/images/behrt_embeddings.png
  !wget -P images https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/images/behrt_model.png


--2023-05-07 22:27:50--  https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/commons/utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12014 (12K) [text/plain]
Saving to: ‘commons/utils.py’


2023-05-07 22:27:50 (88.8 MB/s) - ‘commons/utils.py’ saved [12014/12014]

--2023-05-07 22:27:50--  https://raw.githubusercontent.com/bdigafe/cs598_dlh_final/master/behrt/commons/__init__.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 0 [text/plain]
Saving to: ‘commons/__init__.py’

__init__.py             [ <=>  

In [3]:
if not local_mode:
  %pip install pytorch_pretrained_bert 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_pretrained_bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.8/123.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting boto3
  Downloading boto3-1.26.129-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Collecting s3transfer<0.7.0,>=0.6.0
  Downloading s3transfer-0.6.1-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botocore<1.30.0,>=1.29.129
  Downloading botocore-1.29.129-py3-none-any.whl (10.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmespath<2.0.0,>=0.7.1
  Down

### File and model paramters

In [4]:
import os
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import commons.utils as utils
from models import optimizer
from torch.utils.data import DataLoader
import pytorch_pretrained_bert as Bert

# set random seed for reproducibility
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
os.environ['PYTHONASHSEED'] = str(seed)

global_params = {
    'max_seq_len': 256,
    'max_age' : 110,
    'age_month' : 1,
    'batch_size': 128,
    'num_epochs': 20,
    'min_visit': 5,
    'next_x_months' : 12,

    'gradient_accumulation_steps': 1,
    'training_sample' : 0,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

optim_config = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

file_config = {
    'vocab': ('C:/Birhanu/Education/UIL/cs598/Final/data/' if local_mode else 'data/') + 'condition_codes.pkl',
    'data': ('C:/Birhanu/Education/UIL/cs598/Final/data/' if local_mode else 'data/' ) + 'conditions.pkl',
    'ages' : ('C:/Birhanu/Education/UIL/cs598/Final/data/' if local_mode else 'data/' ) + 'ages.pkl',

    'model_path': 'C:/Birhanu/Education/UIL/cs598/Final/saved_models/' if local_mode else 'saved_models/', 
    'model_name': 'nextxm-model',  # model name

    # MLM pretrained model path
    'pretrainModel': 'C:/Birhanu/Education/UIL/cs598/Final/saved_models/mlm128.pt' if local_mode else 'saved_models/mlm128.pt',

    # log file
    'log_file_name': 'nextxm-log',  # log path
}



In [5]:
# Check if GPU is available
print(global_params["device"])
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

cuda


#### Load data

In [6]:
# Load conditions codes (word vocabs)
print(f"Conditions vocab file: {file_config['vocab']}")
vocab_conditions = utils.get_codes_vocab(file_config["vocab"])
code2idx = vocab_conditions["token2idx"]

# Create age vocab
vocab_age = utils.age_vocab(global_params["max_age"], global_params["age_month"])
age2idx = vocab_age[0]

# Load data
data_conditions = utils.load_data(file_config["data"], sample_size=global_params["training_sample"])
data_age_seqs = utils.load_data(file_config["ages"], sample_size=global_params["training_sample"])

# Split data into train, validation, and test
train_data, test_data = utils.split_data(data_conditions, data_age_seqs, train_ratio=0.8, min_size=4, num_months=global_params["next_x_months"])

Conditions vocab file: data/condition_codes.pkl


In [7]:
print(f"Train data size: {len(train_data)}")
print(f"Test data size: {len(test_data)}")

print(train_data["visit"].iloc[0])
print(train_data["age"].iloc[0])

Train data size: 11268
Test data size: 1889
['35208190', 'SEP', '35208481', 'SEP', '1567866', 'SEP', '1569566', 'SEP', '1569558', '1569566', 'SEP', '1569566', 'SEP', '1569739', 'SEP', '1567750', '1569558', 'SEP', '1569558', 'SEP', '1569566', 'SEP']
['779', '779', '784', '784', '795', '795', '796', '796', '797', '797', '797', '797', '797', '797', '797', '798', '798', '798', '798', '798', '799', '799']


In [8]:
model_config = {
    # number of disease + symbols for word embedding
    'vocab_size': len(vocab_conditions['token2idx'].keys()),

    # word embedding and seg embedding hidden size
    'hidden_size': 288, 

    # number of vocab for seg embedding
    'seg_vocab_size': 2, 

    # number of vocab for age embedding
    'age_vocab_size': len(vocab_age[0].keys()),

    # maximum number of tokens
    'max_position_embedding': global_params['max_seq_len'],

    # dropout rate
    'hidden_dropout_prob': 0.2, 

    # number of multi-head attention layers required
    'num_hidden_layers': 6,  

    # number of attention heads
    'num_attention_heads': 12,

    # multi-head attention dropout rate  
    'attention_probs_dropout_prob': 0.22,  

    # the size of the "intermediate" layer in the transformer encoder
    'intermediate_size': 512,

    # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'hidden_act': 'gelu',
    'initializer_range': 0.02,  # parameter weight initializer range
}

feature_dict = {
    'age': True,
    'seg': True,
    'posi': True
}


In [9]:
class NextVisit(Dataset):
    def __init__(self, token2idx, diag2idx, age2idx, dataframe, max_len, max_age=110, min_visit=5):
        # dataframe preproecssing
        # filter out the patient with number of visits less than min_visit
        self.vocab = token2idx
        self.label_vocab = diag2idx
        self.max_len = max_len

        self.code = dataframe.visit
        self.age = dataframe.age
        self.label = dataframe.label
        self.patid = dataframe.pid

        self.age2idx = age2idx

    def __getitem__(self, index):
        """
        return: age, code, position, segmentation, mask, label
        """
        # cut data
        age = self.age[index]
        code = self.code[index]
        label = self.label[index]
        patid = self.patid[index]

        # extract data
        age = age[(-self.max_len+1):]
        code = code[(-self.max_len+1):]

        # avoid data cut with first element to be 'SEP'
        if code[0] != 'SEP':
            code = np.append(np.array(['CLS']), code)
            age = np.append(np.array(age[0]), age)
        else:
            code[0] = 'CLS'

        # mask 0:len(code) to 1, padding to be 0
        mask = np.ones(self.max_len)
        mask[len(code):] = 0

        # pad age sequence and code sequence
        age = utils.seq_padding(age, self.max_len, token2idx=self.age2idx)

        tokens, code = utils.code2index(code, self.vocab)
        _, label = utils.code2index(label, self.label_vocab)

        # get position code and segment code
        tokens = utils.seq_padding(tokens, self.max_len)
        position = utils.position_idx(tokens)
        segment = utils.index_seg(tokens)

        # pad code and label
        code = utils.seq_padding(code, self.max_len, symbol=self.vocab['PAD'])
        label = utils.seq_padding(label, self.max_len, symbol=-1)

        return torch.LongTensor(age), \
              torch.LongTensor(code), \
              torch.LongTensor(position), \
              torch.LongTensor(segment), \
              torch.LongTensor(mask), \
              torch.LongTensor(label), \
              torch.LongTensor([int(patid)])

    def __len__(self):
        return len(self.code)


In [10]:
# Test batch data
from torch.utils.data import DataLoader

def inspect_batch():
    dataset = NextVisit(token2idx=code2idx, diag2idx=code2idx,
                 age2idx=age2idx, dataframe=train_data, 
                 max_len=global_params['max_seq_len'])
    
    train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
    loader_iter = iter(train_loader)
    batch = next(loader_iter)

    batch = tuple(t.to(global_params['device']) for t in batch)
    age_ids, input_ids, posi_ids, segment_ids, attMask, label, masked_label = batch
    
    # Token codes
    print(f"Input:\n {input_ids}\n")
    print(f"Age:\n{age_ids}\n")
    print(f"Positions:\n{posi_ids}\n")
    print(f"Segments:\n{segment_ids}\n")

    print(f"attMask: {attMask}")
    print(f"Labels: {label}")

#inspect_batch()
     

In [11]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, segment, age
    """

    def __init__(self, config, feature_dict):
        super(BertEmbeddings, self).__init__()
        self.feature_dict = feature_dict

        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size)
        self.segment_embeddings = nn.Embedding(
            config.seg_vocab_size, config.hidden_size)
        self.age_embeddings = nn.Embedding(
            config.age_vocab_size, config.hidden_size)
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\
            from_pretrained(embeddings=self._init_posi_embedding(
                config.max_position_embeddings, config.hidden_size))

        self.LayerNorm = Bert.modeling.BertLayerNorm(
            config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):
        if seg_ids is None:
            seg_ids = torch.zeros_like(word_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(word_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(word_ids)

        word_embed = self.word_embeddings(word_ids)
        segment_embed = self.segment_embeddings(seg_ids)
        age_embed = self.age_embeddings(age_ids)
        posi_embeddings = self.posi_embeddings(posi_ids)

        embeddings = word_embed

        if self.feature_dict['age']:
            embeddings = embeddings + age_embed
        if self.feature_dict['seg']:
            embeddings = embeddings + segment_embed
        if self.feature_dict['posi']:
            embeddings = embeddings + posi_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos/(10000**(2*idx/hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos/(10000**(2*idx/hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros(
            (max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
                
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)


class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, feature_dict):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config, feature_dict)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(input_ids)
        if seg_ids is None:
            seg_ids = torch.zeros_like(input_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(input_ids)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(
            input_ids, age_ids, seg_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, num_labels, feature_dict):

        super(BertForMultiLabelPrediction, self).__init__(config)
        
        self.num_labels = num_labels
        self.bert = BertModel(config, feature_dict)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = nn.MultiLabelSoftMarginLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            return loss, logits
        else:
            return logits


### Load Data

In [12]:
Dset = NextVisit(token2idx=code2idx, diag2idx=code2idx,
                 age2idx=age2idx, dataframe=train_data, 
                 max_len=global_params['max_seq_len'])

trainload = DataLoader(
    dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=2)


In [13]:
Dset = NextVisit(token2idx=code2idx, diag2idx=code2idx,
                 age2idx=age2idx, dataframe=test_data, max_len=global_params['max_seq_len'])

testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)



### Setup Model

In [14]:
print(f"Size of conditions vocab: {len(vocab_conditions['token2idx'])}")
print(f"Size of Age vocab: {len(vocab_age[0].keys())}")


Size of conditions vocab: 301
Size of Age vocab: 1322


In [15]:
conf = BertConfig(model_config)
model = BertForMultiLabelPrediction(conf,  len(vocab_conditions["token2idx"]), feature_dict)


In [16]:
model


BertForMultiLabelPrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(301, 288)
      (segment_embeddings): Embedding(2, 288)
      (age_embeddings): Embedding(1322, 288)
      (posi_embeddings): Embedding(256, 288)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=288, out_features=288, bias=True)
              (key): Linear(in_features=288, out_features=288, bias=True)
              (value): Linear(in_features=288, out_features=288, bias=True)
              (dropout): Dropout(p=0.22, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=288, out_features=288, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout

In [None]:
# load pretrained model and update weights
pretrained_dict = torch.load(file_config['pretrainModel'])
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)

# 3. load the new state dict
model.load_state_dict(model_dict)

In [None]:
model = model.to(global_params['device'])
optim = optimizer.adam(params=list(
    model.named_parameters()), config=optim_config)


### Evaluation Matrix

In [None]:
from sklearn.metrics._plot.roc_curve import RocCurveDisplay
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score

def precision(logits, label):
    sig = nn.Sigmoid()
    output = sig(logits)
    label, output = label.cpu(), output.detach().cpu()
    tempprc=  average_precision_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, output, label

def precision_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    tempprc= average_precision_score(label.numpy(),output.numpy(), average='samples')
    roc = roc_auc_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, roc, output, label,

def auroc_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    tempprc= average_precision_score(label.numpy(),output.numpy(), average='samples')
    roc = roc_auc_score(label.numpy(),output.numpy(), average='samples')
    return roc


### Multi-hot Label Encoder


In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

indexes = list(vocab_conditions["token2idx"].values())

mlb = MultiLabelBinarizer(classes=list(indexes))
mlb.fit([[each] for each in list(indexes)])


### Train and Test

In [None]:
def train(e):
    model.train()
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(trainload):
        cnt += 1
        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch
        targets = torch.tensor(mlb.transform(
            targets.numpy()), dtype=torch.float32)

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])

        loss, logits = model(input_ids, age_ids, segment_ids,
                             posi_ids, attention_mask=attMask, labels=targets)

        if global_params['gradient_accumulation_steps'] > 1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 100 == 0:
            prec, a, b = precision(logits, targets)
            print("epoch: {}\t| Cnt: {}\t| Samples: {}, Loss: {}\t| precision: {}".format(
                e, cnt, nb_tr_examples, temp_loss/100, prec))
            temp_loss = 0

        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optim.step()
            optim.zero_grad()


def evaluation():
    model.eval()

    y = []
    y_label = []
    tr_loss = 0
    for step, batch in enumerate(testload):
        model.eval()
        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch
        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)

        age_ids = age_ids.to(global_params['device'])
        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])

        with torch.no_grad():
            loss, logits = model(input_ids, age_ids, segment_ids,
                                 posi_ids, attention_mask=attMask, labels=targets)
        logits = logits.cpu()
        targets = targets.cpu()

        y_label.append(targets)
        y.append(logits)
        tr_loss += loss.item()

    y_label = torch.cat(y_label, dim=0)
    y = torch.cat(y, dim=0)

    aps, roc, output, label = precision_test(y, y_label)
    return aps, roc, tr_loss


In [None]:
import warnings
import time

warnings.filterwarnings(action='ignore')
optim_config = {
    'lr': optim_config['lr'],
    'warmup_proportion': 0.1
}
optim = optimizer.adam(params=list(
    model.named_parameters()), config=optim_config)

best_pre = 0.512
best_roc = 0

for e in range(global_params['num_epochs']):
    print(f"Epoch: {e} training started...")
    start = time.time()
    train(e)

    aps, roc, test_loss = evaluation()
    if aps > best_pre:
        # Save a trained model
        print("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])
        
        utils.create_folder(file_config['model_path'])
        if file_config['model_path']:
            torch.save(model_to_save.state_dict(), output_model_file)
        best_pre = aps
        best_roc = roc

    print('precision : {}, auroc: {},'.format(aps, roc))
    print(f"Epoch: {e} completed. \t time: {time.time()-start}")

print(f"Best pre={best_pre}, auroc: {best_roc}")