In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
torch.manual_seed(1)
from sklearn.metrics import roc_auc_score, f1_score

In [None]:
training_data=np.load('data/training_data.npy', allow_pickle=True)
test_data=np.load('data/test_data.npy', allow_pickle=True)
val_data=np.load('data/val_data.npy', allow_pickle=True)
word_to_ix=np.load('data/word_to_ix.npy', allow_pickle=True).item() # words (in notes) to index
ix_to_word=np.load('data/ix_to_word.npy', allow_pickle=True).item() # index to words (in notes). not strictly needed for model
wikivec=np.load('data/newwikivec.npy', allow_pickle=True) # wiki article embeddings (# codes with wiki articles, vocab size)
wikivoc=np.load('data/wikivoc.npy', allow_pickle=True).item() # ICD-9 codes with wiki articles. not strictly needed for model

In [None]:
n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)
n_embedding = 100
batch_size = 32
test_batch_size = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wikivec = torch.FloatTensor(wikivec).to(DEVICE) # wikivec is a model input for KSI

In [None]:
def collate_fn(block):
    block_size = len(block)
    max_words = np.max([len(i[0]) for i in block])
    mat = np.zeros((block_size, max_words), dtype=int)
    for i in range(block_size):
        for j in range(max_words):
            try:
                if block[i][0][j] in word_to_ix:
                    mat[i,j] = word_to_ix[block[i][0][j]]
            except IndexError:
                pass
    mat = torch.from_numpy(mat)
    embeddings = torch.FloatTensor(np.array([x for _, x, _ in block]))
    labels = torch.FloatTensor(np.array([y for _, _, y in block]))
    return mat, embeddings, labels

In [None]:
train_dataloader = DataLoader(training_data, collate_fn=collate_fn, batch_size=batch_size)
val_dataloader = DataLoader(val_data, collate_fn=collate_fn, batch_size=test_batch_size)
test_dataloader = DataLoader(test_data, collate_fn=collate_fn, batch_size=test_batch_size)

In [None]:
class KSI(nn.Module):
    def __init__(self, n_ksi_embedding, n_vocab):
        super().__init__()
        self.ksi_embedding = nn.Linear(n_vocab, n_ksi_embedding)
        self.ksi_attention = nn.Linear(n_ksi_embedding, n_ksi_embedding)
        self.ksi_output = nn.Linear(n_ksi_embedding, 1)
        
    def forward_ksi(self, notevec, wikivec):
        n = notevec.shape[0]
        n_codes = wikivec.shape[0]
        notevec = notevec.unsqueeze(1).expand(n, n_codes, -1)
        wikivec = wikivec.unsqueeze(0)
        
        z = torch.mul(wikivec, notevec)
        e = self.ksi_embedding(z)
        attention_scores = torch.sigmoid(self.ksi_attention(e))
        v = torch.mul(attention_scores, e)
        s = self.ksi_output(v)
        o = s.squeeze(2)
        
        return o


class CNN(nn.Module):
    def __init__(self, n_words, n_wiki, n_embedding, ksi=None, **kwargs):
        super().__init__(**kwargs)
        self.ksi = ksi
        self.word_embeddings = nn.Embedding(n_words+1, n_embedding)
        self.dropout_embedding = nn.Dropout(p=0.2)
        self.conv1 = nn.Conv1d(n_embedding, 100, 3)
        self.conv2 = nn.Conv1d(n_embedding, 100, 4)
        self.conv3 = nn.Conv1d(n_embedding, 100, 5)
        self.output = nn.Linear(n_embedding*3, n_wiki)
    
    def forward(self, note, notevec=None, wikivec=None):
        # batch_size, n = note.shape
        embeddings = self.word_embeddings(note) # (batch_size, n, n_embedding)
        embeddings = self.dropout_embedding(embeddings)
        embeddings = embeddings.permute(0, 2, 1) # (batch_size, n_embedding, n)
        
        a1 = F.relu(self.conv1(embeddings))
        a1 = F.max_pool1d(a1, a1.shape[2])
        a2 = F.relu(self.conv2(embeddings))
        a2 = F.max_pool1d(a2, a2.shape[2])
        a3 = F.relu(self.conv3(embeddings))
        a3 = F.max_pool1d(a3, a3.shape[2])
        combined = torch.cat([a1, a2, a3], 1).squeeze(2)
       
        out = self.output(combined)
        if self.ksi:
            out += self.ksi.forward_ksi(notevec, wikivec)
        
        scores = torch.sigmoid(out)
        return scores

In [None]:
def train(model, dataloader, loss_function, optimizer, wikivec=None):
    for data in dataloader:
        optimizer.zero_grad()
        note, embeddings, labels = data
        note = note.to(DEVICE)
        embeddings = embeddings.to(DEVICE)
        labels = labels.to(DEVICE)
        scores = model(note, embeddings, wikivec)
        loss = loss_function(scores, labels)
        loss.backward()
        optimizer.step()

        
def test(model, dataloader, wikivec=None, threshold=0.5):
    micro_f1 = []
    macro_f1 = []
    micro_auc = []
    macro_auc = []
    weights = []
    for data in dataloader:
        note, embeddings, labels = data
        note = note.to(DEVICE)
        embeddings = embeddings.to(DEVICE)
        out = model(note, embeddings, wikivec).cpu().detach().numpy()
        pred = np.array(out > threshold, dtype=float)
        labels = labels.cpu().detach().numpy()
        
        labeled_rows = np.sum(labels, axis=1) > 0 # exclude rows with no labels, which break sklearn metrics
        filtered_labels = labels[labeled_rows].T
        filtered_pred = pred[labeled_rows].T
        filtered_scores = out[labeled_rows].T
        
        #TODO: recall @ k (metric used in the paper)
        micro_f1.append(f1_score(filtered_labels, filtered_pred, average='micro'))
        macro_f1.append(f1_score(filtered_labels, filtered_pred, average='macro'))
        micro_auc.append(roc_auc_score(filtered_labels, filtered_scores, average='micro'))
        macro_auc.append(roc_auc_score(filtered_labels, filtered_scores, average='macro'))
        weights.append(len(data))
    micro_f1 = np.average(micro_f1, weights=weights)
    macro_f1 = np.average(macro_f1, weights=weights)
    micro_auc = np.average(micro_auc, weights=weights)
    macro_auc = np.average(macro_auc, weights=weights)
    return None, micro_f1, macro_f1, micro_auc, macro_auc

In [None]:
base_model = CNN(n_words, n_wiki, n_embedding)
base_model = base_model.to(DEVICE)
loss_function = nn.BCELoss()
optimizer = optim.Adam(base_model.parameters())

for epoch in range(1):
    train(base_model, train_dataloader, loss_function, optimizer)
    t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(base_model, train_dataloader)
    v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(base_model, val_dataloader)
    print(f'Epoch: {epoch+1:03d}, Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
          f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}')
    
torch.save(base_model, 'CNN_model.pt')

In [None]:
ksi = KSI(n_embedding, n_vocab)
ksi.to(DEVICE)
model = CNN(n_words, n_wiki, n_embedding, ksi=ksi)
model = model.to(DEVICE)
loss_function = nn.BCELoss()
optimizer = optim.Adam(model.parameters())  
        
for epoch in range(1):
    train(model, train_dataloader, loss_function, optimizer, wikivec)
    t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(model, train_dataloader, wikivec)
    v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(model, val_dataloader, wikivec)
    print(f'Epoch: {epoch+1:03d}, Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
          f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}')
    
torch.save(model, 'KSI_CNN_model.pt')
