In [1]:
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast

from tqdm import tqdm
import pandas as pd
import random

from model.gpt2 import CustomGPT2
from torcheval.metrics import BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall, BinaryConfusionMatrix, BinaryAUPRC, BinaryAccuracy

In [26]:
DATASET = 'mcmed' # penn, mcmed, physionet
CONTEXT_LENGTH = 1024
TOKENIZER_PATH = f'tokenizer/{DATASET}_tokenizer.json'
CKPT_PATH = f'checkpoints/{DATASET}/{DATASET}_gpt2.pth'
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
BATCH_SIZE = 16
THRESHOLD = 0.5

In [3]:
class SepsisDataset(torch.utils.data.Dataset):
    def __init__(self, data, context_length):
        self.samples, labels = [], []
        data['input'] = data.apply(lambda x: list(x['demo_str']) + list(x['input']), axis=1)
        data['input'] = data['input'].apply(lambda x: x[-context_length:] if len(x) > context_length else x)
        data = data[['pat_enc_csn_id', 'input', 'time', 'label']]
        
        for _, d in data.iterrows():
            label = d['label']
            self.samples.append(d)
            labels.append(label)

        self.index_map = list(range(len(self.samples)))
        random.shuffle(self.index_map)

        _, counts = torch.unique(torch.tensor(labels), return_counts=True)
        print(f"Original: {counts}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[self.index_map[idx]] 
        label = float(sample['label'])
        event = list(sample['input'])
        time = sample['time']
        return event, time, label

    @staticmethod
    def collate_fn(batch):
        events = [item[0] for item in batch]
        times = [item[1] for item in batch]
        labels = [item[2] for item in batch]
        labels = torch.tensor(labels)
        return events, times, labels

In [4]:
# TOKENIZER
tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
tokenizer.eos_token = '[EOS]'
tokenizer.sep_token = '[SEP]'
tokenizer.bos_token = '[BOS]'
tokenizer.pad_token = '[PAD]'
tokenizer.cls_token = '[CLS]'
tokenizer.mask_token = '[MASK]'

In [16]:
# LOAD MODEL
model = CustomGPT2(vocab_size=len(tokenizer), num_classes=1)
lmhead_state_dict = torch.load(CKPT_PATH, map_location=DEVICE, weights_only=True)
load_result = model.load_state_dict(lmhead_state_dict, strict=False)
print(load_result)

<All keys matched successfully>


In [17]:
# LOAD DATA
df = pd.read_parquet(f'data/{DATASET}.parquet')

In [18]:
df = df.sample(32)

In [19]:
# CREATE DATALOADER
train = SepsisDataset(df, CONTEXT_LENGTH)
test_loader = torch.utils.data.DataLoader(train, collate_fn=SepsisDataset.collate_fn, batch_size=BATCH_SIZE, shuffle=True)
iter = tqdm(test_loader, total=len(test_loader))

Original: tensor([30,  2])


  0% 0/2 [00:00<?, ?it/s]

In [20]:
# Initialize metrics
auroc = BinaryAUROC().to(DEVICE)
auprc = BinaryAUPRC().to(DEVICE)
f1 = BinaryF1Score().to(DEVICE)
precision = BinaryPrecision().to(DEVICE)
recall = BinaryRecall().to(DEVICE)
confusion_matrix = BinaryConfusionMatrix().to(DEVICE)
accuracy = BinaryAccuracy().to(DEVICE)

In [23]:
for event, time, target in iter:
    with torch.no_grad():
        target = target.to(DEVICE).long()
        
        # GET INPUTS
        max_len = max(len(i) for i in event)
        tokens = tokenizer(event, return_tensors="pt", is_split_into_words=True, padding=True, return_attention_mask=True)
        input_ids = tokens["input_ids"].to(DEVICE)   
        attention_masks = tokens["attention_mask"].to(DEVICE) 

        times = [F.pad(torch.tensor(r, dtype=torch.float32), (0, max_len - len(r))).round(decimals=2)  for r in time]
        times = torch.stack(times, dim=0).to(DEVICE)  

        # INFERENCE
        output = model.to(DEVICE)(input_ids.to(DEVICE), times.to(DEVICE), attention_masks.to(DEVICE))
        output = torch.sigmoid(output).squeeze(-1) 

        preds = output >= THRESHOLD
        f1.update(preds, target)
        precision.update(preds, target)
        recall.update(preds, target)
        auroc.update(output, target)
        auprc.update(output, target)
        confusion_matrix.update(preds, target)
        accuracy.update(preds, target)
        
        iter.set_description(f"Running...")

In [24]:
f1 = f1.compute()
auc = auroc.compute()
auprc = auprc.compute()
precision = precision.compute()
recall = recall.compute()
conf_matrix = confusion_matrix.compute()
test_acc = accuracy.compute()

In [25]:
print(f"Acc: {test_acc:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}, AUPRC: {auprc:.4f}")
print(conf_matrix)

Acc: 0.8438, F1: 0.2857, AUC: 0.8333, AUPRC: 0.2361
tensor([[26.,  4.],
        [ 1.,  1.]], device='cuda:0')
