# CAML: an attentional convolutional network to predict medical codes from clinical text

## Overview

In this project, we will implement the Convolutional Attention for Multi-Label classification (CAML) model proposed by Mullenbach et al. in the paper "[Explainable Prediction of Medical Codes from Clinical Text](https://www.aclweb.org/anthology/N18-1100/)".

Clinical notes are text documents that are created by clinicians for each patient encounter. They are typically accompanied by medical codes, which describe the diagnosis and treatment. Annotating these codes is labor intensive and error prone; furthermore, the connection between the codes and the text is not annotated, obscuring the reasons and details behind specific diagnoses and treatments. Thus, let us implement CAML, an attentional convolutional network to predict medical codes from clinical text.

<img src='img/clinical notes.png'>

Image courtsey: [link](https://www.aclweb.org/anthology/2020.acl-demos.33/)

In [1]:
import os
import csv
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

In [2]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "lib/data/"

## Dataset

For this project, we will be using the Indiana University Chest X-Ray dataset. The goal is to predict diseases using chest x-ray reports.

Navigate to the data folder `DATA_PATH`, there are several files:

- `train_df.csv`, `test_df.csv`: these two files contains the data used for training and testing.
    - `Report ID` refers to a unique chest x-ray report.
    - `Text` refers to the clinical report text.
    - `Label` refers to the diseases.
- `vocab.csv`: this file contains the vocabularies used in the clinical text.

In [3]:
!ls {DATA_PATH}

test_df.csv  train_df.csv  vocab.csv


For example, the first chest x-ray report in `train_df.csv` has:
- `Report ID`: 1
- `Text`: the cardiac silhouette and mediastinum size are within normal limits . there is no pulmonary edema . there is no focal consolidation . there are no xxxx of a pleural effusion . there is no evidence of pneumothorax . normal chest xxxxx .
- `Label`: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

where label is a multi-hot vector representing the following diseases:
```
normal
cardiomegaly
scoliosis / degenerative
fractures bone
pleural effusion
thickening
pneumothorax
hernia hiatal
calcinosis
emphysema / pulmonary emphysema
pneumonia / infiltrate / consolidation
pulmonary edema
pulmonary atelectasis
cicatrix
opacity
nodule / mass
airspace disease
hypoinflation / hyperdistention
catheters indwelling / surgical instruments / tube inserted / medical device
other
```

So this report 1 is labeled as "normal".

## 1 Prepare the Dataset

### 1.1 Helper Functions

To begin, let us first implement some helper functions we will use later.

In [4]:
def to_index(sequence, token2idx):
    """
    Convert the sequnce of tokens to indices. If the word in unknown, then map it to '<unk>'.
    
    INPUT:
        sequence (type: list of str): a sequence of tokens
        token2idx (type: dict): a dictionary mapping token to the corresponding index
    
    OUTPUT:
        indices (type: list of int): a sequence of indicies
        
    EXAMPLE:
        >>> sequence = ['hello', 'world', 'unknown_word']
        >>> token2idx = {'hello': 0, 'world': 1, '<unk>': 2}
        >>> to_index(sequence, token2idx)
        [0, 1, 2]
    """
    result = []
    for word in sequence:
        if word in token2idx.keys():
            result.append(token2idx[word])
        else:
            result.append(token2idx['<unk>'])
    return result

### 1.2 CustomDataset

Now, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the clinical text as input and medical codes as output.

In [13]:
from torch.utils.data import Dataset

NUM_WORDS = 1253
NUM_CLASSES = 20


class CustomDataset(Dataset):
    
    def __init__(self, filename):       
        
        # read in the data files
        self.data = pd.read_csv(filename)
        
        # load word lookup
        self.idx2word, self.word2idx = self.load_lookup(f'{DATA_PATH}/vocab.csv', padding=True)
        
        assert len(self.idx2word) == len(self.word2idx) == NUM_WORDS
        
    def load_lookup(self, filename, padding=False):
        """ load lookup for word """
        idx2token = {}
        with open(filename, 'r') as f:
            for i, line in enumerate(f):
                line = line.strip()
                idx2token[i] = line
        token2idx = {w:i for i,w in idx2token.items()}
        return idx2token, token2idx
        
    def __len__(self):
        
        """
        Return the number of samples (i.e. admissions).
        """
        return len(self.data)
    
    def __getitem__(self, index):
        
        """
        Generate one sample of data.
        """
        data = self.data.iloc[index]
        text = data['Text'].split(' ')
        label = data['Label']
        # convert label string to list
        label = [int(l) for l in label.strip('[]').split(', ')]
        assert len(label) == NUM_CLASSES
        text = to_index(text, self.word2idx)
        return torch.tensor(text, dtype=torch.long), torch.tensor(label, dtype=torch.float)

### 1.3 Collate Function

The collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [3,  1,  2, 8, 5], 
  [12, 13, 6, 7, 12, 23, 11] ]
```

where the first sample has text `[3, 1, 2, 8, 5]` the second sample has text `[12, 13, 6, 7, 12, 23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (7), where 7 is the maximum number of tokens.

```
[ [3,  1,  2, 8, 5, *0*, *0*], 
  [12, 13, 6, 7, 12, 23,  11 ]
```

where `*0*` indicates the padding token.

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [16]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(data):
    """
    OUTPUT:
        text: the padded text, shape: (batch size, max length)
        labels: the stacked labels, shape: (batch size, num classes)
    """
    text, labels = zip(*data)

    # pad the text using pad_sequence()
    text = pad_sequence(text, batch_first=True)
    
    # stack the labels
    labels = torch.stack(labels)
    
    return text, labels

In [17]:
from torch.utils.data import DataLoader

dataset = CustomDataset(f'{DATA_PATH}/train_df.csv')
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
text, labels = next(loader_iter)

All done, now let us load the dataset and data loader.

In [18]:
train_set = CustomDataset(f'{DATA_PATH}/train_df.csv')
test_set = CustomDataset(f'{DATA_PATH}/test_df.csv')
train_loader = DataLoader(train_set, batch_size=32, collate_fn=collate_fn, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn)

In [19]:
train_loader_iter = iter(train_loader)
text, labels = next(train_loader_iter)

In [20]:
text

tensor([[681, 676,  46,  ...,   0,   0,   0],
        [ 55,   9,   7,  ...,   0,   0,   0],
        [  4,  31,   6,  ...,   0,   0,   0],
        ...,
        [ 46, 185,  10,  ...,   0,   0,   0],
        [  4,  34, 447,  ...,   0,   0,   0],
        [  4,  34,  10,  ...,   0,   0,   0]])

In [21]:
text.shape

torch.Size([32, 129])

In [22]:
labels

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [1

In [23]:
labels.shape

torch.Size([32, 20])

## 2 Model

Next, we will implement the CAML model.

<img src='img/caml.png'>

CAML is a convolutional neural network (CNN)-based model. It employs a per-label attention mechanism, which allows the model to learn distinct document representations for each label.

In [24]:
kernel_size=10
num_filter_maps=16
embed_size=100
dropout=0.5

In [26]:
from math import floor
from torch.nn.init import xavier_uniform_

# embedding layer
embed = nn.Embedding(NUM_WORDS, embed_size, padding_idx=0)
embed_drop = nn.Dropout(p=dropout)

# initialize conv layer as in section 2.1
conv = nn.Conv1d(embed_size, num_filter_maps, kernel_size=kernel_size, padding=int(floor(kernel_size/2)))
xavier_uniform_(conv.weight)

# context vectors for computing attention as in section 2.2
U = nn.Linear(num_filter_maps, 20)
xavier_uniform_(U.weight)

# final layer: create a matrix to use for the NUM_CLASSES binary classifiers as in section 2.3
final = nn.Linear(num_filter_maps, NUM_CLASSES)
xavier_uniform_(final.weight)

Parameter containing:
tensor([[ 4.6424e-02, -5.8118e-03, -1.1895e-01,  1.1350e-01,  1.1773e-01,
         -1.2843e-01, -3.2863e-01, -2.8696e-01, -1.9685e-01,  3.0065e-01,
          1.9048e-01,  4.0555e-01,  4.0449e-02,  1.2111e-01, -3.0619e-01,
          3.7543e-01],
        [-1.4240e-01,  3.8347e-01,  1.5925e-01, -1.6462e-01,  3.8513e-01,
          3.9925e-02, -4.0551e-01, -3.5628e-01,  9.6371e-02, -9.1424e-02,
         -2.7839e-01,  2.3444e-01, -1.6602e-02,  1.3793e-01, -3.8036e-01,
         -1.1662e-01],
        [ 2.9073e-01,  2.4191e-01,  2.7346e-02, -1.3438e-01,  3.6025e-01,
         -1.6374e-01,  1.3985e-02,  3.6275e-01,  3.6183e-01,  1.8094e-02,
          2.7703e-01, -8.2314e-02, -2.9007e-02,  6.2736e-02, -2.4469e-01,
         -2.9445e-01],
        [ 2.7320e-01,  2.2017e-01,  1.4782e-01, -3.4094e-01, -3.6820e-01,
          1.3770e-01,  2.0380e-01,  3.2225e-01, -2.5532e-01, -1.3278e-01,
          3.5933e-01, -3.8846e-01, -2.1118e-01,  3.4978e-01, -2.4956e-01,
         -4.5794e-02]

Step by step, we can monitor how the shape changes:

In [27]:
print(text.shape)
text = embed(text)
print(text.shape)
text = embed_drop(text)
print(text.shape)

torch.Size([32, 129])
torch.Size([32, 129, 100])
torch.Size([32, 129, 100])


In [29]:
print(text.shape)
text = text.transpose(1, 2)
print(text.shape)
text = conv(text)
print(text.shape)
text = torch.tanh(text)
print(text.shape)

torch.Size([32, 129, 100])
torch.Size([32, 100, 129])
torch.Size([32, 16, 130])
torch.Size([32, 16, 130])


In [30]:
text = text.transpose(1,2)
print(text.shape)

torch.Size([32, 130, 16])


In [31]:
temp = text.transpose(1,2)
print(temp.shape)
print(U.weight.shape)

torch.Size([32, 16, 130])
torch.Size([20, 16])


In [37]:
alpha = torch.matmul(U.weight, temp)
print(alpha.shape)

torch.Size([32, 20, 130])


In [38]:
alpha = F.softmax(alpha, dim=2)
print(alpha.shape)

torch.Size([32, 20, 130])


In [39]:
v = torch.matmul(alpha, text)
print(v.shape)

torch.Size([32, 20, 16])


In [41]:
print((final.weight).shape)

torch.Size([20, 16])


In [44]:
print((final.weight.unsqueeze(0)).shape)

torch.Size([1, 20, 16])


In [51]:
temp = final.weight*v
print(temp.shape)
temp = torch.sum(temp, dim=2)
print(temp.shape)
temp = temp + final.bias
print(temp.shape)
y_hat = torch.sigmoid(temp)
print(y_hat.shape)

torch.Size([32, 20, 16])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])


Now let's put together the model:

In [52]:
from math import floor
from torch.nn.init import xavier_uniform_

class CAML(nn.Module):

    def __init__(self, kernel_size=10, num_filter_maps=16, embed_size=100, dropout=0.5):
        super(CAML, self).__init__()
        
        # embedding layer
        self.embed = nn.Embedding(NUM_WORDS, embed_size, padding_idx=0)
        self.embed_drop = nn.Dropout(p=dropout)

        # initialize conv layer as in section 2.1
        self.conv = nn.Conv1d(embed_size, num_filter_maps, kernel_size=kernel_size, padding=int(floor(kernel_size/2)))
        xavier_uniform_(self.conv.weight)

        # context vectors for computing attention as in section 2.2
        self.U = nn.Linear(num_filter_maps, 20)
        xavier_uniform_(self.U.weight)

        # final layer: create a matrix to use for the NUM_CLASSES binary classifiers as in section 2.3
        self.final = nn.Linear(num_filter_maps, NUM_CLASSES)
        xavier_uniform_(self.final.weight)
        
    def forward_embed(self, text):
        """
        Feed text through the embedding (self.embed) and dropout layer (self.embed_drop).
        
        INPUT: 
            text: (batch size, seq_len)
            
        OURPUT:
            text: (batch size, seq_len, embed_size)
        """
        text = self.embed(text)
        text = self.embed_drop(text)
        return text
        
    def forward_conv(self, text):
        """
        Feed text through the convolution layer (self.conv) and tanh activation function (torch.tanh) 
        in eq (1) in the paper.
        
        INTPUT:
            text: (batch size, embed_size, seq_len)
            
        OUTPUT:
            text: (batch size, num_filter_maps, seq_len)
        """
        text = self.conv(text)
        text = torch.tanh(text)
        return text
        
    def forward_calc_atten(self, text):
        """
        Calculate the attention weights in eq (2) in the paper. 
        
        INPUT:
            text: (batch size, seq_len, num_filter_maps)

        OUTPUT:
            alpha: (batch size, num_class, seq_len), the attention weights
        """
        # (batch size, seq_len, num_filter_maps) -> (batch size, num_filter_mapsseq_len)
        text = text.transpose(1,2)
        alpha = torch.matmul(self.U.weight, text)
        alpha = F.softmax(alpha, dim=2)
        return alpha
        
    def forward_aply_atten(self, alpha, text):
        """
        Apply the attention in eq (3) in the paper.

        INPUT: 
            text: (batch size, seq_len, num_filter_maps)
            alpha: (batch size, num_class, seq_len), the attention weights
            
        OUTPUT:
            v: (batch size, num_class, num_filter_maps), vector representations for each label
        """
        v = torch.matmul(alpha, text)
        return v
    
    def forward_linear(self, v):
        """
        Apply the final linear classification in eq (5) in the paper.
        
        INPUT: 
            v: (batch size, num_class, num_filter_maps), vector representations for each label
            
        OUTPUT:
            y_hat: (batch size, num_class), label probability
        """
        # multiply `self.final.weight` v `text` element-wise using torch.mul()
        temp = self.final.weight*v
        # sum the result over dim 2 (i.e. num_filter_maps)
        temp = torch.sum(temp, dim=2)
        # add the result with `self.final.bias`
        temp = temp + self.final.bias
        # apply sigmoid with torch.sigmoid()
        y_hat = torch.sigmoid(temp)
        return y_hat

        
    def forward(self, text):
        # 1. get embeddings and apply dropout
        text = self.forward_embed(text)
        # (batch size, seq_len, embed_size) -> (batch size, embed_size, seq_len);
        text = text.transpose(1, 2)

        # 2. apply convolution and nonlinearity (tanh)
        text = self.forward_conv(text)
        # (batch size, num_filter_maps, seq_len) -> (batch size, seq_len, num_filter_maps);
        text = text.transpose(1,2)
        
        # 3. calculate attention
        alpha = self.forward_calc_atten(text)
        
        # 3. apply attention
        v = self.forward_aply_atten(alpha, text)
           
        # 4. final layer classification
        y_hat = self.forward_linear(v)
        
        return y_hat
    
    
model = CAML()

In [53]:
model = CAML()
model.eval()

CAML(
  (embed): Embedding(1253, 100, padding_idx=0)
  (embed_drop): Dropout(p=0.5, inplace=False)
  (conv): Conv1d(100, 16, kernel_size=(10,), stride=(1,), padding=(5,))
  (U): Linear(in_features=16, out_features=20, bias=True)
  (final): Linear(in_features=16, out_features=20, bias=True)
)

## 3 Training and Inferencing

In [54]:
model = CAML()

In [55]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

Now let us implement the `eval()` and `train()` function. Note that `train()` should call `eval()` at the end of each training epoch to see the results on the validaion dataset.

In [56]:
from sklearn.metrics import precision_recall_fscore_support

def eval(model, test_loader):
    
    """    
    INPUT:
        model: the CAML model
        test_loader: dataloader
        
    OUTPUT:
        precision: overall micro precision score
        recall: overall micro recall score
        f1: overall micro f1 score
    """

    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    model.eval()
    for sequences, labels in test_loader:
        y_hat = model(sequences)
        y_hat = y_hat > 0.5
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, labels.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return p, r, f

In [57]:
def train(model, train_loader, test_loader, n_epochs):
    """    
    INPUT:
        model: the CAML model
        train_loader: dataloder
        val_loader: dataloader
        n_epochs: total number of epochs
    """
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for sequences, labels in train_loader:
            optimizer.zero_grad()
            y_hat = model(sequences)
            loss = criterion(y_hat, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        
        p, r, f = eval(model, test_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(epoch+1, p, r, f))

    
# number of epochs to train the model
n_epochs = 20

train(model, train_loader, test_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.456818
Epoch: 1 	 Validation p: 0.00, r:0.00, f: 0.00
Epoch: 2 	 Training Loss: 0.261946
Epoch: 2 	 Validation p: 1.00, r:0.00, f: 0.00
Epoch: 3 	 Training Loss: 0.227357
Epoch: 3 	 Validation p: 0.88, r:0.13, f: 0.22
Epoch: 4 	 Training Loss: 0.213753
Epoch: 4 	 Validation p: 0.85, r:0.21, f: 0.34
Epoch: 5 	 Training Loss: 0.201032
Epoch: 5 	 Validation p: 0.84, r:0.23, f: 0.36
Epoch: 6 	 Training Loss: 0.190393
Epoch: 6 	 Validation p: 0.88, r:0.24, f: 0.38
Epoch: 7 	 Training Loss: 0.178920
Epoch: 7 	 Validation p: 0.90, r:0.34, f: 0.49
Epoch: 8 	 Training Loss: 0.168405
Epoch: 8 	 Validation p: 0.89, r:0.36, f: 0.52
Epoch: 9 	 Training Loss: 0.160167
Epoch: 9 	 Validation p: 0.90, r:0.39, f: 0.55
Epoch: 10 	 Training Loss: 0.153547
Epoch: 10 	 Validation p: 0.90, r:0.41, f: 0.56
Epoch: 11 	 Training Loss: 0.147389
Epoch: 11 	 Validation p: 0.90, r:0.43, f: 0.59
Epoch: 12 	 Training Loss: 0.142694
Epoch: 12 	 Validation p: 0.91, r:0.45, f: 0.60
Epoch: 13 

In [58]:
p, r, f = eval(model, test_loader)