In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb, pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from rl import make_transition, make_transition_for_AKI, imvt

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')
    
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)
    
np.random.seed(params['random_seed'])
torch.manual_seed(params['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(params['random_seed'])
    torch.cuda.manual_seed_all(params['random_seed'])

In [2]:
wandb.init(
    name="AKI_smooth_l1_non_aki_label",
    project="AIAKI_IMV_LSTM",
    config={
        "learning_rate":params["learning_rate"],
        "epoch":params['epoch'],
        "batch_size":params["minibatch_size"],
        "n_units":params["n_units"],
        "update_freq":params["update_freq"]
    }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchanreverse[0m ([33mdahs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
train_loader, train_len = make_transition_for_AKI('./code/train_non_aki.parquet',rolling_size=24,batch_size=64,istrain=True)
val_loader, val_len = make_transition_for_AKI('./code/val_non_aki.parquet',rolling_size=24,batch_size=256,istrain=False)

100%|██████████| 10664/10664 [01:11<00:00, 149.14it/s]


Finished train data transition for AKI


100%|██████████| 1523/1523 [00:03<00:00, 503.44it/s]


Finished val or test data trnsition for AKI


In [18]:
device = 'cuda:0'
network = imvt(input_dim=122, output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)
target_network = imvt(input_dim=122, output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)

In [19]:
Model="AKI_smooth_l1_non_aki_label"

epoch = params['epoch']
gamma = params['gamma']
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = params['update_freq']
train_loss_history = []
val_loss_history = []

for i in range(epoch):
    loss_train = 0
    update_counter = 0
    for s,a,r,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r = r.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q_pred = q.gather(1, a).squeeze()
        
        with torch.no_grad():
            q2,_,_ = target_network(s2)
            q2_net, _, _ = network(s2)

        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

        bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
        loss = F.smooth_l1_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r = r.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
            loss = F.smooth_l1_loss(q_pred, bellman_target)
            loss_val += loss.item()
            
    wandb.log({"Iter:": i, "train:":loss_train/train_len, "val:":loss_val/val_len})
    print("Iter:", i, "train:", loss_train/train_len, "val:",loss_val/val_len)
    if not os.path.exists('checkpoint_%s'%Model):
        os.makedirs('./checkpoint_%s'%Model)
    torch.save(network.state_dict(), './checkpoint_%s/checkpoint_%i.pt'%(Model,i))
    torch.save(optimizer.state_dict(), './checkpoint_%s/optim.pth'%Model)
    train_loss_history.append(loss_train/train_len)
    val_loss_history.append(loss_val/val_len)
    
    with open('./checkpoint_%s/train_loss_%s'%(Model,Model), 'wb') as f:
        pickle.dump(train_loss_history, f)
        
    with open('./checkpoint_%s/val_loss_%s'%(Model,Model), 'wb') as f:
        pickle.dump(val_loss_history, f)

13620it [10:17, 22.07it/s]                           


Iter: 0 train: 4.266620782946649e-09 val: 2.0416814742910297e-06


13620it [10:03, 22.56it/s]                           


Iter: 1 train: 2.979897773429268e-12 val: 2.0417066236533477e-06


13620it [10:18, 22.03it/s]                           


Iter: 2 train: 1.761516651937986e-12 val: 2.0417207316528115e-06


13620it [09:49, 23.12it/s]                           


Iter: 3 train: 1.256619323075438e-12 val: 2.0417304045820696e-06


13620it [10:08, 22.39it/s]                           


Iter: 4 train: 9.951570785654165e-13 val: 2.0417418811612e-06


13620it [10:08, 22.38it/s]                           


Iter: 5 train: 8.16995204888966e-13 val: 2.0417448631454843e-06


13620it [09:49, 23.09it/s]                           


Iter: 6 train: 7.161664769114519e-13 val: 2.04174494926524e-06


13620it [10:05, 22.51it/s]                           


Iter: 7 train: 6.256860746063744e-13 val: 2.0417549410400782e-06


13620it [09:59, 22.71it/s]                           


Iter: 8 train: 5.482664140832463e-13 val: 2.0417633886267733e-06


13620it [09:59, 22.74it/s]                           


Iter: 9 train: 5.137325564004509e-13 val: 2.0417667353885707e-06


13620it [10:05, 22.51it/s]                           


Iter: 10 train: 4.561976048445517e-13 val: 2.0417676792235653e-06


13620it [11:01, 20.59it/s]                           


Iter: 11 train: 4.267835714028078e-13 val: 2.0417724312724915e-06


13620it [11:07, 20.41it/s]                           


Iter: 12 train: 3.913721350323313e-13 val: 2.041766333512129e-06


13620it [11:00, 20.61it/s]                           


Iter: 13 train: 3.6534536015885334e-13 val: 2.041773526324948e-06


13620it [09:56, 22.82it/s]                           


Iter: 14 train: 3.4974899749330653e-13 val: 2.0417728234657244e-06


13620it [09:58, 22.77it/s]                           


Iter: 15 train: 3.358242752050679e-13 val: 2.0417761538405835e-06


13620it [09:57, 22.78it/s]                           


Iter: 16 train: 3.0132844783079104e-13 val: 2.0417782001640006e-06


13620it [09:59, 22.71it/s]                           


Iter: 17 train: 2.938778784746944e-13 val: 2.0417787288469495e-06


13620it [10:03, 22.58it/s]                           


Iter: 18 train: 2.8064300741873827e-13 val: 2.04178243902077e-06


13620it [09:54, 22.89it/s]                           


Iter: 19 train: 2.6682746824410937e-13 val: 2.0417819106425083e-06


13620it [10:06, 22.45it/s]                           


Iter: 20 train: 2.445730756883043e-13 val: 2.0417861104686387e-06


13620it [10:10, 22.31it/s]                           


Iter: 21 train: 2.4358158675176734e-13 val: 2.041780654952632e-06


13620it [10:09, 22.36it/s]                           


Iter: 22 train: 2.3111041811582975e-13 val: 2.041779765074607e-06


13620it [10:25, 21.76it/s]                           


Iter: 23 train: 2.1865337431323702e-13 val: 2.041785331722216e-06


13620it [10:25, 21.78it/s]                           


Iter: 24 train: 2.0963593222734797e-13 val: 2.0417831910530787e-06


13620it [10:23, 21.85it/s]                           


Iter: 25 train: 2.0621035536591825e-13 val: 2.041785574158383e-06


13620it [10:20, 21.96it/s]                           


Iter: 26 train: 1.939515034323082e-13 val: 2.041785150122898e-06


13620it [10:22, 21.89it/s]                           


Iter: 27 train: 1.8841583850313662e-13 val: 2.041784013412314e-06


13620it [10:21, 21.92it/s]                           


Iter: 28 train: 1.814981034432475e-13 val: 2.041782980896944e-06


13620it [10:21, 21.92it/s]                           


Iter: 29 train: 1.7745211436211111e-13 val: 2.0417840463056303e-06


13620it [10:27, 21.71it/s]                           


Iter: 30 train: 1.6904001072914912e-13 val: 2.041787963827253e-06


13620it [10:26, 21.72it/s]                           


Iter: 31 train: 1.6266252212428187e-13 val: 2.0417843557473306e-06


13620it [10:31, 21.55it/s]                           


Iter: 32 train: 1.6060494137341647e-13 val: 2.0417827560233055e-06


13620it [10:28, 21.68it/s]                           


Iter: 33 train: 1.5655257618733956e-13 val: 2.041784270623297e-06


13620it [10:21, 21.93it/s]                           


Iter: 34 train: 1.5021957142873171e-13 val: 2.0417868493657337e-06


13620it [10:29, 21.63it/s]                           


Iter: 35 train: 1.4856667604160973e-13 val: 2.041781652362863e-06


13620it [10:35, 21.44it/s]                           


Iter: 36 train: 1.4224374646030653e-13 val: 2.0417835492994368e-06


13620it [10:36, 21.38it/s]                           


Iter: 37 train: 1.377893550053011e-13 val: 2.0417866265408446e-06


13620it [10:36, 21.41it/s]                           


Iter: 38 train: 1.320123869714621e-13 val: 2.0417881978409597e-06


13620it [10:25, 21.77it/s]                           


Iter: 39 train: 1.3104850345147495e-13 val: 2.0417867255409742e-06


13620it [10:35, 21.43it/s]                           


Iter: 40 train: 1.2910204354598352e-13 val: 2.0417861176713236e-06


13620it [10:44, 21.13it/s]                           


Iter: 41 train: 1.2507087104130068e-13 val: 2.0417855376523115e-06


13620it [10:53, 20.83it/s]                           


Iter: 42 train: 1.213907423362473e-13 val: 2.041785097241578e-06


13620it [10:38, 21.34it/s]                           


Iter: 43 train: 1.1683044428176523e-13 val: 2.041786436834044e-06


13620it [10:52, 20.88it/s]                           


Iter: 44 train: 1.1551993735270756e-13 val: 2.0417843039893756e-06


13620it [10:33, 21.50it/s]                           


Iter: 45 train: 1.127904621068549e-13 val: 2.04178379837965e-06


13620it [10:41, 21.22it/s]                           


Iter: 46 train: 1.1136779940647584e-13 val: 2.041785197999395e-06


13620it [10:46, 21.05it/s]                           


Iter: 47 train: 1.0419906069809648e-13 val: 2.0417824415551033e-06


13620it [10:42, 21.21it/s]                           


Iter: 48 train: 1.0355384732495799e-13 val: 2.0417842071739933e-06


13620it [10:39, 21.28it/s]                           


Iter: 49 train: 1.0481343498403367e-13 val: 2.041786975050524e-06


13620it [10:33, 21.49it/s]                           


Iter: 50 train: 1.0404870679465969e-13 val: 2.0417874880868603e-06


13620it [10:26, 21.75it/s]                           


Iter: 51 train: 9.831944440378406e-14 val: 2.041784432342359e-06


13620it [10:40, 21.26it/s]                           


Iter: 52 train: 9.63999726246088e-14 val: 2.0417838784251617e-06


13620it [10:40, 21.27it/s]                           


Iter: 53 train: 9.302824767210898e-14 val: 2.0417849671472286e-06


13620it [10:45, 21.10it/s]                           


Iter: 54 train: 9.18373353921421e-14 val: 2.041782805516858e-06


13620it [10:39, 21.28it/s]                           


Iter: 55 train: 9.155031547052914e-14 val: 2.041783712744275e-06


13620it [11:04, 20.49it/s]                           


Iter: 56 train: 8.948954001608672e-14 val: 2.0417854325184847e-06


13620it [10:51, 20.90it/s]                           


Iter: 57 train: 9.00943602195776e-14 val: 2.0417826890598317e-06


13620it [10:55, 20.79it/s]                           


Iter: 58 train: 8.696147405572828e-14 val: 2.0417862126994765e-06


13620it [10:53, 20.85it/s]                           


Iter: 59 train: 8.455900510759158e-14 val: 2.041782050278963e-06


13620it [10:48, 21.02it/s]                           


Iter: 60 train: 8.3865027217656e-14 val: 2.041784022322417e-06


13620it [10:49, 20.97it/s]                           


Iter: 61 train: 8.1778365162181e-14 val: 2.041784009258103e-06


13620it [10:42, 21.21it/s]                           


Iter: 62 train: 8.219328650191593e-14 val: 2.0417823940048937e-06


13620it [10:48, 21.00it/s]                           


Iter: 63 train: 8.05690615299426e-14 val: 2.0417827160892587e-06


13620it [10:58, 20.68it/s]                           


Iter: 64 train: 7.70511037202181e-14 val: 2.041784005359974e-06


13620it [10:45, 21.09it/s]                           


Iter: 65 train: 7.774910085943289e-14 val: 2.041782471024845e-06


13620it [10:55, 20.78it/s]                           


Iter: 66 train: 7.633871660160957e-14 val: 2.0417824310889257e-06


13620it [11:28, 19.78it/s]                           


Iter: 67 train: 7.336142568353184e-14 val: 2.041785304124556e-06


13620it [10:59, 20.65it/s]                           


Iter: 68 train: 7.166065047031328e-14 val: 2.0417829254393262e-06


13620it [10:49, 20.98it/s]                           


Iter: 69 train: 7.404646407855174e-14 val: 2.0417827227931522e-06


13620it [11:01, 20.58it/s]                           


Iter: 70 train: 7.419813431113872e-14 val: 2.041783796707691e-06


13620it [12:18, 18.44it/s]                           


Iter: 71 train: 7.061077197823379e-14 val: 2.041783859787454e-06


13620it [10:54, 20.82it/s]                           


Iter: 72 train: 6.885747834078683e-14 val: 2.0417837203980164e-06


13620it [11:22, 19.94it/s]                           


Iter: 73 train: 6.879848792370812e-14 val: 2.041782583515774e-06


13620it [11:22, 19.96it/s]                           


Iter: 74 train: 6.813760235935727e-14 val: 2.0417840064304353e-06


13620it [10:56, 20.76it/s]                           


Iter: 75 train: 6.730667330222727e-14 val: 2.041781134219149e-06


13620it [11:04, 20.51it/s]                           


Iter: 76 train: 6.502707704353191e-14 val: 2.041781280651522e-06


13620it [10:57, 20.72it/s]                           


Iter: 77 train: 6.470717591714811e-14 val: 2.0417819187516116e-06


13620it [11:01, 20.58it/s]                           


Iter: 78 train: 6.319919754509327e-14 val: 2.0417816202845344e-06


13620it [11:15, 20.18it/s]                           


Iter: 79 train: 6.36384012504988e-14 val: 2.041779660528407e-06


13620it [10:56, 20.76it/s]                           


Iter: 80 train: 6.313342344323131e-14 val: 2.041783576658333e-06


13620it [11:05, 20.47it/s]                           


Iter: 81 train: 6.119471443812905e-14 val: 2.0417819014118566e-06


13620it [11:04, 20.50it/s]                           


Iter: 82 train: 6.124626526949354e-14 val: 2.0417815992545312e-06


13620it [11:13, 20.22it/s]                           


Iter: 83 train: 5.911871538049632e-14 val: 2.041782071362322e-06


13620it [10:50, 20.93it/s]                           


Iter: 84 train: 5.871840174160105e-14 val: 2.041779169726956e-06


13620it [11:14, 20.19it/s]                           


Iter: 85 train: 6.077560026345447e-14 val: 2.041781156183308e-06


13620it [11:10, 20.33it/s]                           


Iter: 86 train: 5.851516125180116e-14 val: 2.0417821502736564e-06


13620it [10:58, 20.68it/s]                           


Iter: 87 train: 5.647172700343735e-14 val: 2.0417822250316175e-06


13620it [11:00, 20.62it/s]                           


Iter: 88 train: 5.590774428505022e-14 val: 2.0417796513165007e-06


13620it [11:20, 20.02it/s]                           


Iter: 89 train: 5.48776974880893e-14 val: 2.041780484027984e-06


13620it [11:09, 20.34it/s]                           


Iter: 90 train: 5.5440576158145324e-14 val: 2.041778961952785e-06


13620it [11:30, 19.71it/s]                           


Iter: 91 train: 5.5416248826401085e-14 val: 2.04177906650694e-06


13620it [11:10, 20.30it/s]                           


Iter: 92 train: 5.2634970566990665e-14 val: 2.0417796867420473e-06


  4%|▍         | 546/13194 [00:25<09:41, 21.76it/s]


KeyboardInterrupt: 