In [1]:
from emrgpt.trainer import calculate_utility_reintubation, BLOCK_SIZE, DL_WORKERS, BATCH_SIZE
from torch.utils.data import DataLoader, random_split
from emrgpt.data import ReintubationDS, TimelineDS
from emrgpt.model import TimelineBasedEmrGPT
import torch

*** EMR GPT ***


In [2]:
torch.manual_seed(42)
ds = TimelineDS(BLOCK_SIZE)

train_ds, val_ds = random_split(ds, lengths=[0.9, 0.1])

validation_stay_ids = [val_ds.dataset.stay_ids[i] for i in val_ds.indices]
reintubation_validation_ds = ReintubationDS(tlds=ds, stay_ids=validation_stay_ids)

train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    num_workers=DL_WORKERS,
)
val_dl = DataLoader(val_ds, batch_size=512, num_workers=DL_WORKERS)
reintubation_dl = DataLoader(
    reintubation_validation_ds, batch_size=32, num_workers=DL_WORKERS
)

In [3]:
model = TimelineBasedEmrGPT(
    n_event_types=13,
    d_model=32,
    block_size=24,
    max_len=24,
    n_head=10,
    n_layer=10,
    dropout=0.2,
).to('cuda')

model.load_state_dict(torch.load('cache/TimelineBasedEmrGPT.pt'))
model.eval()
model

TimelineBasedEmrGPT(
  (proj): Linear(in_features=13, out_features=32, bias=True)
  (positional_encoding): FixedPositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (blocks): Sequential(
    (0): AkDecoderBlock(
      (self_attention): AkMultiHeadAttention(
        (heads): ModuleList(
          (0-9): 10 x AkSelfAttentionHead(
            (key): Linear(in_features=32, out_features=3, bias=False)
            (query): Linear(in_features=32, out_features=3, bias=False)
            (value): Linear(in_features=32, out_features=3, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=30, out_features=32, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feedforward): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=32, bias=True)
          (3): Dropout

In [4]:
calculate_utility_reintubation(model, reintubation_dl)

np.float64(0.6375719672326107)