In [7]:
from emrgpt.model import TimelineBasedEmrGPT
from emrgpt.trainer import TimelineDS
import torch
import dotenv
import pandas as pd

dotenv.load_dotenv('.env', override=True)

True

In [8]:
model = TimelineBasedEmrGPT(
    n_event_types=13,
    d_model=32,
    block_size=24,
    max_len=24,
    n_head=10,
    n_layer=10,
    dropout=0.2,
).to('cuda')

model.load_state_dict(torch.load('cache/savedmodels/TimelineBasedEmrGPT.pt'))
model.eval()
model

TimelineBasedEmrGPT(
  (proj): Linear(in_features=13, out_features=32, bias=True)
  (positional_encoding): FixedPositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (blocks): Sequential(
    (0): AkDecoderBlock(
      (self_attention): AkMultiHeadAttention(
        (heads): ModuleList(
          (0-9): 10 x AkSelfAttentionHead(
            (key): Linear(in_features=32, out_features=3, bias=False)
            (query): Linear(in_features=32, out_features=3, bias=False)
            (value): Linear(in_features=32, out_features=3, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=30, out_features=32, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feedforward): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=32, bias=True)
          (3): Dropout

In [9]:
ds = TimelineDS()
ds

<emrgpt.trainer.TimelineDS at 0x7770257435c0>

In [19]:
g = model.generate(
    seed=ds.normalize(
        torch.tensor([65., 121., 82., 12., 38., 99., 120., 0., 0., 0., 0., 0., 0.], device='cuda').unsqueeze(0)
    ).unsqueeze(0)
)
g_denorm = ds.denormalize(g)
df = pd.DataFrame(data=g_denorm.detach().cpu(), columns=ds.features)
df

Unnamed: 0,heart_rate,sbp,dbp,resp_rate,temperature,spo2,glucose,norepi_eq_rate,vent_hfnc,vent_suppo2,vent_noninvasive,vent_invasive,vent_trach
0,65.0,121.0,82.0,12.0,38.0,99.0,120.0,0.0,0.0,0.0,0.0,0.0,0.0
1,98.015945,120.280678,73.730453,14.756079,37.250565,100.794189,106946.273438,0.001833,-0.036907,-0.116157,0.023932,0.061125,0.048466
2,98.398155,120.181458,78.935555,13.730887,37.080048,101.638298,123824.835938,0.003296,-0.042171,-0.156667,0.033381,0.059209,0.037101
3,100.027328,122.257576,81.01107,14.630562,36.91925,103.214905,133499.453125,0.003761,-0.051988,-0.165182,0.048639,0.056198,0.032363
4,100.645485,126.165276,83.60714,15.907503,36.759117,104.583427,146454.78125,0.003333,-0.067481,-0.173801,0.055059,0.058111,0.033732
5,100.387558,129.440048,85.951973,16.81171,36.703255,104.338867,154015.421875,0.002949,-0.08373,-0.192458,0.051596,0.06679,0.035101
6,98.106346,129.621262,86.159973,17.194729,36.824699,102.943184,148125.046875,0.003467,-0.08721,-0.215878,0.04965,0.077756,0.033969
7,95.489601,126.842888,84.980652,17.102505,36.840492,101.641464,132640.84375,0.004692,-0.083406,-0.237758,0.050986,0.085694,0.030702
8,93.417038,124.768654,82.972481,17.143862,36.612068,101.642517,113584.804688,0.005667,-0.079691,-0.252201,0.055746,0.088552,0.03455
9,92.795082,125.894005,82.330223,17.347012,36.206512,103.262779,92984.25,0.005674,-0.081459,-0.264087,0.055312,0.096997,0.050526


In [14]:
df

Unnamed: 0,heart_rate,sbp,dbp,resp_rate,temperature,spo2,glucose,norepi_eq_rate,vent_hfnc,vent_suppo2,vent_noninvasive,vent_invasive,vent_trach
0,1.0,0.1,1.0,0.13,18.200001,0.01,0.12,0.0,0.0,0.0,0.0,0.0,0.0
1,67.126625,108.805267,67.396133,18.017878,38.230335,93.500847,101155.625,0.001525,-0.020251,-0.25789,0.042337,0.120837,-0.000843
2,93.082596,110.134888,73.557442,14.87123,36.523766,102.099754,140243.859375,0.00256,-0.042195,-0.296538,0.063482,0.139033,0.016286
3,94.631401,114.932678,77.781364,15.011525,36.646286,104.216454,128890.835938,0.003842,-0.058975,-0.303178,0.069137,0.131352,0.014519
4,95.722847,120.820076,80.300064,16.363865,36.507973,105.660561,138110.25,0.003536,-0.076348,-0.310995,0.075117,0.133808,0.017847
5,96.183174,125.680359,82.420845,17.039083,36.561054,105.56292,138208.609375,0.003279,-0.090881,-0.322787,0.071857,0.149437,0.02452
6,94.726097,126.226494,83.52668,17.132236,36.753139,103.971451,128054.015625,0.00399,-0.096405,-0.339873,0.068149,0.169239,0.025107
7,91.723129,124.037872,82.914688,17.013645,36.804108,102.349396,111454.046875,0.005367,-0.091475,-0.353332,0.06941,0.183404,0.019463
8,89.918213,121.698395,81.488052,17.057571,36.637569,102.405556,88028.9375,0.006584,-0.086969,-0.360173,0.072209,0.191068,0.021647
9,90.040672,122.734245,81.071159,17.411385,36.252075,104.079338,60593.703125,0.006664,-0.087818,-0.369001,0.073519,0.205213,0.038157
