In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from rl import make_transition, make_transition_for_AKI, imvt

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')

In [2]:
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)

In [3]:
wandb.init(
    name="IMV_LSTM_AKI_mse",
    project="AIAKI",
    config={
        "learning_rate":params["learning_rate"],
        "epoch":params['epoch'],
        "batch_size":params["minibatch_size"],
        "n_units":params["n_units"],
        "update_freq":params["update_freq"]
    }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchanreverse[0m ([33mdahs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_loader, train_len = make_transition_for_AKI(params['train'],rolling_size=24,batch_size=64,istrain=True)
val_loader, val_len = make_transition_for_AKI(params['val'],rolling_size=24,batch_size=256,istrain=False)

100%|██████████| 11426/11426 [00:26<00:00, 437.63it/s]


Finished train data transition for AKI


100%|██████████| 762/762 [00:00<00:00, 952.22it/s]


Finished val or test data trnsition for AKI


In [5]:
device = 'cuda:0'
network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)
target_network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)

In [7]:
network.load_state_dict(torch.load('./checkpoint_AKI_mse/checkpoint_29.pt'))
target_network.load_state_dict(torch.load('./checkpoint_AKI_mse/checkpoint_29.pt'))

<All keys matched successfully>

In [8]:
epoch = params['epoch']
gamma = params['gamma']
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = params['update_freq']
train_loss_history = []
val_loss_history = []


for i in range(29,epoch-29,1):
    loss_train = 0
    update_counter = 0
    for s,a,r,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r = r.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q_pred = q.gather(1, a).squeeze()
        
        with torch.no_grad():
            q2,_,_ = target_network(s2)
            q2_net, _, _ = network(s2)

        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

        bellman_target = torch.clamp(r, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
        loss = F.mse_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r = r.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
            loss = F.mse_loss(q_pred, bellman_target)
            loss_val += loss.item()
            
    wandb.log({"Iter:": i, "train:":loss_train/train_len, "val:":loss_val/val_len})
    print("Iter:", i, "train:", loss_train/train_len, "val:",loss_val/val_len)
    torch.save(network.state_dict(), './checkpoint_AKI_mse/checkpoint_%i.pt'%i)
    train_loss_history.append(loss_train/train_len) 
    val_loss_history.append(loss_val/val_len)

14517it [13:35, 17.79it/s]                           


Iter: 29 train: 1.1160536702417026e-07 val: 4.3863731016518555e-10


14517it [13:47, 17.54it/s]                           


Iter: 30 train: 1.7737118274323147e-10 val: 4.242795958312263e-10


14517it [13:52, 17.43it/s]                           


Iter: 31 train: 2.4793229137555825e-10 val: 4.308570659120074e-10


14517it [13:04, 18.51it/s]                           


Iter: 32 train: 2.0663861572354779e-10 val: 3.713760773662258e-10


14517it [12:54, 18.75it/s]                           


Iter: 33 train: 4.714907156253052e-10 val: 4.5563426151712877e-10


14517it [12:50, 18.84it/s]                           


Iter: 34 train: 5.997584156769237e-10 val: 4.1722010128963487e-10


14517it [12:51, 18.82it/s]                           


Iter: 35 train: 1.8977210436804945e-10 val: 3.73784787758261e-10


14517it [12:54, 18.75it/s]                           


Iter: 36 train: 1.035595349956195e-10 val: 4.126060373740303e-10


14517it [12:57, 18.67it/s]                           


Iter: 37 train: 2.4554897358335014e-10 val: 4.374595641814145e-10


14517it [12:58, 18.64it/s]                           


Iter: 38 train: 3.1542508130594557e-09 val: 3.654719623465584e-10


14517it [13:06, 18.46it/s]                           


Iter: 39 train: 7.269902231381445e-09 val: 4.0038380182432393e-10


14517it [13:18, 18.18it/s]                           


Iter: 40 train: 9.570137750734256e-11 val: 4.076902804388937e-10


14517it [13:15, 18.25it/s]                           


Iter: 41 train: 7.398943503046975e-11 val: 3.8002773831722594e-10


14517it [13:15, 18.25it/s]                           


Iter: 42 train: 7.26553854311308e-11 val: 3.5847833356903086e-10


14517it [13:15, 18.26it/s]                           


Iter: 43 train: 3.99211880245745e-11 val: 3.649782454568251e-10


14517it [13:15, 18.26it/s]                           


Iter: 44 train: 3.555271541715885e-11 val: 3.2136370228518204e-10


14517it [13:15, 18.24it/s]                           


Iter: 45 train: 4.413920200225494e-11 val: 3.7134416240069616e-10


14517it [13:15, 18.24it/s]                           


Iter: 46 train: 4.0646697301127486e-11 val: 3.3076315513175394e-10


14517it [13:17, 18.20it/s]                           


Iter: 47 train: 6.573027584922871e-11 val: 3.192560249609453e-10


14517it [13:17, 18.21it/s]                           


Iter: 48 train: 1.9580357280025997e-11 val: 2.982681848041331e-10


14517it [13:23, 18.06it/s]                           


Iter: 49 train: 2.119989326261442e-10 val: 7.237639635057437e-10


14517it [13:19, 18.15it/s]                           


Iter: 50 train: 3.8559057387762986e-11 val: 3.2337060514959385e-10


14517it [13:12, 18.32it/s]                           


Iter: 51 train: 5.6714333510671865e-11 val: 3.708837738132988e-10


14517it [13:08, 18.40it/s]                           


Iter: 52 train: 1.162691825542637e-09 val: 2.937872670686533e-10


14517it [13:07, 18.43it/s]                           


Iter: 53 train: 4.1297733585899626e-11 val: 2.8615589100442987e-10


14517it [13:10, 18.37it/s]                           


Iter: 54 train: 2.019263347175356e-11 val: 3.222426936594436e-10


14517it [13:12, 18.33it/s]                           


Iter: 55 train: 2.2396485892462535e-11 val: 2.964770388371066e-10


14517it [13:18, 18.19it/s]                           


Iter: 56 train: 2.1831155091172434e-11 val: 3.0139822852362823e-10


14517it [13:18, 18.18it/s]                           


Iter: 57 train: 4.3136156584558125e-11 val: 3.0029477884070885e-10


14517it [13:19, 18.17it/s]                           


Iter: 58 train: 5.7692020186367385e-11 val: 3.151158325842002e-10


14517it [13:25, 18.03it/s]                           


Iter: 59 train: 4.1097017839232843e-11 val: 2.9054400899369564e-10


14517it [13:44, 17.61it/s]                           


Iter: 60 train: 4.307669672425337e-11 val: 4.547426178439314e-10


14517it [13:15, 18.25it/s]                           


Iter: 61 train: 1.5615365377455642e-10 val: 2.923607350088459e-10


14517it [13:26, 18.00it/s]                           


Iter: 62 train: 2.2044980277504718e-11 val: 2.74176517172051e-10


14517it [13:26, 18.01it/s]                           


Iter: 63 train: 1.0346208881262875e-10 val: 2.8235814787262994e-10


14517it [13:41, 17.66it/s]                           


Iter: 64 train: 1.7106582209734056e-11 val: 2.795730177914542e-10


14517it [13:39, 17.72it/s]                           


Iter: 65 train: 2.631738661919296e-11 val: 2.7712515022669894e-10


14517it [13:42, 17.64it/s]                           


Iter: 66 train: 2.705728118356659e-11 val: 2.830120480301091e-10


14517it [13:10, 18.35it/s]                           


Iter: 67 train: 9.864863958843894e-11 val: 2.538079587231046e-10


14517it [13:26, 17.99it/s]                           


Iter: 68 train: 1.4452731624083552e-11 val: 2.760856140716477e-10


14517it [13:32, 17.86it/s]                           


Iter: 69 train: 5.197800932438625e-11 val: 2.7331908889731693e-10


14517it [14:54, 16.23it/s]                           


Iter: 70 train: 1.590146695654549e-11 val: 2.5934182342140797e-10


14517it [14:59, 16.14it/s]                           


Iter: 71 train: 1.549802198915347e-10 val: 2.598378828301773e-10


14517it [13:54, 17.39it/s]                           


Iter: 72 train: 8.911318977854043e-11 val: 2.693500352446518e-10


14517it [13:45, 17.58it/s]                           


Iter: 73 train: 2.0289062308566192e-11 val: 2.7917593804560074e-10


14517it [13:35, 17.79it/s]                           


Iter: 74 train: 1.5431349673580516e-10 val: 2.751609797165271e-10


14517it [13:50, 17.48it/s]                           


Iter: 75 train: 2.1532044540791465e-11 val: 2.383917610762693e-10


14517it [14:14, 16.99it/s]                           


Iter: 76 train: 1.3650532373945531e-11 val: 2.5329530158029764e-10


14517it [14:20, 16.86it/s]                           


Iter: 77 train: 1.2795447604077358e-11 val: 2.7174501759395784e-10


14517it [14:14, 16.99it/s]                           


Iter: 78 train: 9.854934475127055e-10 val: 2.62761563127866e-10


14517it [14:16, 16.96it/s]                           


Iter: 79 train: 1.6548938735701485e-11 val: 2.4177373047860066e-10


14517it [14:13, 17.01it/s]                           


Iter: 80 train: 2.294859291422418e-10 val: 2.2356054614202758e-10


14517it [14:12, 17.02it/s]                           


Iter: 81 train: 2.3335629347931712e-11 val: 2.3625194668367337e-10


14517it [13:39, 17.72it/s]                           


Iter: 82 train: 3.6836307850256114e-11 val: 2.3570604595873145e-10


14517it [13:42, 17.64it/s]                           


Iter: 83 train: 1.2326271425122644e-11 val: 2.5127264952644817e-10


14517it [13:10, 18.36it/s]                           


Iter: 84 train: 1.1054913207164815e-11 val: 6.885046558646217e-11


14517it [13:40, 17.70it/s]                           


Iter: 85 train: 1.0279100386746581e-11 val: 2.6431895375797386e-10


14517it [14:28, 16.71it/s]                           


Iter: 86 train: 3.2557061917588974e-10 val: 2.4261434464302713e-10


14517it [14:08, 17.11it/s]                           


Iter: 87 train: 1.4689592847190965e-11 val: 2.2674565864973192e-10


14517it [14:41, 16.46it/s]                           


Iter: 88 train: 2.374975069802663e-11 val: 2.479160024770627e-10


14517it [14:56, 16.18it/s]                           


Iter: 89 train: 5.644357807960236e-11 val: 2.1725027214930584e-10


14517it [14:34, 16.59it/s]                           


Iter: 90 train: 7.978857216910281e-10 val: 2.3322894833546156e-10


14517it [15:22, 15.74it/s]                           


Iter: 91 train: 2.5551982957423326e-11 val: 2.5122855416884597e-10


14517it [13:44, 17.61it/s]                           


Iter: 92 train: 2.140447984004075e-11 val: 1.056981418799968e-09


14517it [14:34, 16.60it/s]                           


Iter: 93 train: 6.011457139357042e-11 val: 2.3402275664385174e-10


14517it [15:32, 15.57it/s]                           


Iter: 94 train: 1.585280266422162e-11 val: 2.0559274679833988e-10


14517it [14:56, 16.20it/s]                           


Iter: 95 train: 2.144495222498613e-11 val: 2.446367933629144e-10


14517it [14:59, 16.15it/s]                           


Iter: 96 train: 1.1777704545536427e-11 val: 2.256708721507355e-10


14517it [15:12, 15.92it/s]                           


Iter: 97 train: 4.024649819172362e-11 val: 2.3239636368425853e-10


14517it [15:42, 15.40it/s]                           


Iter: 98 train: 2.1532673121627827e-11 val: 2.238479480528118e-10


14517it [15:13, 15.88it/s]                           


Iter: 99 train: 3.682893176606147e-11 val: 2.3473323545586385e-10


14517it [15:37, 15.48it/s]                           


Iter: 100 train: 8.613528533648894e-10 val: 2.0622606397476106e-10


14517it [16:29, 14.66it/s]                           


Iter: 101 train: 8.1113130501642e-09 val: 2.2976684425416162e-10


14517it [16:04, 15.05it/s]                           


Iter: 102 train: 1.4794744033886478e-11 val: 2.2408334074735265e-10


14517it [16:33, 14.61it/s]                           


Iter: 103 train: 1.3839880808405693e-11 val: 2.2270999319107817e-10


14517it [17:48, 13.58it/s]                           


Iter: 104 train: 6.317590947798625e-12 val: 2.195854400913589e-10


14517it [17:18, 13.98it/s]                           


Iter: 105 train: 9.113980373512916e-12 val: 2.221211904248961e-10


14517it [16:45, 14.44it/s]                           


Iter: 106 train: 3.108671339338541e-11 val: 2.123864440393811e-10


14517it [17:03, 14.19it/s]                           


Iter: 107 train: 1.1128218423260473e-11 val: 2.1011661510011668e-10


14517it [17:50, 13.56it/s]                           


Iter: 108 train: 7.957407097604212e-12 val: 2.1026161084705333e-10


14517it [17:06, 14.14it/s]                           


Iter: 109 train: 8.63454508795493e-12 val: 2.078374040013934e-10


14517it [17:36, 13.74it/s]                           


Iter: 110 train: 1.7422041971048448e-11 val: 1.975875095566847e-10


14517it [18:20, 13.19it/s]                           


Iter: 111 train: 5.649623005190858e-12 val: 2.183679842533758e-10


14517it [17:51, 13.54it/s]                           


Iter: 112 train: 1.1604230319008413e-11 val: 2.1389719433568233e-10


14517it [18:17, 13.23it/s]                           


Iter: 113 train: 2.7998344938845185e-09 val: 2.0282447332457127e-10


14517it [17:59, 13.45it/s]                           


Iter: 114 train: 1.3667869605968726e-11 val: 2.0589062203846732e-10


14517it [18:28, 13.10it/s]                           


Iter: 115 train: 7.569216893786145e-12 val: 2.0929030886791898e-10


14517it [18:26, 13.12it/s]                           


Iter: 116 train: 1.134750447248227e-11 val: 2.1770266826801784e-10


14517it [18:24, 13.14it/s]                           


Iter: 117 train: 5.896012564318559e-12 val: 2.5899101457777344e-10


14517it [19:25, 12.45it/s]                           


Iter: 118 train: 7.506838220711994e-12 val: 1.9458986055022797e-10


14517it [19:14, 12.57it/s]                           


Iter: 119 train: 1.3877115601670677e-10 val: 8.86385965255351e-10


14517it [17:11, 14.07it/s]                           


Iter: 120 train: 1.752790977520521e-11 val: 2.18749215759492e-10


14517it [14:58, 16.15it/s]                           


Iter: 121 train: 6.090253014232005e-12 val: 1.9344024430355188e-10


14517it [15:30, 15.60it/s]                           


Iter: 122 train: 8.729871720418041e-12 val: 1.8744113571738295e-10


14517it [13:39, 17.72it/s]                           


Iter: 123 train: 8.97192722446458e-12 val: 1.8158991857187528e-10


14517it [13:59, 17.30it/s]                            


Iter: 124 train: 3.975914682217434e-12 val: 1.9457086451114878e-10


14517it [14:23, 16.80it/s]                           


Iter: 125 train: 9.143372729456687e-12 val: 2.018880120348788e-10


14517it [14:55, 16.21it/s]                           


Iter: 126 train: 1.2383094429444518e-11 val: 1.9513989776912083e-10


14517it [14:42, 16.46it/s]                           


Iter: 127 train: 5.412392476296838e-11 val: 1.8988502230948765e-10


14517it [15:11, 15.93it/s]                           


Iter: 128 train: 7.597348762426283e-12 val: 1.8169689511484532e-10


14517it [15:32, 15.56it/s]                           


Iter: 129 train: 4.6627866524276364e-12 val: 1.782019264455305e-10


14517it [14:19, 16.89it/s]                           


Iter: 130 train: 8.877445261161524e-12 val: 1.780910807211194e-10


14517it [14:04, 17.18it/s]                           


Iter: 131 train: 5.2117711825134055e-12 val: 2.332178202877479e-10


14517it [14:27, 16.73it/s]                           


Iter: 132 train: 4.4671545171551546e-12 val: 1.8780984362145342e-10


14517it [14:23, 16.81it/s]                           


Iter: 133 train: 4.506083576177904e-12 val: 1.7627983126964444e-10


14517it [14:41, 16.47it/s]                           


Iter: 134 train: 4.642708037517002e-12 val: 1.675685620401337e-10


14517it [14:26, 16.76it/s]                           


Iter: 135 train: 4.218482773970903e-12 val: 1.7973712191738014e-10


14517it [14:29, 16.70it/s]                           


Iter: 136 train: 3.1580903999172223e-12 val: 1.7593679224386962e-10


14517it [14:30, 16.68it/s]                           


Iter: 137 train: 4.659518538198902e-12 val: 1.8834553109601314e-10


14517it [15:04, 16.05it/s]                           


Iter: 138 train: 5.796865381238303e-10 val: 1.850975064685597e-10


14517it [14:47, 16.35it/s]                           


Iter: 139 train: 5.9703575517758455e-12 val: 1.709411407068669e-10


14517it [14:50, 16.30it/s]                           


Iter: 140 train: 3.567770290642893e-12 val: 1.676889964637432e-10


14517it [14:50, 16.30it/s]                           


Iter: 141 train: 4.473842789142982e-12 val: 1.6775369348907126e-10


14517it [14:50, 16.31it/s]                           


Iter: 142 train: 3.664079038739735e-12 val: 1.7387361269818958e-10


14517it [14:51, 16.28it/s]                           


Iter: 143 train: 4.552625346825841e-12 val: 1.592824519053277e-10


14517it [14:50, 16.31it/s]                           


Iter: 144 train: 3.436860321297371e-12 val: 1.62467228820205e-10


14517it [15:00, 16.13it/s]                           


Iter: 145 train: 3.8528764313500175e-10 val: 1.8227789830697813e-10


14517it [15:05, 16.03it/s]                           


Iter: 146 train: 6.205145705379635e-12 val: 1.6538704629320766e-10


14517it [14:54, 16.23it/s]                           


Iter: 147 train: 3.6314301158659785e-12 val: 1.6604679896710373e-10


14517it [15:05, 16.03it/s]                           


Iter: 148 train: 2.0138461601535465e-11 val: 1.6862090489202892e-10


14517it [15:09, 15.96it/s]                           


Iter: 149 train: 5.796529213507528e-12 val: 1.6872430036252068e-10


14517it [15:15, 15.85it/s]                           


Iter: 150 train: 2.8973223814720087e-12 val: 1.6942543586799922e-10


14517it [15:16, 15.83it/s]                           


Iter: 151 train: 3.5155366766925294e-12 val: 1.5858351476850451e-10


14517it [15:18, 15.81it/s]                           


Iter: 152 train: 3.1352607921534953e-12 val: 1.6509004020233447e-10


14517it [15:15, 15.86it/s]                           


Iter: 153 train: 2.9768456840431207e-12 val: 1.5687360571985267e-10


14517it [15:21, 15.75it/s]                           


Iter: 154 train: 3.0131888220571678e-12 val: 1.5765687992612934e-10


14517it [15:16, 15.83it/s]                           


Iter: 155 train: 3.2244772038603707e-12 val: 1.6467572269036904e-10


14517it [15:25, 15.69it/s]                           


Iter: 156 train: 3.3327140251580332e-12 val: 1.453694898702298e-10


14517it [15:19, 15.78it/s]                           


Iter: 157 train: 7.776964510882327e-12 val: 1.524537712766536e-10


14517it [15:20, 15.78it/s]                           


Iter: 158 train: 3.1347313629494154e-12 val: 1.6541221386464802e-10


14517it [15:28, 15.64it/s]                           


Iter: 159 train: 3.099280101888871e-12 val: 1.5364025881241322e-10


14517it [15:42, 15.41it/s]                           


Iter: 160 train: 6.944938888897269e-12 val: 1.5925791074825794e-10


14517it [15:49, 15.29it/s]                           


Iter: 161 train: 5.997499168315918e-12 val: 1.551936808819828e-10


14517it [15:29, 15.62it/s]                           


Iter: 162 train: 2.421392340280662e-12 val: 1.490260660903194e-10


14517it [15:53, 15.23it/s]                           


Iter: 163 train: 7.766032784368931e-12 val: 1.4861110503910445e-10


14517it [15:54, 15.22it/s]                           


Iter: 164 train: 2.6692568374788137e-12 val: 1.4717861385023395e-10


14517it [16:01, 15.09it/s]                           


Iter: 165 train: 2.9986587409043742e-12 val: 1.5695563591741804e-10


14517it [15:49, 15.29it/s]                           


Iter: 166 train: 3.632268179599293e-12 val: 1.4765125001402881e-10


14517it [15:42, 15.40it/s]                           


Iter: 167 train: 2.9031974462023954e-12 val: 1.4705409978727155e-10


14517it [16:03, 15.07it/s]                           


Iter: 168 train: 3.453276292075074e-12 val: 1.4824659673073872e-10


14517it [16:09, 14.97it/s]                           


Iter: 169 train: 3.420369657941065e-12 val: 1.5024851165222023e-10


14517it [15:51, 15.25it/s]                           


Iter: 170 train: 2.744063774857611e-12 val: 1.4149813591223124e-10
