## COLAB SETUP

In [1]:
# mount your drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
space = 'colab'

In [3]:
if space == 'colab':
    data_root = '/content/drive/MyDrive/Data'
    save_root = ''
else:
    data_root = 'C:/Users/james/Data/MIMIC/mimic-iii-clinical-database-1.4'
    save_root = 'C:/Users/james/Data/MIMIC/mimic-iii-chart-transformers'

In [None]:
!pip install x_transformers

## TENSORBOARD UTILS

In [5]:
%reload_ext tensorboard

In [72]:
from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter()

Do writing! e.g. see [PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html).

Then `tensorboard --logdir runs`.

In [None]:
tensorboard --logdir drive/MyDrive/Data/logs

## PRE-PROCESSING

In [None]:
import os
import numpy as np
import pandas as pd
import pickle as pickle
import torch
from sklearn.model_selection import train_test_split


In [None]:
# paths

chartevents_path = os.path.join(data_root, "CHARTEVENTS.csv")
admissions_path = os.path.join(data_root,"ADMISSIONS.csv")
d_items_path = os.path.join(data_root, "d_items.csv")


In [None]:
# read in admissions

admissions = pd.read_csv(admissions_path,
                         parse_dates=['ADMITTIME', 'DISCHTIME'])

# extract only those charted and apply labelling logic

charted = admissions[admissions.HAS_CHARTEVENTS_DATA == 1]
charted.drop('ROW_ID', axis=1, inplace=True)
charted['HADM_IN_SEQ'] = charted.groupby('SUBJECT_ID')['ADMITTIME'].rank().astype(int)
charted = charted.sort_values(by=['SUBJECT_ID', 'HADM_IN_SEQ'])
charted['ADMITTIME_NEXT'] = charted.groupby('SUBJECT_ID')['ADMITTIME'].shift(-1)
charted['DIS2ADM'] = charted['ADMITTIME_NEXT'] - charted['DISCHTIME']
charted['READM<7'] = (charted['DIS2ADM'] < pd.Timedelta(days=7)).astype(int)
charted['READM<30'] = (charted['DIS2ADM'] < pd.Timedelta(days=30)).astype(int)
charted.set_index('HADM_ID', inplace=True)

# get hadm_ids for the first admission

first_indices = charted[charted.HADM_IN_SEQ == 1].index.to_numpy()

# split first-hadm_ids into train, val, test and check.

train_indices, surplus = train_test_split(first_indices, train_size=0.8)
val_indices, test_indices = train_test_split(surplus, test_size=0.5)
del surplus
assert set(first_indices) == set(train_indices) | set(val_indices) | set(test_indices)

# helpers


def ts_to_posix(time):
    return pd.Timestamp(time, unit='s').timestamp()


def get_admittime(hadm_id):
    time = charted.loc[hadm_id, 'ADMITTIME']
    return ts_to_posix(time)


def get_from_charted(hadm_id, label):
    return charted.loc[hadm_id, label]


# token mappings

d_items = pd.read_csv(d_items_path)

token_shift = 1
pad_token = 0

itemid2token = dict(zip(d_items['ITEMID'], range(token_shift, token_shift + len(d_items))))

# add special tokens to the dictionary
itemid2token['[PAD]'] = pad_token
#itemid2token['[BOS]'] = 1
#itemid2token['[EOS]'] = 2
#itemid2token['[SEP]'] = 3

token2itemid = {v: k for k, v in itemid2token.items()}
token2label = dict(zip(range(len(d_items)), d_items['LABEL']))

with open(os.path.join(save_root, 'mappings.pkl'), 'wb') as f:
    pickle.dump({'itemid2token': itemid2token,
                 'token2itemid': token2itemid},
                f)


def map2token(itemid):
    return itemid2token[np.int(itemid)]


def map2itemid(token):
    return str(token2itemid[token])


def map2itemidstr(tokens):
    return ' '.join(list(map(map2itemid, tokens)))


# loop through sets and generate output files

for subset in ['val', 'train', 'test']:
    print(f'Processing {subset} set data...')

    # grouper for charts

    gpdf = (pd.read_csv(chartevents_path, skiprows=0, 
                        nrows=10000000 if space != 'colab',
                        header=0,
                        usecols=['HADM_ID', 'CHARTTIME', 'ITEMID'],
                        dtype={'HADM_ID': np.int},
                        converters={'ITEMID': map2token},
                        parse_dates=['CHARTTIME'])
            .query(f'HADM_ID.isin(@{subset}_indices)')
            .groupby(by='HADM_ID')
            )

    # initialise

    tokens = dict()
    times = dict()
    times_rel = dict()
    labels = dict()

    # populate with entries

    for i in gpdf.groups:
        time_origin = get_admittime(i)
        temp = gpdf.get_group(i).sort_values(by="CHARTTIME")
        tokens[i] = np.array(temp['ITEMID'], dtype=int)
        times[i] = np.fromiter(
            map(ts_to_posix, temp['CHARTTIME']),
            dtype=np.int64
        )
        times_rel[i] = times[i] - time_origin
        labels[i] = {
            'readm_7': get_from_charted(i, 'READM<7'),
            'readm_30': get_from_charted(i, 'READM<30')
        }

    # write out charts to pickle

    save_path = os.path.join(save_root, f'{subset}_charts.pkl')

    with open(save_path, 'wb') as f:
        pickle.dump({f'{subset}_tokens': tokens,
                     f'{subset}_times': times,
                     f'{subset}_times_rel': times_rel}, f)

    del tokens, times, times_rel, gpdf

    # write out labels to pickle

    save_path = os.path.join(save_root, f'{subset}_labels.pkl')

    with open(save_path, 'wb') as f:
        pickle.dump({f'{subset}_labels': labels}, f)

    del labels

    print(f'{subset} set data processed!')



## SELF-SUPERVISED MODE



### THE MODEL

In [5]:
import os
import copy
import tqdm
import random

import numpy as np
import pandas as pd
import pickle as pickle
import torch

from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

#### Mappings, Paths and Utils

In [14]:
# paths

train_path = os.path.join(data_root, "train_charts.pkl")
val_path = os.path.join(data_root, "val_charts.pkl")
mapping_path = os.path.join(data_root, "mappings.pkl")

# misc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# token mappings:  # TODO: refactor to module where possible.

with open(mapping_path, 'rb') as f:
    mappings = pickle.load(f)
    itemid2token = mappings['itemid2token']
    token2itemid = mappings['token2itemid']
    del mappings

num_tokens = len(itemid2token)

# token mappings: decoders


def decode_token(token):
    return str(token2itemid[token])


def decode_tokens(tokens):
    return ' '.join(list(map(decode_token, tokens)))


# get data

def fetch_data(path, var_key):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data[var_key]


trX = fetch_data(train_path, 'train_tokens')
vaX = fetch_data(val_path, 'val_tokens')

data_train = {k: torch.from_numpy(v) for k, v in trX.items()}
data_val = {k: torch.from_numpy(v) for k, v in vaX.items()}


# yield from loader

def cycle(loader):
    while True:
        for data in loader:
            yield data


#### Constants & Model

In [None]:
# constants & hyperparameters 

NUM_EPOCHS = 10
NUM_BATCHES = 1000
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4  # 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
CHECKPOINT_AFTER = 100
GENERATE_EVERY = 100
GENERATE_LENGTH = 100
SEQ_LEN = 200

# instantiate GPT-like decoder model


LAYER_SPEC = {'dim':100, 'depth':3, 'heads':4}  # full: 512, 6, 8

model = TransformerWrapper(
    num_tokens=num_tokens,  # Expects each val in data to be [0, num_tokens)
    max_seq_len=SEQ_LEN, 
    attn_layers=Decoder(
        dim=LAYER_SPEC['dim'],
        depth=LAYER_SPEC['depth'],
        heads=LAYER_SPEC['heads'])
)

pre_model = AutoregressiveWrapper(model)
pre_model.to(device)

#### Datasets & Dataloaders

In [10]:
class ClsSamplerDataset(Dataset):  # TODO: tidy __getitem__ method with more natural pad operations.
    def __init__(self, data, seq_len, labels=None):
        super().__init__()
        self.data = data
        self.labels = labels
        self.seq_len = seq_len
        self.lookup = dict(zip(np.arange(len(self.data)), self.data.keys()))

    def __getitem__(self, key):  # a.t.m. when data[key] shorter length than SEQ_LEN, padded with 0.
        index = self.lookup[key]
        item_len = self.data[index].size(0)
        rand_start = torch.randint(0, item_len - self.seq_len, (1,)) if item_len > self.seq_len else 0
        lenfromseq = min(item_len, self.seq_len)
        sample = torch.zeros(self.seq_len)
        sample[:lenfromseq] = self.data[index][rand_start: rand_start + lenfromseq]

        if self.labels is not None:
            label = torch.tensor(self.labels[index])
            return sample.long().to(device), label.long().to(device)
        else:
            return sample.long().to(device)

    def __len__(self):
        return len(self.data)


train_dataset = ClsSamplerDataset(data_train, SEQ_LEN)
val_dataset   = ClsSamplerDataset(data_val, SEQ_LEN)

train_loader  = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader    = DataLoader(val_dataset,   batch_size=BATCH_SIZE)

train_cycler  = cycle(train_loader)
val_cycler    = cycle(val_loader)


### TRAINING LOOP

In [None]:
optim = torch.optim.Adam(pre_model.parameters(), lr=LEARNING_RATE)
ckpt_path = os.path.join(save_root, "pre_model_exp1.pt")

writer = SummaryWriter(
    log_dir="runs/pre_model",
    filename_suffix='_' + '_'.join(map(str, LAYER_SPEC.values()))
    )

# training loop

best_val_loss = np.inf

for epoch in range(NUM_EPOCHS):
  for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.,
                     desc=f'epoch {epoch}:', colour='green'):
      pre_model.train()

      for __ in range(GRADIENT_ACCUMULATE_EVERY):
          loss = pre_model(next(train_loader))
          loss.backward()

      torch.nn.utils.clip_grad_norm_(pre_model.parameters(), 0.5)
      optim.step()
      optim.zero_grad()

      # parameter tracking

      writer.add_scalar('train_loss', loss.item(),
                        epoch * NUM_BATCHES + i
                        )

      # validate model

      if i % VALIDATE_EVERY == 0:
          pre_model.eval()
          with torch.no_grad():
              val_loss = pre_model(next(val_loader)).item()
              
              writer.add_scalar('val_loss', val_loss,
                        epoch * NUM_BATCHES + i
                        )

              if val_loss < best_val_loss:
                  print(f'VL: {val_loss} < BVL: {best_val_loss}')
                  best_val_loss = val_loss

                  # checkpoint model

                  if i > CHECKPOINT_AFTER:
                    print("Checkpoint saving...")
                    torch.save({
                        'train_step': i,
                        'model_state_dict': pre_model.state_dict(),
                        'LAYER_SPEC': LAYER_SPEC,
                        'SEQ_LEN': SEQ_LEN,
                        'optim_state_dict': optim.state_dict(),
                        'val_loss': val_loss
                    }, ckpt_path)
                    print("Checkpoint saved!\n")
      
      # generate sequence

      if i % GENERATE_EVERY == 0:
          pre_model.eval()
          inp = random.choice(val_dataset)[:-1]
          primer_str = decode_tokens(inp.cpu().numpy())
          print('\nprimer:', primer_str, '*' * 100, sep='\n')

          sample = pre_model.generate(inp, GENERATE_LENGTH)
          sample_str = decode_tokens(sample.cpu().numpy())
          print('output:', sample_str, '\n', sep='\n')

### EVALUATE

In [None]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    cum_loss = 0
    counter = 0
    for batch in dataloader:
        counter += 1
        batch_size = batch.shape[0]
        val_loss = model(batch).item()
        cum_loss += val_loss
    avg_loss = cum_loss/(batch_size*counter)
    return cum_loss, avg_loss

In [None]:
evaluate(pre_model, val_loader)

(2384.7820653915405, 2.3943595034051612)

Will this code work for finetuning too?

### GENERATING SEQUENCES

In [None]:
# reading d_items for interpretability 
d_items_path = os.path.join(data_root, "d_items.csv")
d_items = pd.read_csv(d_items_path, index_col='ITEMID', dtype={'ITEMID': str})

In [None]:
def decode_token(token):
    return str(token2itemid[token])

def decode_tokens(tokens):
    return ' '.join(list(map(decode_token, tokens)))

def token2label(token):
    if token == 0:
        return '[PAD]'
    else:
        itemid = token2itemid[token]
        x = d_items.loc[itemid, 'LABEL']
    return x

def tokens2labels(tokens):
    return '\n\t -> '.join(list(map(token2label, tokens)))

In [None]:
# fetch and load model state_dict

weights_path = os.path.join(data_root, 'models', 'pre_model_exp1.pt')
X = torch.load(weights_path, map_location=device)
states = X['model_state_dict']
base_states = { k[len('net.'):] if k[:len('net.')] == 'net.' else k : v for k, v in states.items()}

pre_model.load_state_dict(states)
pre_model.to(device)

In [None]:
pre_model.eval()
with torch.no_grad():
    prompt = random.choice(val_dataset)[0:18]
    #prompt = torch.cat((prompt, torch.tensor([0]).to(device)))

    sample = pre_model.generate(start_tokens=prompt, seq_len = 50, eos_token=0)
    print("prompt:\t", tokens2labels(prompt.cpu().numpy()))
    print("model:\t", tokens2labels(sample.cpu().numpy()), '\n')

In [None]:
evaluate(pre_model, train_loader)

(3122.3268125355244, 1.5635086692716698)

### TESTING - DO NOT USE

In [None]:
test_path    = os.path.join(data_root, 'test_charts.pkl')
tsX          = fetch_data(test_path, 'test_tokens')
data_test    = {k: torch.from_numpy(v) for k, v in tsX.items()}
test_dataset = ClsSamplerDataset(data_test, SEQ_LEN)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE)
test_cycler  = cycle(test_loader)

## FINE-TUNING MODE

In [85]:
class FinetuningWrapper(nn.Module):
    def __init__(self, net, num_classes, state_dict = None,
                 ignore_index = -100, pad_value = 0, weight = None):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index
        self.num_classes = num_classes
        self.weight = weight  # expected to be a tensor of size = num_classes

        self.net = copy.deepcopy(net)  # deepcopy is necessary here.
        self.max_seq_len = self.net.max_seq_len

        # initialise net from pretrained
        if state_dict is not None:
            self.net.load_state_dict(state_dict)

        # define classifier head layers
        self.num_features = net.to_logits.in_features * 200
        self.net.clf1 = nn.Linear(self.num_features, num_classes, bias=True)

    def forward(self, X, Y, predict=False, **kwargs):
        Z = self.net(X, return_embeddings=True, **kwargs)
        Z = torch.flatten(Z, start_dim=1)
        logits = self.net.clf1(Z)
        print('net outputted')
        loss = F.cross_entropy(logits, Y, weight = self.weight)
        print('loss computed')
        return logits if predict else loss

In [61]:
train_lbl_path = os.path.join(data_root, "train_labels.pkl")
val_lbl_path = os.path.join(data_root, "val_labels.pkl")
FT_BATCH_SIZE = 100

# fetch labels

with open(train_lbl_path, 'rb') as f:
    X = pickle.load(f)
    train_labels_30 = {k: v['readm_30'] for k, v in  X['train_labels'].items()}
    train_labels_7 = {k: v['readm_7'] for k, v in  X['train_labels'].items()}
    del X

with open(val_lbl_path, 'rb') as f:
    X = pickle.load(f)
    val_labels_30 = {k: v['readm_30'] for k, v in  X['val_labels'].items()}
    val_labels_7 = {k: v['readm_7'] for k, v in  X['val_labels'].items()}
    del X

# helper for propensities

def propensity(di):
    x = sum(di.values()) / len(di)
    return x

# generate datasets and loaders

ft_train_dataset = ClsSamplerDataset(data_train, SEQ_LEN, labels=train_labels)
ft_val_dataset = ClsSamplerDataset(data_val, SEQ_LEN, labels=val_labels)

ft_train_loader = cycle(DataLoader(ft_train_dataset, batch_size=FT_BATCH_SIZE))
ft_val_loader = cycle(DataLoader(ft_val_dataset, batch_size=FT_BATCH_SIZE))

In [66]:
# fetch model weights

params_path = os.path.join(data_root, 'models', 'pre_model_exp1.pt')
X = torch.load(params_path, map_location=device)
states = X['model_state_dict']
base_states = { k[len('net.'):] if k[:len('net.')] == 'net.' else k : v for k, v in states.items()}

### FINETUNING LOOP

#### TRAINING

In [None]:
# propensities

p = propensity(train_labels_30)
weights = torch.tensor([p, 1-p])

# initialisation

fit_model = FinetuningWrapper(model, num_classes=2,
                              state_dict=base_states,
                              weight=weights)
fit_model.to(device)

TODO: bug in F.cross_entropy loss with weights. Got Double not Float error.

In [87]:
# set optimiser and paths

optim_ft = torch.optim.Adam(fit_model.parameters(), lr=0.001)
ckpt_ft_path = os.path.join(save_root, "fit_model_exp0.pt")
logs_path = os.path.join(save_root, "logs", "fit_model_bal0")

# training loop constants

NUM_FT_BATCHES = 100
CHECKPOINT_AFTER = 10
VALIDATE_EVERY = 2

In [88]:
writer = SummaryWriter(logs_path)

# training loop

best_val_loss = np.inf
for i in tqdm.tqdm(range(NUM_FT_BATCHES), mininterval=10., desc='fine-tuning'):
    fit_model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        X, Y = next(ft_train_loader)
        loss = fit_model(X, Y)
        loss.backward()
      
    writer.add_scalar('loss', loss.item(), i)

    print(f'tuning loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(fit_model.parameters(), 0.5)
    optim_ft.step()
    optim_ft.zero_grad()

    # validate fit_model

    if i % VALIDATE_EVERY == 0:
        fit_model.eval()
        with torch.no_grad():
            X, Y = next(ft_val_loader)
            val_loss = fit_model(X, Y).item()
            writer.add_scalar('val_loss', val_loss, i)
            print(f'validation loss: {val_loss}')
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                
                if i > CHECKPOINT_AFTER:
                    print("Saving checkpoint...\n")
                    torch.save({
                        'train_step': i,
                        'model_state_dict': fit_model.state_dict(),
                        'SEQ_LEN': SEQ_LEN,
                        'optim_state_dict': optim_ft.state_dict(),
                        'val_loss': val_loss
                    }, ckpt_ft_path)
                    print("Checkpoint saved!\n")
    #else:
    #    val_loss = np.nan
    #writer.add_scalars('loss', {'tuning_loss': loss.item(), 'val_loss': val_loss}, i)

writer.close()

fine-tuning:   0%|          | 0/100 [00:01<?, ?it/s]

net outputted





RuntimeError: ignored

#### PREDICTION

Weight loading regimes:

1. Initialise with random
2. Unsupervised pretrain
3. Load from `pretrain`
4. Finetune for `task`
5. Load from `finetune`

Here, we are in 5.


In [32]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score

In [16]:
# fetch model weights

ft_model_path = os.path.join(data_root, 'models', 'fit_model_exp1.pt')
X = torch.load(ft_model_path, map_location=device)

In [19]:
ft_states = X['model_state_dict']
ft_base_states = { k[len('net.'):] if k[:len('net.')] == 'net.' else k : v for k, v in ft_states.items()}

In [None]:
# initialisation
fit_model = FinetuningWrapper(model, num_classes=2)
fit_model.load_state_dict(ft_states)
fit_model.to(device)

In [None]:
fit_model.eval()

# nums TP, FP, TN, FN

TP_tot = 0
FP_tot = 0
TN_tot = 0
FN_tot = 0

with torch.no_grad():
    for i in range(10):
        X, Y_true = next(ft_val_loader)
        logits = fit_model(X, Y_true, predict=True)
        Y_pred = torch.argmax(logits, dim=1)
        y_true, y_pred = Y_true.cpu(), Y_pred.cpu()
        TP = ((y_true == 1) & (y_pred == 1)).sum()
        FN = (y_true > y_pred).sum()
        FP = (y_true < y_pred).sum()
        TN = ((y_true == 0) & (y_pred == 0)).sum()
        print(f'TP = {TP}', f'FP = {FP}', f'TN = {TN}', f'FN = {FN}')
        #print(f'TP: {Y_true.numpy().sum()}', f'pP: {Y_pred.numpy().sum()}')
        print(accuracy_score(y_true, y_pred, normalize=True), '\n')
        print(balanced_accuracy_score(y_true, y_pred), '\n')