In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dill
import os
import pandas as pd
from tqdm import tqdm

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0, max_len=1000):
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.embeddings = nn.Embedding(max_len, d_model)

        initrange = 0.1
        self.embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        pos = torch.arange(0, x.size(1), device=x.device).int().unsqueeze(0)
        x = x + self.embeddings(pos).expand_as(x)
        return x

class PatientEncoder(nn.Module): 
    def __init__(self, args, voc_size):
        super(PatientEncoder, self).__init__()
        self.args = args
        self.voc_size = voc_size
        self.emb_dim = args.embed_dim 
        self.device = torch.device('cuda:{}'.format(args.cuda))

        self.special_tokens = {'CLS': torch.LongTensor([0,]).to(self.device), 'SEP': torch.LongTensor([1,]).to(self.device)}
        
        self.segment_embedding = nn.Embedding(2, self.emb_dim)

        if args.patient_seperate == False:
            self.embeddings = nn.ModuleList(
            [nn.Embedding(voc_size[i], self.emb_dim) for i in range(2)])
            self.special_embeddings = nn.Embedding(2, self.emb_dim) 
            self.transformer_visit = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=self.emb_dim, nhead=args.nhead, dropout=args.dropout),
                num_layers=args.encoder_layers
            )
            self.positional_embedding_layer_disease = LearnablePositionalEncoding(d_model=args.embed_dim)
            self.positional_embedding_layer_procedure = LearnablePositionalEncoding(d_model=args.embed_dim)
            self.patient_encoder = self.patient_encoder_unified
        else:
            self.embeddings = nn.ModuleList(
            [nn.Embedding(voc_size[i], self.emb_dim//2) for i in range(2)])
            self.special_embeddings = nn.Embedding(2, self.emb_dim//2)
            self.transformer_disease = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=self.emb_dim//2, nhead=args.nhead, dropout=args.dropout),
                num_layers=args.encoder_layers
            )
            self.transformer_procedure = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=self.emb_dim//2, nhead=args.nhead, dropout=args.dropout),
                num_layers=args.encoder_layers
            )

            self.patient_layer = nn.Sequential(
                nn.Linear(self.emb_dim, self.emb_dim),
                nn.ReLU(),
                nn.Linear(self.emb_dim, self.emb_dim),
            )

            self.positional_embedding_layer_disease = LearnablePositionalEncoding(d_model=args.embed_dim//2)
            self.positional_embedding_layer_procedure = LearnablePositionalEncoding(d_model=args.embed_dim//2)

            self.patient_encoder = self.patient_encoder_seperate
        
    def patient_encoder_seperate(self, batch_visits):
        device = self.device

        batch_disease_repr, batch_procedure_repr = [], []
        for adm in batch_visits:
            # 对每次访问：
            diseases = adm[0]
            procedures = adm[1]

            disease_embedding = self.embeddings[0](torch.LongTensor(diseases).unsqueeze(dim=1).to(self.device)) # (n, 1, dim/2)
            procedure_embedding = self.embeddings[1](torch.LongTensor(procedures).unsqueeze(dim=1).to(self.device)) # (m, 1, dim/2)

            # 加入CLS token
            cls_embedding_dis = self.special_embeddings(self.special_tokens['CLS']).unsqueeze(dim=1)  # (1, 1, dim/2)
            cls_embedding_pro = self.special_embeddings(self.special_tokens['SEP']).unsqueeze(dim=1)  # (1, 1, dim/2)
            disease_embedding = torch.cat((cls_embedding_dis, disease_embedding), dim=0)  # (n+1, 1, dim/2)
            procedure_embedding = torch.cat((cls_embedding_pro, procedure_embedding), dim=0) # (m+1, 1, dim/2)

            disease_embedding = self.positional_embedding_layer_disease(disease_embedding)
            procedure_embedding = self.positional_embedding_layer_procedure(procedure_embedding)

            disease_representation = self.transformer_disease(disease_embedding)[0]  # (n+1,1,dim/2)
            procedure_representation = self.transformer_procedure(procedure_embedding)[0] # (m+1,1,dim/2)

            disease_representation = disease_representation.mean(dim=0)  # (1,1,dim/2)
            procedure_representation = procedure_representation.mean(dim=0)  # (1,1,dim/2)

            disease_representation = torch.reshape(disease_representation, (1,1,-1))  # (1,1,dim/2)
            procedure_representation = torch.reshape(procedure_representation, (1,1,-1))  # (1,1,dim/2)

            batch_disease_repr.append(disease_representation)
            batch_procedure_repr.append(procedure_representation)
        
        batch_disease_repr = torch.cat(batch_disease_repr, dim=1).to(device) # (1, B, dim/2)
        batch_procedure_repr = torch.cat(batch_procedure_repr, dim=1).to(device) #  (1, B, dim/2)

        batch_repr = torch.cat((batch_disease_repr, batch_procedure_repr), dim=-1)  # (1, B, dim)
        batch_repr = batch_repr.squeeze(dim=0)  # (B, dim)
        # batch_repr = self.patient_layer(batch_repr)  # (B, dim)

        return batch_repr
    
    def patient_encoder_unified(self, batch_visits):
        batch_repr = []
        for adm in batch_visits:
            # 对每次访问：
            diseases = adm[0]
            procedures = adm[1]
            disease_embedding = self.embeddings[0](torch.LongTensor(diseases).unsqueeze(dim=1).to(self.device)) # (n, 1, dim)
            procedure_embedding = self.embeddings[1](torch.LongTensor(procedures).unsqueeze(dim=1).to(self.device))  # (m, 1, dim)
            
            cls_embedding = self.special_embeddings(self.special_tokens['CLS']).unsqueeze(dim=1)
            sep_embedding = self.special_embeddings(self.special_tokens['SEP']).unsqueeze(dim=1)
            
            disease_embedding = torch.cat((cls_embedding, disease_embedding), dim=0)  # (n+1, 1, dim)
            procedure_embedding = torch.cat((sep_embedding, procedure_embedding), dim=0) # (m+1, 1, dim)

            disease_embedding = self.positional_embedding_layer_disease(disease_embedding)
            procedure_embedding = self.positional_embedding_layer_procedure(procedure_embedding)

            # disease_embedding = self.positional_embedding_layer(disease_embedding)
            # procedure_embedding = self.positional_embedding_layer(procedure_embedding)

            combined_embedding = torch.cat((disease_embedding, procedure_embedding), dim=0)  # (n+m+2, 1, dim)
            # combined_embedding = torch.cat((cls_embedding, disease_embedding, sep_embedding, procedure_embedding), dim=0)  # (n+m+2, 1, dim)
            
            # 加入segment embedding
            segments = torch.tensor([0] * (len(diseases) + 2) + [1] * len(procedures)).to(self.device)
            segment_embedding = self.segment_embedding(segments).unsqueeze(dim=1)
            input_embedding = combined_embedding + segment_embedding

            visit_representation = self.transformer_visit(input_embedding)[0]
            visit_representation = torch.reshape(visit_representation, (1,1,-1))   # (1,1,dim)
            batch_repr.append(visit_representation)
        batch_repr = torch.cat(batch_repr, dim=1).to(self.device)  # (1,B,dim)
        batch_repr = batch_repr.squeeze(dim=0)  # (B,dim)
        return batch_repr

class RareMed(PatientEncoder):
    def __init__(self, args, voc_size):
        super(RareMed, self).__init__(args, voc_size)

        self.init_weights()

        # self.mask_adapter = Adapter(self.emb_dim, args.adapter_dim)
        self.cls_mask = nn.Linear(self.emb_dim, self.voc_size[0]+self.voc_size[1])

        # self.nsp_adapter = Adapter(self.emb_dim, args.adapter_dim)
        self.cls_nsp = nn.Linear(self.emb_dim, 1)

        self.cls_final = nn.Linear(self.emb_dim, self.voc_size[2])
    
    def forward(self, input):
        # B: batch size, D: drug num, dim: embedding dimension
        patient_repr = self.patient_encoder(input) # (B,dim)

        return patient_repr

    def init_weights(self):
        """Initialize embedding weights."""
        initrange = 0.1
        self.embeddings[0].weight.data.uniform_(-initrange, initrange)      # disease
        self.embeddings[1].weight.data.uniform_(-initrange, initrange)      # procedure

        self.segment_embedding.weight.data.uniform_(-initrange, initrange)
        self.special_embeddings.weight.data.uniform_(-initrange, initrange)

class Args:
    def __init__(self):
        self.note = ''
        self.model_name = 'RAREMed'
        self.dataset = 'mimic-iii'
        self.early_stop = 10
        self.test = False
        self.log_dir_prefix = None
        self.pretrain_prefix = None
        self.cuda = 0
        self.patient_seperate = False
        self.seg_rel_emb = True
        self.pretrain_nsp = False
        self.pretrain_mask = False
        self.pretrain_epochs = 20
        self.mask_prob = 0
        self.embed_dim = 512
        self.encoder_layers = 3
        self.nhead = 4
        self.batch_size = 1
        self.adapter_dim = 128
        self.lr = 1e-5
        self.dropout = 0.3
        self.weight_decay = 0.1
        self.weight_multi = 0.005
        self.weight_ddi = 0.1


In [None]:
voc_path = 'data/data_process/output/mimic-iii/voc_final.pkl'
data_path = 'data/data_process/output/mimic-iii/data.csv'
model_path = 'Your/trained/RAREMed/model/path'   ##

pat_embed_path = 'data/save_embedding/pat_embed_raremed.pkl'

In [None]:
# Usage
args = Args()

voc = dill.load(open(voc_path, 'rb'))
diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']
def add_word(word, voc):
    voc.word2idx[word] = len(voc.word2idx)
    voc.idx2word[len(voc.idx2word)] = word
    return voc
add_word('[MASK]', diag_voc)
add_word('[MASK]', pro_voc)
voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

model = RareMed(args, voc_size)
model.load_state_dict(torch.load(model_path))
model.to(f'cuda:{args.cuda}')
model.eval()

In [None]:
data = pd.read_csv(data_path)

pat_reps = dict()

for i in tqdm(range(len(data))):
    hadm_id = int(data.loc[i, 'HADM_ID'])
    diags = eval(data.loc[i, 'ICD_CODE'])
    pros = eval(data.loc[i, 'PRO_CODE'])
    patient = [[diags, pros]]
    patient_repr = model(patient)
    pat_reps[hadm_id] = patient_repr[0].cpu().detach()

dill.dump(pat_reps, open(pat_embed_path, 'wb'))
    