In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb, pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from rl import make_transition, make_transition_for_AKI, imvt

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')
    
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)
    
np.random.seed(params['random_seed'])
torch.manual_seed(params['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(params['random_seed'])
    torch.cuda.manual_seed_all(params['random_seed'])

In [2]:
wandb.init(
    name="AKI_smooth_l1",
    project="AIAKI_IMV_LSTM",
    config={
        "learning_rate":params["learning_rate"],
        "epoch":params['epoch'],
        "batch_size":params["minibatch_size"],
        "n_units":params["n_units"],
        "update_freq":params["update_freq"]
    }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchanreverse[0m ([33mdahs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
train_loader, train_len = make_transition_for_AKI(params['train'],rolling_size=24,batch_size=64,istrain=True)
val_loader, val_len = make_transition_for_AKI(params['val'],rolling_size=24,batch_size=256,istrain=False)

100%|██████████| 10664/10664 [00:29<00:00, 367.12it/s]


Finished train data transition for AKI


100%|██████████| 1523/1523 [00:02<00:00, 685.24it/s]


Finished val or test data trnsition for AKI


In [4]:
device = 'cuda:0'
network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)
target_network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)

In [6]:
Model="AKI_smooth_l1_2"

epoch = params['epoch']
gamma = params['gamma']
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = params['update_freq']
train_loss_history = []
val_loss_history = []


for i in range(epoch):
    loss_train = 0
    update_counter = 0
    for s,a,r,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r = r.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q_pred = q.gather(1, a).squeeze()
        
        with torch.no_grad():
            q2,_,_ = target_network(s2)
            q2_net, _, _ = network(s2)

        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

        bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
        loss = F.smooth_l1_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r = r.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
            loss = F.smooth_l1_loss(q_pred, bellman_target)
            loss_val += loss.item()
            
    wandb.log({"Iter:": i, "train:":loss_train/train_len, "val:":loss_val/val_len})
    print("Iter:", i, "train:", loss_train/train_len, "val:",loss_val/val_len)
    if not os.path.exists('checkpoint_%s'%Model):
        os.makedirs('./checkpoint_%s'%Model)
    torch.save(network.state_dict(), './checkpoint_%s/checkpoint_%i.pt'%(Model,i))
    torch.save(optimizer.state_dict(), './checkpoint_%s/optim.pth'%Model)
    train_loss_history.append(loss_train/train_len)
    val_loss_history.append(loss_val/val_len)
    
    with open('./checkpoint_%s/train_loss_%s'%(Model,Model), 'wb') as f:
        pickle.dump(train_loss_history, f)
        
    with open('./checkpoint_%s/val_loss_%s'%(Model,Model), 'wb') as f:
        pickle.dump(val_loss_history, f)

13620it [13:08, 17.27it/s]                           


Iter: 0 train: 0.00014451145154212953 val: 1.4116378735318106e-05


13620it [12:41, 17.88it/s]                           


Iter: 1 train: 0.00012278268186529095 val: 3.603637928111162e-05


13620it [10:50, 20.92it/s]                           


Iter: 2 train: 7.305037415581768e-05 val: 1.9434259478572127e-05


13620it [10:44, 21.14it/s]                           


Iter: 3 train: 6.351256104438319e-05 val: 5.710188627249799e-06


13620it [10:30, 21.60it/s]                           


Iter: 4 train: 6.037671505655819e-05 val: 8.51850695675522e-06


13620it [10:29, 21.63it/s]                           


Iter: 5 train: 5.832413098027916e-05 val: 7.941986385069587e-06


13620it [10:34, 21.48it/s]                           


Iter: 6 train: 5.605487888226544e-05 val: 6.855905037270356e-06


13620it [09:58, 22.77it/s]                           


Iter: 7 train: 5.418596927700785e-05 val: 1.0081998705121853e-05


13620it [10:05, 22.48it/s]                           


Iter: 8 train: 5.316284880102993e-05 val: 7.933275423530977e-06


13620it [10:34, 21.45it/s]                            


Iter: 9 train: 5.1754139238862085e-05 val: 1.019810400762983e-05


13620it [10:05, 22.50it/s]                           


Iter: 10 train: 5.101127825557865e-05 val: 5.568829380257447e-06


13620it [10:13, 22.19it/s]                           


Iter: 11 train: 5.0408621056874916e-05 val: 1.393994727419116e-05


13620it [10:16, 22.10it/s]                           


Iter: 12 train: 4.873236451842928e-05 val: 1.1046962147627111e-05


13620it [10:26, 21.74it/s]                           


Iter: 13 train: 4.795969634055117e-05 val: 1.3792911694180087e-05


13620it [10:23, 21.83it/s]                           


Iter: 14 train: 4.7178396697914794e-05 val: 8.442986092611418e-06


13620it [10:21, 21.92it/s]                           


Iter: 15 train: 4.664634123566788e-05 val: 7.099202022204935e-06


13620it [10:30, 21.59it/s]                           


Iter: 16 train: 4.621658292426293e-05 val: 5.426447806112589e-06


13620it [10:17, 22.06it/s]                           


Iter: 17 train: 4.681617829859161e-05 val: 4.6978614095037276e-06


13620it [10:30, 21.61it/s]                           


Iter: 18 train: 4.52167480659738e-05 val: 6.022441681608693e-06


13620it [10:36, 21.38it/s]                           


Iter: 19 train: 4.4065671680224135e-05 val: 8.434271951644426e-06


13620it [10:25, 21.79it/s]                           


Iter: 20 train: 4.2889831109630074e-05 val: 8.78349818229893e-06


13620it [10:35, 21.42it/s]                           


Iter: 21 train: 4.168288213471707e-05 val: 9.348898125178258e-06


13620it [10:37, 21.36it/s]                           


Iter: 22 train: 4.013435402823011e-05 val: 1.0659770249387627e-05


13620it [10:49, 20.98it/s]                           


Iter: 23 train: 4.019659691892571e-05 val: 5.417678619504424e-06


13620it [10:44, 21.12it/s]                           


Iter: 24 train: 3.9339720391622485e-05 val: 6.013805218594199e-06


13620it [10:44, 21.13it/s]                           


Iter: 25 train: 3.9122378677041935e-05 val: 5.937289575059563e-06


13620it [10:56, 20.73it/s]                           


Iter: 26 train: 3.710934570975838e-05 val: 5.666180603628376e-06


13620it [10:46, 21.07it/s]                           


Iter: 27 train: 3.701553405258517e-05 val: 7.424325634681477e-06


13620it [10:54, 20.82it/s]                           


Iter: 28 train: 3.4858533790163476e-05 val: 8.04128317332725e-06


13620it [10:54, 20.81it/s]                           


Iter: 29 train: 3.446622383199275e-05 val: 8.602903940337018e-06


13620it [10:38, 21.32it/s]                           


Iter: 30 train: 3.3399714299592934e-05 val: 8.770118118770857e-06


13620it [11:31, 19.68it/s]                           


Iter: 31 train: 3.2597955996895355e-05 val: 9.751126016007142e-06


13620it [10:48, 21.00it/s]                           


Iter: 32 train: 3.416519439586204e-05 val: 8.731612768439065e-06


13620it [11:10, 20.32it/s]                           


Iter: 33 train: 3.111674113992765e-05 val: 9.5106602617967e-06


13620it [11:03, 20.53it/s]                           


Iter: 34 train: 3.0400214372967132e-05 val: 8.770839590787999e-06


13620it [11:04, 20.51it/s]                           


Iter: 35 train: 2.9817044840801556e-05 val: 8.706902464904992e-06


13620it [11:09, 20.33it/s]                           


Iter: 36 train: 3.0111456927627296e-05 val: 8.119769363769727e-06


13620it [11:18, 20.08it/s]                           


Iter: 37 train: 2.8797185935414735e-05 val: 9.420519497386021e-06


13620it [11:17, 20.11it/s]                           


Iter: 38 train: 2.894597389562854e-05 val: 8.579278704459832e-06


13620it [11:45, 19.30it/s]                           


Iter: 39 train: 2.791116373842835e-05 val: 8.218680103587283e-06


13620it [11:47, 19.24it/s]                           


Iter: 40 train: 2.630169611565472e-05 val: 7.994100580420146e-06


13620it [11:54, 19.05it/s]                           


Iter: 41 train: 2.566970062150656e-05 val: 8.151645415902508e-06


13620it [11:31, 19.70it/s]                           


Iter: 42 train: 2.6469832706631683e-05 val: 8.37696035896192e-06


13620it [11:28, 19.78it/s]                           


Iter: 43 train: 2.6642658110292495e-05 val: 8.572386798834333e-06


13620it [11:30, 19.72it/s]                           


Iter: 44 train: 2.66865956951717e-05 val: 8.593470216625445e-06


13620it [11:36, 19.54it/s]                           


Iter: 45 train: 2.4558787369567834e-05 val: 5.761457245749018e-06


13620it [11:59, 18.93it/s]                           


Iter: 46 train: 2.3479082935754947e-05 val: 5.8135326408321295e-06


13620it [11:53, 19.09it/s]                           


Iter: 47 train: 2.3329229070515577e-05 val: 5.623790689986415e-06


13620it [11:59, 18.93it/s]                           


Iter: 48 train: 2.261606450832512e-05 val: 5.847532466764153e-06


13620it [11:35, 19.57it/s]                           


Iter: 49 train: 2.2139903785046354e-05 val: 5.720988385286107e-06


13620it [11:47, 19.25it/s]                           


Iter: 50 train: 2.168166666420708e-05 val: 6.392069743020769e-06


13620it [12:03, 18.83it/s]                           


Iter: 51 train: 2.1030345880040978e-05 val: 5.3135108567814485e-06


13620it [11:57, 18.98it/s]                           


Iter: 52 train: 2.0312632479423972e-05 val: 5.595175646009583e-06


13620it [12:01, 18.87it/s]                           


Iter: 53 train: 1.9752893463891254e-05 val: 5.670510228863665e-06


13620it [12:01, 18.87it/s]                           


Iter: 54 train: 1.9443793936307875e-05 val: 5.649647192605796e-06


13620it [12:01, 18.88it/s]                           


Iter: 55 train: 2.0090896740227687e-05 val: 6.01403015850899e-06


13620it [12:11, 18.63it/s]                           


Iter: 56 train: 1.904477406650906e-05 val: 5.537805091242519e-06


13620it [12:19, 18.42it/s]                           


Iter: 57 train: 1.934474447241882e-05 val: 5.4507514309185985e-06


13620it [12:13, 18.56it/s]                           


Iter: 58 train: 1.8196550233094012e-05 val: 5.455454354849154e-06


13620it [12:06, 18.74it/s]                           


Iter: 59 train: 1.8232543786103634e-05 val: 5.923858531361115e-06


13620it [12:26, 18.25it/s]                           


Iter: 60 train: 1.8891068386304055e-05 val: 5.619334459256083e-06


13620it [12:16, 18.48it/s]                           


Iter: 61 train: 1.777763667463892e-05 val: 5.616188185802522e-06


13620it [12:54, 17.59it/s]                           


Iter: 62 train: 1.700548213221962e-05 val: 5.09088898437372e-06


13620it [12:24, 18.29it/s]                           


Iter: 63 train: 1.7613170563708304e-05 val: 4.980970010559702e-06


13620it [12:24, 18.29it/s]                           


Iter: 64 train: 1.7028313749332616e-05 val: 5.364886033352985e-06


13620it [12:17, 18.46it/s]                           


Iter: 65 train: 1.5829537101250615e-05 val: 5.972001638055539e-06


13620it [12:49, 17.71it/s]                           


Iter: 66 train: 1.603312627440268e-05 val: 5.376543217276169e-06


13620it [12:22, 18.35it/s]                           


Iter: 67 train: 1.693567248028609e-05 val: 5.202787034165901e-06


13620it [12:33, 18.07it/s]                           


Iter: 68 train: 1.6754006560487863e-05 val: 5.0378008779808895e-06


13620it [12:24, 18.29it/s]                           


Iter: 69 train: 1.6694725724822515e-05 val: 4.6589077511467225e-06


13620it [12:45, 17.78it/s]                            


Iter: 70 train: 1.668137917885268e-05 val: 5.592383115552696e-06


13620it [12:40, 17.92it/s]                           


Iter: 71 train: 1.7486098029689843e-05 val: 5.56861070721811e-06


13620it [12:47, 17.74it/s]                           


Iter: 72 train: 1.699588427913312e-05 val: 5.333387464342869e-06


13620it [13:03, 17.38it/s]                           


Iter: 73 train: 1.6032728802509043e-05 val: 5.79229645470004e-06


13620it [13:05, 17.35it/s]                           


Iter: 74 train: 1.603644658752678e-05 val: 5.463513568592955e-06


13620it [12:39, 17.94it/s]                           


Iter: 75 train: 1.5172884185191675e-05 val: 5.944839040283679e-06


13620it [12:39, 17.93it/s]                           


Iter: 76 train: 1.539023313135548e-05 val: 6.395360719616764e-06


13620it [12:54, 17.59it/s]                           


Iter: 77 train: 1.6268934533470297e-05 val: 6.083675889031857e-06


13620it [13:56, 16.29it/s]                           


Iter: 78 train: 1.4127043603475985e-05 val: 5.6626977812860315e-06


13620it [14:22, 15.80it/s]                            


Iter: 79 train: 1.3748484061846783e-05 val: 5.898235558590763e-06


13620it [12:48, 17.72it/s]                           


Iter: 80 train: 1.36284409810131e-05 val: 5.200419815507098e-06


13620it [12:38, 17.96it/s]                           


Iter: 81 train: 1.4191976990640738e-05 val: 5.161565178142528e-06


13620it [13:01, 17.42it/s]                           


Iter: 82 train: 1.3769743205682241e-05 val: 5.2255389554532865e-06


13620it [12:29, 18.18it/s]                           


Iter: 83 train: 1.3581887292031408e-05 val: 5.2715991971695625e-06


13620it [12:27, 18.23it/s]                           


Iter: 84 train: 1.6507954216384744e-05 val: 4.606964369337497e-06


13620it [13:30, 16.80it/s]                            


Iter: 85 train: 1.8097987773902884e-05 val: 4.6214935016204006e-06


13620it [12:58, 17.50it/s]                           


Iter: 86 train: 1.634657013558605e-05 val: 5.088196589195055e-06


13620it [12:35, 18.03it/s]                           


Iter: 87 train: 1.4515650169030619e-05 val: 4.9480256230698304e-06


13620it [12:33, 18.06it/s]                           


Iter: 88 train: 1.3611849220325072e-05 val: 4.887294726299997e-06


13620it [12:29, 18.18it/s]                           


Iter: 89 train: 1.309727546441202e-05 val: 4.865363324864811e-06


13620it [12:47, 17.75it/s]                           


Iter: 90 train: 1.2788453235544387e-05 val: 4.997077813793055e-06


13620it [12:29, 18.17it/s]                           


Iter: 91 train: 1.467991404603119e-05 val: 5.078879825116514e-06


13620it [12:38, 17.96it/s]                           


Iter: 92 train: 1.3470076958879915e-05 val: 4.76609395204987e-06


13620it [13:01, 17.43it/s]                           


Iter: 93 train: 1.3104312724753557e-05 val: 5.707288625555485e-06


13620it [12:38, 17.96it/s]                           


Iter: 94 train: 1.4093917736620887e-05 val: 5.533997769121023e-06


13620it [12:49, 17.69it/s]                           


Iter: 95 train: 1.319866438371849e-05 val: 5.424699198877827e-06


13620it [12:49, 17.70it/s]                           


Iter: 96 train: 1.24777843316385e-05 val: 5.230859308243662e-06


13620it [12:38, 17.95it/s]                            


Iter: 97 train: 1.3461642882538612e-05 val: 5.28578295707946e-06


13620it [13:02, 17.40it/s]                           


Iter: 98 train: 1.4843685312840735e-05 val: 5.177425999407403e-06


13620it [12:54, 17.58it/s]                           


Iter: 99 train: 1.2736682134150107e-05 val: 5.103064765065012e-06


13620it [12:53, 17.60it/s]                           


Iter: 100 train: 1.1813981708941536e-05 val: 5.044090977005869e-06


13620it [12:46, 17.77it/s]                           


Iter: 101 train: 1.141656201404635e-05 val: 5.27945674913958e-06


13620it [12:25, 18.28it/s]                           


Iter: 102 train: 1.1547563263455356e-05 val: 4.5703963555449945e-06


 89%|████████▉ | 11739/13209 [10:32<01:19, 18.56it/s]


KeyboardInterrupt: 