In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score
from torchinfo import summary

In [2]:
training_data=np.load('data/training_data.npy', allow_pickle=True)
test_data=np.load('data/test_data.npy', allow_pickle=True)
val_data=np.load('data/val_data.npy', allow_pickle=True)
word_to_ix=np.load('data/word_to_ix.npy', allow_pickle=True).item() # words (in notes) to index
ix_to_word=np.load('data/ix_to_word.npy', allow_pickle=True).item() # index to words (in notes). not strictly needed for model
wikivec=np.load('data/newwikivec.npy', allow_pickle=True) # wiki article embeddings (# codes with wiki articles, vocab size)
wikivoc=np.load('data/wikivoc.npy', allow_pickle=True).item() # ICD-9 codes with wiki articles. not strictly needed for model

In [3]:
n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)
n_embedding = 100
batch_size = 32
test_batch_size = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wikivec = torch.FloatTensor(wikivec).to(DEVICE) # wikivec is a model input for KSI

In [4]:
def collate_fn(block):
    block_size = len(block)
    max_words = np.max([len(i[0]) for i in block])
    mat = np.zeros((block_size, max_words), dtype=int)
    for i in range(block_size):
        for j in range(max_words):
            try:
                if block[i][0][j] in word_to_ix:
                    mat[i,j] = word_to_ix[block[i][0][j]]
            except IndexError:
                pass
    mat = torch.from_numpy(mat)
    embeddings = torch.FloatTensor(np.array([x for _, x, _ in block]))
    labels = torch.FloatTensor(np.array([y for _, _, y in block]))
    return mat, embeddings, labels

In [5]:
train_dataloader = DataLoader(training_data, collate_fn=collate_fn, batch_size=batch_size)
val_dataloader = DataLoader(val_data, collate_fn=collate_fn, batch_size=test_batch_size)
test_dataloader = DataLoader(test_data, collate_fn=collate_fn, batch_size=test_batch_size)

In [6]:
class KSI(nn.Module):
    def __init__(self, n_ksi_embedding, n_vocab):
        super().__init__()
        self.ksi_embedding = nn.Linear(n_vocab, n_ksi_embedding)
        self.ksi_attention = nn.Linear(n_ksi_embedding, n_ksi_embedding)
        self.ksi_output = nn.Linear(n_ksi_embedding, 1)
        
    def forward_ksi(self, notevec, wikivec):
        with torch.profiler.record_function("KSI Forward"):
            n = notevec.shape[0]
            n_codes = wikivec.shape[0]
            notevec = notevec.unsqueeze(1).expand(n, n_codes, -1)
            wikivec = wikivec.unsqueeze(0)
        
            z = torch.mul(wikivec, notevec)
            e = self.ksi_embedding(z)
            attention_scores = torch.sigmoid(self.ksi_attention(e))
            v = torch.mul(attention_scores, e)
            s = self.ksi_output(v)
            o = s.squeeze(2)
        
        return o


class CNN(nn.Module):
    def __init__(self, n_words, n_wiki, n_embedding, ksi=None, **kwargs):
        super().__init__(**kwargs)
        self.ksi = ksi
        self.word_embeddings = nn.Embedding(n_words+1, n_embedding)
        self.dropout_embedding = nn.Dropout(p=0.2)
        self.conv1 = nn.Conv1d(n_embedding, 100, 3)
        self.conv2 = nn.Conv1d(n_embedding, 100, 4)
        self.conv3 = nn.Conv1d(n_embedding, 100, 5)
        self.output = nn.Linear(n_embedding*3, n_wiki)
    
    def forward(self, note, notevec=None, wikivec=None):
        # batch_size, n = note.shape
        with torch.profiler.record_function("CNN Embedding"):
            embeddings = self.word_embeddings(note) # (batch_size, n, n_embedding)
            embeddings = self.dropout_embedding(embeddings)
            embeddings = embeddings.permute(0, 2, 1) # (batch_size, n_embedding, n)
        
        with torch.profiler.record_function("CNN Forward"):
            a1 = F.relu(self.conv1(embeddings))
            a1 = F.max_pool1d(a1, a1.shape[2])
            a2 = F.relu(self.conv2(embeddings))
            a2 = F.max_pool1d(a2, a2.shape[2])
            a3 = F.relu(self.conv3(embeddings))
            a3 = F.max_pool1d(a3, a3.shape[2])
            combined = torch.cat([a1, a2, a3], 1).squeeze(2)
       
            out = self.output(combined)
        if self.ksi:
            out += self.ksi.forward_ksi(notevec, wikivec)
        
        scores = torch.sigmoid(out)
        return scores

In [7]:
def train(model, dataloader, loss_function, optimizer, wikivec=None, profiler=None):
    for data in dataloader:
        optimizer.zero_grad()
        note, embeddings, labels = data
        note = note.to(DEVICE)
        embeddings = embeddings.to(DEVICE)
        labels = labels.to(DEVICE)
        scores = model(note, embeddings, wikivec)
        loss = loss_function(scores, labels)
        loss.backward()
        optimizer.step()
        if profiler:
            profiler.step()

        
def test(model, dataloader, wikivec=None, threshold=0.5, k=10):
    y = []
    yhat = []
    recall = []
    for data in dataloader:
        note, embeddings, labels = data
        note = note.to(DEVICE)
        embeddings = embeddings.to(DEVICE)
        out = model(note, embeddings, wikivec).cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()
        y.append(labels)
        yhat.append(out)
        
    y = np.concatenate(y)
    yhat = np.concatenate(yhat)
    preds = np.array(yhat > threshold, dtype=float)
    for i in range(yhat.shape[0]):
        n_labels = int(y[i, :].sum())
        topk = max(k, n_labels)
        ind_topk = np.argpartition(yhat[i, :], -topk)[-topk:]
        recall.append(y[i, ind_topk].sum() / n_labels if n_labels > 0 else np.nan)
        
    mask = np.sum(y, axis=0) > 0 # mask out classes without both positive and negative examples

    recall = np.nanmean(recall)
    micro_f1 = f1_score(y[:, mask], preds[:, mask], average='micro')
    macro_f1 = f1_score(y[:, mask], preds[:, mask], average='macro')
    micro_auc = roc_auc_score(y[:, mask], yhat[:, mask], average='micro')
    macro_auc = roc_auc_score(y[:, mask], yhat[:, mask], average='macro')
    return recall, micro_f1, macro_f1, micro_auc, macro_auc

In [8]:
# note_lengths = []
# for data in train_dataloader:
#     n, _, _ = data
#     note_lengths.append(n.shape[1])
# avg_note_size = np.round(np.array(note_lengths).mean()).astype(int)

avg_note_size = 2455

In [9]:
n_epochs = 25
save = True
profile = False

In [10]:
base_model = CNN(n_words, n_wiki, n_embedding)
base_model = base_model.to(DEVICE)
base_summary = summary(base_model, [(batch_size, avg_note_size), (batch_size, n_vocab)], dtypes=[torch.int, torch.float])

base_summary

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      --                        --
├─Embedding: 1-1                         [32, 2455, 100]           4,796,200
├─Dropout: 1-2                           [32, 2455, 100]           --
├─Conv1d: 1-3                            [32, 100, 2453]           30,100
├─Conv1d: 1-4                            [32, 100, 2452]           40,100
├─Conv1d: 1-5                            [32, 100, 2451]           50,100
├─Linear: 1-6                            [32, 344]                 103,544
Total params: 5,020,044
Trainable params: 5,020,044
Non-trainable params: 0
Total mult-adds (G): 9.60
Input size (MB): 1.87
Forward/backward pass size (MB): 251.25
Params size (MB): 20.08
Estimated Total Size (MB): 273.20

In [11]:
loss_function = nn.BCELoss()
optimizer = optim.Adam(base_model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dataloader), epochs=n_epochs)
if profile: 
    with torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ], profile_memory=True, use_cuda=True, on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/CNN')) as prof_base:
        for epoch in range(n_epochs):
            train(base_model, train_dataloader, loss_function, optimizer, profiler=prof_base)
            t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(base_model, train_dataloader)
            v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(base_model, val_dataloader)
            print(f'Epoch: {epoch+1:03d}, Train Recall@10: {t_recall_at_k:.4f}, Val Recall@10: {v_recall_at_k:.4f}' + 
                f', Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
                f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}' +
                f', Train Micro AUC: {t_micro_auc:.4f}, Val Micro AUC: {v_micro_auc:.4f}' +
                f', Train Macro AUC: {t_macro_auc:.4f}, Val Macro AUC: {v_macro_auc:.4f}')
else: 
    for epoch in range(n_epochs):
        train(base_model, train_dataloader, loss_function, optimizer, profiler=None)
        t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(base_model, train_dataloader)
        v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(base_model, val_dataloader)
        print(f'Epoch: {epoch+1:03d}, Train Recall@10: {t_recall_at_k:.4f}, Val Recall@10: {v_recall_at_k:.4f}' + 
            f', Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
            f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}' +
            f', Train Micro AUC: {t_micro_auc:.4f}, Val Micro AUC: {v_micro_auc:.4f}' +
            f', Train Macro AUC: {t_macro_auc:.4f}, Val Macro AUC: {v_macro_auc:.4f}')

if save:
    torch.save(base_model, 'CNN_model.pt')

  warn("use_cuda is deprecated, use activities argument instead")


Epoch: 001, Train Recall@10: 0.6880, Val Recall@10: 0.6757, Train Micro F1: 0.5601, Val Micro F1: 0.5495, Train Macro F1: 0.0755, Val Macro F1: 0.0879, Train Micro AUC: 0.9591, Val Micro AUC: 0.9468, Train Macro AUC: 0.7347, Val Macro AUC: 0.7171
Epoch: 002, Train Recall@10: 0.7511, Val Recall@10: 0.7339, Train Micro F1: 0.5996, Val Micro F1: 0.5784, Train Macro F1: 0.1173, Val Macro F1: 0.1239, Train Micro AUC: 0.9736, Val Micro AUC: 0.9608, Train Macro AUC: 0.8792, Val Macro AUC: 0.7826
Epoch: 003, Train Recall@10: 0.7649, Val Recall@10: 0.7382, Train Micro F1: 0.6216, Val Micro F1: 0.5919, Train Macro F1: 0.1398, Val Macro F1: 0.1399, Train Micro AUC: 0.9767, Val Micro AUC: 0.9620, Train Macro AUC: 0.9211, Val Macro AUC: 0.7950
Epoch: 004, Train Recall@10: 0.7753, Val Recall@10: 0.7382, Train Micro F1: 0.6369, Val Micro F1: 0.6008, Train Macro F1: 0.1930, Val Macro F1: 0.1525, Train Micro AUC: 0.9789, Val Micro AUC: 0.9617, Train Macro AUC: 0.9366, Val Macro AUC: 0.7875
Epoch: 005, 

In [None]:
print(prof_base.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         9.35%       11.082s         9.35%       11.082s     411.981us       76.853s        51.15%       76.853s       2.857ms           0 b           0 b           0 b           0 

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc = test(base_model, test_dataloader, wikivec)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')

  recall.append(y[i, ind_topk].sum() / n_labels)


Test Recall@10: nan, Test Micro F1: 0.5383, Test Macro F1: 0.0838, Test Micro AUC: 0.9483, Test Macro AUC: 0.6998


In [None]:
ksi = KSI(n_embedding, n_vocab)
ksi.to(DEVICE)
model = CNN(n_words, n_wiki, n_embedding, ksi=ksi)
model = model.to(DEVICE)
ksi_summary = summary(model, [(batch_size, avg_note_size), 
                              (batch_size, n_vocab),
                              (n_wiki, n_vocab)], 
                      dtypes=[torch.int, torch.float, torch.float])

ksi_summary

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      --                        --
├─KSI: 1-1                               --                        --
│    └─Linear: 2-1                       --                        (recursive)
│    └─Linear: 2-2                       --                        (recursive)
│    └─Linear: 2-3                       --                        (recursive)
├─Embedding: 1-2                         [32, 2455, 100]           4,796,200
├─Dropout: 1-3                           [32, 2455, 100]           --
├─Conv1d: 1-4                            [32, 100, 2453]           30,100
├─Conv1d: 1-5                            [32, 100, 2452]           40,100
├─Conv1d: 1-6                            [32, 100, 2451]           50,100
├─Linear: 1-7                            [32, 344]                 103,544
├─KSI: 1-1                               --                        --
│    └─Linear: 2-4                

In [None]:
loss_function = nn.BCELoss()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dataloader), epochs=n_epochs)  
if profile:
    with torch.profiler.profile(activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ], profile_memory=True, use_cuda=True, on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/CNN_KSI')) as prof_ksi:
        for epoch in range(n_epochs):
            train(model, train_dataloader, loss_function, optimizer, wikivec, profiler=prof_ksi)
            t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(model, train_dataloader, wikivec)
            v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(model, val_dataloader, wikivec)
            print(f'Epoch: {epoch+1:03d}, Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
                f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}')
else:
    for epoch in range(n_epochs):
        train(model, train_dataloader, loss_function, optimizer, wikivec, profiler=None)
        t_recall_at_k, t_micro_f1, t_macro_f1, t_micro_auc, t_macro_auc = test(model, train_dataloader, wikivec)
        v_recall_at_k, v_micro_f1, v_macro_f1, v_micro_auc, v_macro_auc = test(model, val_dataloader, wikivec)
        print(f'Epoch: {epoch+1:03d}, Train Micro F1: {t_micro_f1:.4f}, Val Micro F1: {v_micro_f1:.4f}' +
            f', Train Macro F1: {t_macro_f1:.4f}, Val Macro F1: {v_macro_f1:.4f}')

if save:
    torch.save(model, 'KSI_CNN_model.pt')

  warn("use_cuda is deprecated, use activities argument instead")
  recall.append(y[i, ind_topk].sum() / n_labels)
  recall.append(y[i, ind_topk].sum() / n_labels)


Epoch: 001, Train Micro F1: 0.5746, Val Micro F1: 0.5591, Train Macro F1: 0.1366, Val Macro F1: 0.1489


In [None]:
print(prof_ksi.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_        14.48%       20.872s        14.48%       20.872s     778.471us       80.707s        45.09%       80.707s       3.010ms           0 b           0 b           0 b           0 

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc = test(model, test_dataloader, wikivec)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')

  recall.append(y[i, ind_topk].sum() / n_labels)


Test Recall@10: nan, Test Micro F1: 0.5551, Test Macro F1: 0.1381, Test Micro AUC: 0.9656, Test Macro AUC: 0.8174
