### Load dataset and showcase aspects of what is actually loaded.

In [2]:
import os 
from src.mimic3 import MIMIC3_ICD9
from pathlib import Path
directory = "/srv/local/data/jw3/physionet.org/files"
mimiciii_directory = os.path.join(directory, "mimic-iii-clinical-database-1.4")
stored_processed_path = Path("/srv/local/data/jw3/physionet.org/files/files/data/mimiciii_clean") # where to store the processed .feather files for use and fast loading
print(mimiciii_directory)
dataset = MIMIC3_ICD9(mimiciii_directory, processed_path=stored_processed_path)
dataset.info()
   

/srv/local/data/jw3/physionet.org/files/mimic-iii-clinical-database-1.4
Loading The Processed Data
Data Loaded Successfully!


{'num_classes': 3681,
 'num_examples': 52712,
 'num_train_tokens': array(58074899),
 'average_tokens_per_example': 1101.7396228562757,
 'num_train_examples': 38428,
 'num_val_examples': 5548,
 'num_test_examples': 8736,
 'num_train_classes': 3681,
 'num_val_classes': 3676,
 'num_test_classes': 3681,
 'average_classes_per_example': 15.584838366975262}

### Get access to a training set sample.

In [3]:
training_set = dataset.get_training_set()
training_set[0]

("admission date:  [**2101-10-20**]     discharge date:  [**2101-10-31**]\n\ndate of birth:   [**2025-4-11**]     sex:  m\n\nservice:  medicine\n\nchief complaint:  admitted from rehabilitation for\nhypotension (systolic blood pressure to the 70s) and\ndecreased urine output.\n\nhistory of present illness:  the patient is a 76-year-old\nmale who had been hospitalized at the [**hospital1 190**] from [**10-11**] through [**10-19**] of [**2101**]\nafter undergoing a left femoral-at bypass graft and was\nsubsequently discharged to a rehabilitation facility.\n\non [**2101-10-20**], he presented again to the [**hospital1 346**] after being found to have a systolic\nblood pressure in the 70s and no urine output for 17 hours.\na foley catheter placed at the rehabilitation facility\nyielded 100 cc of murky/brown urine.  there may also have\nbeen purulent discharge at the penile meatus at this time.\n\non presentation to the emergency department, the patient was\nwithout subjective complaints.  

In [4]:
print(len(dataset.df))

52712


In [5]:
dataset.df.head()

#,_id,text,target,num_words,num_targets,icd9_diag,icd9_proc,split
0,145834,'admission date: [**2101-10-20**] discharge...,"""['38.93', '99.62', '96.6', '89.64', '96.72', '9...",2162,15,"""['584.9', '427.5', '038.9', '410.71', '682.6', ...","""['38.93', '99.62', '96.6', '89.64', '96.72', '9...",train
1,185777,'admission date: [**2191-3-16**] discharge ...,"""['38.93', '33.23', '88.72', '571.5', '042', '13...",1518,12,"""['571.5', '042', '136.3', '276.3', 'E931.7', '0...","['38.93', '33.23', '88.72']",train
2,107064,'admission date: [**2175-5-30**] discharg...,"""['39.57', '55.69', '38.06', '00.91', '99.04', '...",1084,13,"""['276.6', '285.9', '275.3', '403.91', '276.7', ...","['39.57', '55.69', '38.06', '00.91', '99.04']",test
3,150750,'admission date: [**2149-11-9**] discharg...,"""['96.72', '96.04', '507.0', '584.9', '276.5', '...",1488,8,"""['507.0', '584.9', '276.5', '401.9', '428.0', '...","['96.72', '96.04']",train
4,184167,'admission date: [**2103-6-28**] discharg...,"""['99.83', '99.15', '96.6', '765.15', 'V30.00', ...",577,8,"['765.15', 'V30.00', 'V29.0', '765.25', '774.2']","['99.83', '99.15', '96.6']",train
5,194540,'admission date: [**2178-4-16**] d...,"['92.29', '01.13', '01.59', '99.25', '191.3']",1876,5,['191.3'],"['92.29', '01.13', '01.59', '99.25']",test
6,112213,'admission date: [**2104-8-7**] discharge d...,"""['38.93', '54.59', '53.51', '96.71', '54.12', '...",729,18,"""['427.5', '553.21', '568.0', '998.11', '157.0',...","""['38.93', '54.59', '53.51', '96.71', '54.12', '...",train
7,143045,'admission date: [**2167-1-8**] discharge...,"""['36.12', '37.61', '88.72', '36.15', '39.61', '...",1305,10,"['411.1', '272.0', '414.01', '250.00', '401.9']","['36.12', '37.61', '88.72', '36.15', '39.61']",train
8,161087,'admission date: [**2135-5-9**] di...,"""['38.93', '37.31', '88.72', '785.51', '719.46',...",573,11,"""['785.51', '719.46', '458.9', '311', '423.9', '...","['38.93', '37.31', '88.72']",test
9,194023,'admission date: [**2134-12-27**] ...,"""['35.71', '39.61', '88.72', '272.4', 'V12.59', ...",633,7,"['272.4', 'V12.59', '458.29', '745.5']","['35.71', '39.61', '88.72']",test


### Useful helper functions to convert dataset to a PyTorch dataloader

In [17]:
# Convert all classes into multi_hot vectors
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Any, Tuple, List, Sequence

def chunk_text(text: str, chunk_size: int) -> List[str]:
    """Chunk text into smaller pieces i.e words.

    Args:
        text (str): Text to chunk.
        chunk_size (int): Size of each chunk.

    Returns:
        List[str]: List of chunks.
    """
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def get_target_2_index_mimiciii(dataset : MIMIC3_ICD9):
    all_targets = dataset.all_targets()
    target2index = {}
    for i, target in enumerate(all_targets):
        target2index[target] = i 
    return target2index 

def get_index2target(target2index):
    index2target = {}
    for target, index in target2index.items():
        index2target[index] = target
    return index2target

def targets_to_multihot(targets, target2index):
    multihot = torch.zeros(len(target2index))
    for target in targets:
        multihot[target2index[target]] = 1
    return multihot 


# What the paper uses
#  add_special_token: true
#   padding: false
#   use_fast: true
#   do_lower_case: true
#   max_length: null
#   truncation: false


class MIMIC3_Torch(Dataset):    
    def __init__(self, training_set, tokenizer, target2index, chunk_size=512):
        self.training_set = training_set
        self.tokenizer = tokenizer 
        self.target2index = target2index
        self.chunk_size = chunk_size
    def __len__(self):
        return len(self.training_set)
    
    # convert from (text str, targets str) to (token_ids tensor(int), attn_masks tensor(int), torch.tensor(int) )
    def __getitem__(self, index):
        text, targets = self.training_set[index]
        tokenized = self.tokenizer(text, add_special_tokens=True, padding=False, max_length=None)
        token_ids = torch.tensor(tokenized["input_ids"])
        attn_mask = torch.tensor(tokenized["attention_mask"])
        targets = targets_to_multihot(targets, self.target2index)
        return token_ids, attn_mask, targets
    # writing a collate function
    def seq2batch(
            self, sequence: Sequence[torch.Tensor], chunk_size: int = 0
        ) -> torch.Tensor:
            """Batch a sequence of vectors of different lengths. Use the pad_index to pad the vectors.

            Args:
                sequences (Sequence[torch.Tensor]): A sequence of sequences.

            Returns:
                torch.Tensor: A batched tensor.
            """
            if chunk_size == 0:
                return torch.nn.utils.rnn.pad_sequence(
                    sequence, batch_first=True, padding_value=self.tokenizer.pad_token_id
                )
            sequence = list(sequence)
            batch_size = len(sequence)
            max_length = max([len(x) for x in sequence])
            if max_length % chunk_size != 0:
                max_length = max_length + (chunk_size - max_length % chunk_size)

            # pad first sequence to the desired length
            sequence[0] = torch.nn.functional.pad(
                sequence[0],
                (0, max_length - len(sequence[0])),
                value=self.tokenizer.pad_token_id,
            )
            return (
                torch.nn.utils.rnn.pad_sequence(
                    sequence, batch_first=True, padding_value=self.tokenizer.pad_token_id
                )
                .contiguous()
                .view((batch_size, -1, chunk_size))
            )

    def collate_fn(self, batch):

        batch_token_ids = []
        batch_attn_masks = []
        batch_targets = []
        for triplet in batch:
            token_ids, attn_mask, targets = triplet 
            batch_token_ids.append(token_ids)
            batch_attn_masks.append(attn_mask)
            batch_targets.append(targets)

        token_ids = self.seq2batch(batch_token_ids, self.chunk_size)
        attn_mask = self.seq2batch(batch_attn_masks, self.chunk_size)
        targets = self.seq2batch(batch_targets, self.chunk_size)
        return token_ids, attn_mask, targets

In [7]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
# Load some form of a medical RoBERTa model.
tokenizer = AutoTokenizer.from_pretrained("pminervini/RoBERTa-base-PM-M3-Voc-hf", do_lower_case=True, use_fast=True)
model = AutoModelForMaskedLM.from_pretrained("pminervini/RoBERTa-base-PM-M3-Voc-hf")

Some weights of the model checkpoint at pminervini/RoBERTa-base-PM-M3-Voc-hf were not used when initializing RobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
t2index = get_target_2_index_mimiciii(dataset)
torch_dataset = MIMIC3_Torch(dataset.get_training_set(), tokenizer, t2index)

In [9]:
torch_dataset[0]

(tensor([    0, 43917,  4379,  ...,  4893,    64,     2]),
 tensor([1, 1, 1,  ..., 1, 1, 1]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

In [19]:

dataloader = DataLoader(torch_dataset, batch_size=2, collate_fn=torch_dataset.collate_fn)
for token_id, attn_mask, targets in dataloader:
    # print(batch)
    # token_id, attn_mask, targets = batch 
    print(token_id.shape)
    print(targets.shape)
    print(attn_mask.shape)

torch.Size([2, 8, 512])
torch.Size([2, 8, 512])
torch.Size([2, 8, 512])
torch.Size([2, 5, 512])
torch.Size([2, 8, 512])
torch.Size([2, 5, 512])
torch.Size([2, 5, 512])
torch.Size([2, 8, 512])
torch.Size([2, 5, 512])
torch.Size([2, 12, 512])
torch.Size([2, 8, 512])
torch.Size([2, 12, 512])
torch.Size([2, 8, 512])
torch.Size([2, 8, 512])
torch.Size([2, 8, 512])
torch.Size([2, 9, 512])
torch.Size([2, 8, 512])
torch.Size([2, 9, 512])
torch.Size([2, 4, 512])
torch.Size([2, 8, 512])
torch.Size([2, 4, 512])
torch.Size([2, 5, 512])
torch.Size([2, 8, 512])
torch.Size([2, 5, 512])
torch.Size([2, 7, 512])
torch.Size([2, 8, 512])
torch.Size([2, 7, 512])
torch.Size([2, 7, 512])
torch.Size([2, 8, 512])
torch.Size([2, 7, 512])
torch.Size([2, 12, 512])
torch.Size([2, 8, 512])
torch.Size([2, 12, 512])
torch.Size([2, 14, 512])
torch.Size([2, 8, 512])
torch.Size([2, 14, 512])
torch.Size([2, 4, 512])
torch.Size([2, 8, 512])
torch.Size([2, 4, 512])
torch.Size([2, 7, 512])
torch.Size([2, 8, 512])
torch.Size

KeyboardInterrupt: 