In [1]:
import torch.nn as nn
import torch.nn.functional as F
import pytorch_pretrained_bert as Bert
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import import_ipynb
from cvd_utils import *

importing Jupyter notebook from cvd_utils.ipynb


In [2]:
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, segment, age
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
        self.gender_embeddings = nn.Embedding(config.gender_vocab_size, config.hidden_size)
        self.ethnicity_embeddings = nn.Embedding(config.ethnicity_vocab_size, config.hidden_size)
        self.race_embeddings = nn.Embedding(config.race_vocab_size, config.hidden_size)
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size). \
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))

        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, age_ids=None, gender_ids=None, ethni_ids=None, race_ids=None, seg_ids=None, posi_ids=None, age=True):

        if seg_ids is None:
            seg_ids = torch.zeros_like(word_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(word_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(word_ids)

        word_embed = self.word_embeddings(word_ids)
        segment_embed = self.segment_embeddings(seg_ids)
        age_embed = self.age_embeddings(age_ids)
        gender_embed = self.gender_embeddings(gender_ids)
        ethnicity_embed = self.ethnicity_embeddings(ethni_ids)
        race_embed = self.race_embeddings(race_ids)
        posi_embeddings = self.posi_embeddings(posi_ids)

        if age:
            embeddings = word_embed + segment_embed + age_embed + gender_embed + ethnicity_embed + race_embed + posi_embeddings
        else:
            embeddings = word_embed + segment_embed + gender_embed + ethnicity_embed + race_embed +  posi_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos / (10000 ** (2 * idx / hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos / (10000 ** (2 * idx / hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)

In [3]:
class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)
        
    def forward(self, input_ids, age_ids=None, gender_ids=None, ethni_ids=None, race_ids=None, seg_ids=None, posi_ids=None, attention_mask=None,
                output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if age_ids is None:
            age_ids = torch.zeros_like(input_ids)
        if gender_ids is None:
            gender_ids = torch.zeros_like(input_ids)
        if ethni_ids is None:
            ethni_ids = torch.zeros_like(input_ids)
        if race_ids is None:
            race_ids = torch.zeros_like(input_ids)
        if seg_ids is None:
            seg_ids = torch.zeros_like(input_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, age_ids, gender_ids, ethni_ids, race_ids, seg_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

In [4]:
class ListModule(object):
    def __init__(self, module, prefix, *args):
        self.module = module
        self.prefix = prefix
        self.num_module = 0
        for new_module in args:
            self.append(new_module)

    def append(self, new_module):
        if not isinstance(new_module, nn.Module):
            raise ValueError('Not a Module')
        else:
            self.module.add_module(self.prefix + str(self.num_module), new_module)
            self.num_module += 1

    def __len__(self):
        return self.num_module

    def __getitem__(self, i):
        if i < 0 or i >= self.num_module:
            raise IndexError('Out of bound')
        return getattr(self.module, self.prefix + str(i))

In [5]:
class BertForMTR(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMTR, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dropoutMTR = nn.Dropout(0.3)
        self.final = nn.Linear(config.hidden_size, config.number_output)
        self.fc_shared = ListModule(self, "fc1_")
        self.fc_shared1 = ListModule(self, "fc2_")
        self.fc_shared2 = ListModule(self, "fc3_")
        for _ in range(5):
            self.fc_shared.append(nn.Linear(config.hidden_size, 100))
            self.fc_shared1.append(nn.Linear(100, 100))
            self.fc_shared2.append(nn.Linear(100, 100))
        self.target1 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target2 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target3 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target4 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target5 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target6 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target7 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target8 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target9 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target10 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        self.target11 = nn.Sequential(
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 100),
          nn.LeakyReLU(),
          nn.Dropout(0.3),
          nn.Linear(100, 1)
        )
        
        self.leaky = nn.LeakyReLU()
        self.apply(self.init_bert_weights)

    def maxout(self, x, layer_list):
        max_output = layer_list[0](x)
        for _, layer in enumerate(layer_list, start=1):
            max_output = torch.max(max_output, layer(x))
        return max_output
        
    def forward(self, input_ids, age_ids=None, gender_ids=None, ethni_ids=None, race_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, masked_lm_labels=None, target_mask=None):
        _, pooled_output = self.bert(input_ids, age_ids, gender_ids, ethni_ids, race_ids, seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        prediction_scores = self.dropout(pooled_output)

        prediction_scores = self.maxout(prediction_scores, self.fc_shared)
        prediction_scores = self.leaky(prediction_scores)
        prediction_scores = self.dropoutMTR(prediction_scores)
        prediction_scores = self.maxout(prediction_scores, self.fc_shared1)
        prediction_scores = self.leaky(prediction_scores)
        prediction_scores = self.dropoutMTR(prediction_scores)
        prediction_scores = self.maxout(prediction_scores, self.fc_shared2)
        prediction_scores = self.leaky(prediction_scores)
        prediction_scores = self.dropoutMTR(prediction_scores)

        target1 = self.target1(prediction_scores)
        target2 = self.target2(prediction_scores)
        target3 = self.target3(prediction_scores)
        target4 = self.target4(prediction_scores)
        target5 = self.target5(prediction_scores)
        target6 = self.target6(prediction_scores)
        target7 = self.target7(prediction_scores)
        target8 = self.target8(prediction_scores)
        target9 = self.target9(prediction_scores)
        target10 = self.target10(prediction_scores)
        target11 = self.target11(prediction_scores)


        if masked_lm_labels is not None:
            prediction_scores = torch.cat((target1, target2, target3, target4, target5, target6, target7, target8, target9, target10, target11), dim=1)
            loss_fct = nn.MSELoss(reduction='none')
            masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
            masked_lm_loss = torch.mul(masked_lm_loss, target_mask)
            masked_lm_loss = torch.mean((masked_lm_loss).sum(dim=0)/(target_mask.sum(dim=0) + 0.001))
            return masked_lm_loss, prediction_scores, masked_lm_labels
 
        else:
            return prediction_scores

In [6]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.gender_vocab_size = config.get('gender_vocab_size')
        self.ethnicity_vocab_size = config.get('ethnicity_vocab_size')
        self.race_vocab_size = config.get('race_vocab_size')
        self.number_output = config.get('number_output')
        self.number_static = config.get('number_static')
        
class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')

In [7]:
class CVDLoader(Dataset):
    def __init__(self, dataframe, max_len, code='code', age='age', labels='labels'):
        self.max_len = max_len
        self.code = dataframe[code]
        self.age = dataframe[age]
        self.labels = dataframe[labels]
        self.gender = dataframe["gender"]
        self.ethnicity = dataframe["ethnicity"]
        self.race = dataframe["race"]

    def __getitem__(self, index):
        """
        return: age, code, position, segmentation, mask, label
        """
        
        # extract data
        age = self.age[index]
        code = self.code[index]
        label = self.labels[index]
        gender = self.gender[index]
        ethnicity = self.ethnicity[index]
        race = self.race[index]

        # mask 0:len(code) to 1, padding to be 0
        mask = np.ones(self.max_len)
        mask[len(code):] = 0
        
        target_mask = (label != 0).astype(int)

        # pad age sequence and code sequence
        age = seq_padding(age, self.max_len)
        gender = seq_padding(gender, self.max_len)
        ethnicity = seq_padding(ethnicity, self.max_len)
        race = seq_padding(race, self.max_len)
        
        # get position code and segment code
        code = seq_padding(code, self.max_len)
        position = position_idx(code)
        segment = index_seg(code)

        return  torch.LongTensor(age), torch.LongTensor(code), torch.LongTensor(gender), torch.LongTensor(ethnicity), torch.LongTensor(race), \
                torch.LongTensor(position), torch.LongTensor(segment), \
                torch.FloatTensor(mask), torch.FloatTensor(label), torch.FloatTensor(target_mask)
    
    def __len__(self):
        return len(self.code)

In [8]:
def adam(params, config=None):
    if config is None:
        config = {
            'lr': 3e-5,
            'warmup_proportion': 0.1,
            'weight_decay': 0.01
        }
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0}
    ]

    optim = Bert.optimization.BertAdam(optimizer_grouped_parameters,
                                       lr=config['lr'],
                                       warmup=config['warmup_proportion'])
    return optim