In [1]:
from bertviz import model_view

In [2]:
def load_model(path, model):
    # load pretrained model and update weights
    pretrained_dict = torch.load(path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model

In [19]:
import sys
sys.path.insert(1, '../')
from utils.dataset import EHRDatasetCodePrediction
from model.tokenizer import EHRTokenizer
import torch
from utils.dataset import EHRDataset
from torch import nn
import matplotlib.pyplot as plt
import pytorch_pretrained_bert as Bert
from torch.utils.data import DataLoader
from model.model import *
import seaborn as sns
from utils.config import BertConfig
from model.model2 import *
from model.trainer import PatientTrajectoryPredictor
import pytorch_lightning as pl
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [4]:
tokenizer = EHRTokenizer(task='ccsr')

In [5]:
path = '../processing/readmission_data_ccsr_'
data = pd.read_parquet(path)

In [6]:
model_config = {
    'vocab_size': len(tokenizer.getVoc('code').keys()), # number of disease + symbols for word embedding
    'hidden_size': 300, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(tokenizer.getVoc('age').keys()), # number of vocab for age embedding,
    'gender_vocab_size': 3,
    'max_position_embeddings': 32, # maximum number of tokens
    'hidden_dropout_prob': 0.1, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate
    'intermediate_size': 300, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

In [7]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embeddings'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')
        self.gender_vocab_size = config.get('gender_vocab_size')

### Visualize attention for Masked Language Modeling

In [8]:
bert_config = BertConfig(model_config)
model = BertForMaskedLM(bert_config) 
PATH = "../saved_models/MLM/deep_notsuffled"
model = load_model(PATH, model)

In [32]:
data['len'] = data['hadm_id'].apply(lambda x: len(x))

In [40]:
patient = data[data['len'] == 3]
patient = patient[patient['subject_id'] == 10215056]

In [41]:
patient

Unnamed: 0,subject_id,label,icd_code,ccsr,age,alcohol_abuse,tobacco_abuse,ndc,hadm_id,gender,len
20,10215056,"[0, 0, 0]","[[M545, M179, R791, E119, K219, F341, Z86718, ...","[[MUS038, MUS006, SYM017, END002, DIG004, MBD0...","[58.0, 62.0, 65.0]","[0, 0, 0]","[0, 1, 0]","[[56017275, 51079081120, 93521193, 45064165, 8...","[26393320, 26394582, 20975745]",F,3


In [63]:
#patientd = EHRDataset(patient, max_len=64, tokenizer=tokenizer) 

In [64]:
#loader = torch.utils.data.DataLoader(patientd, batch_size=1, shuffle=True, num_workers=4)

In [65]:
#batch = next(iter(loader))
#print(batch)
#age_ids, gender_ids, input_ids, posi_ids, segment_ids, attMask, masked_label, _ = batch

[tensor([[16, 16, 16,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  4,  4,  4]]), tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), tensor([[517,  40, 519,  80,  64,  11, 140,  23,  16, 518,  51,  86, 518,  89,
          33,  39,  74, 518,  10, 142, 117, 143, 144,  11,  68, 145,  38,  39,
          39, 180,  25,  25, 518,  34, 146,  38, 519, 518,  75,  11, 141, 518,
          80,  25,  39,  39,  74,  33,  27, 139,  16,  25,  78,  25,  31, 518,
         518,  86,  19, 148, 149,  93,  40, 519]]), tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 

In [58]:
age_ids

tensor([[4, 4, 4]])