In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from rl import make_transition, make_transition_for_AKI, imvt

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')

In [2]:
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)

In [3]:
wandb.init(
    name="IMV_LSTM_Rescue_mse",
    project="AIAKI",
    config={
        "learning_rate":params["learning_rate"],
        "epoch":params['epoch'],
        "batch_size":params["minibatch_size"],
        "n_units":params["n_units"],
        "update_freq":params["update_freq"]
    }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchanreverse[0m ([33mdahs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_loader, train_len = make_transition(params['train'],rolling_size=24,batch_size=64,istrain=True)
val_loader, val_len = make_transition(params['val'],rolling_size=24,batch_size=256,istrain=False)

100%|██████████| 11426/11426 [00:33<00:00, 336.62it/s]


Finished train data transition


100%|██████████| 762/762 [00:01<00:00, 560.76it/s]


Finished val or test data trnsition


In [5]:
device = 'cuda:0'
network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)
target_network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)

In [6]:
epoch = params['epoch']
gamma = params['gamma']
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = params['update_freq']
train_loss_history = []
val_loss_history = []


for i in range(epoch):
    loss_train = 0
    update_counter = 0
    for s,a,r1,r2,r3,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r1 = r1.to(device)
        r2 = r2.to(device)
        r3 = r3.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q_pred = q.gather(1, a).squeeze()
        
        with torch.no_grad():
            q2,_,_ = target_network(s2)
            q2_net, _, _ = network(s2)

        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

        bellman_target = torch.clamp(r2, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
        loss = F.mse_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r1,r2,r3,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r1 = r1.to(device)
            r2 = r2.to(device)
            r3 = r3.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r2, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
            loss = F.mse_loss(q_pred, bellman_target)
            loss_val += loss.item()
            
    wandb.log({"Iter:": i, "train:":loss_train/train_len, "val:":loss_val/val_len})
    print("Iter:", i, "train:", loss_train/train_len, "val:",loss_val/val_len)
    torch.save(network.state_dict(), './checkpoint/checkpoint_%i.pt'%i)
    train_loss_history.append(loss_train/train_len)
    val_loss_history.append(loss_val/val_len)

16189it [16:14, 16.62it/s]                           


Iter: 0 train: 0.0002855096365250812 val: 2.3586637110160236e-05


16189it [16:16, 16.58it/s]                           


Iter: 1 train: 0.0002450447550823554 val: 2.3443511609918735e-05


16189it [16:16, 16.58it/s]                           


Iter: 2 train: 0.00023031934322699342 val: 3.156938346178348e-05


16189it [16:47, 16.07it/s]                           


Iter: 3 train: 0.00021542570887615511 val: 3.157295938092786e-05


16189it [17:05, 15.79it/s]                           


Iter: 4 train: 0.00019718511108760594 val: 3.1756627027849306e-05


16189it [16:45, 16.10it/s]                           


Iter: 5 train: 0.00017900316519949341 val: 2.9535606163798576e-05


16189it [16:22, 16.47it/s]                           


Iter: 6 train: 0.00015631833202125483 val: 4.233422124801393e-05


16189it [16:16, 16.59it/s]                           


Iter: 7 train: 0.0001454985340846843 val: 4.204546731151999e-05


16189it [16:22, 16.47it/s]                           


Iter: 8 train: 0.00012991264283027822 val: 2.4053139626870353e-05


16189it [16:17, 16.57it/s]                           


Iter: 9 train: 0.00012103094734706586 val: 1.878235209099789e-05


16189it [16:15, 16.60it/s]                           


Iter: 10 train: 0.0001153549600499465 val: 2.0656799828932347e-05


16189it [16:16, 16.58it/s]                           


Iter: 11 train: 9.863189863035571e-05 val: 2.040922857139964e-05


16189it [16:16, 16.58it/s]                           


Iter: 12 train: 9.25493563923447e-05 val: 1.3716144554997679e-05


16189it [16:16, 16.58it/s]                           


Iter: 13 train: 8.909602576777866e-05 val: 1.9685429792006843e-05


16189it [16:17, 16.56it/s]                           


Iter: 14 train: 7.589250231893781e-05 val: 1.2541092044917514e-05


16189it [16:16, 16.58it/s]                           


Iter: 15 train: 5.5844580065363014e-05 val: 1.2381740364901255e-05


16189it [16:17, 16.56it/s]                           


Iter: 16 train: 4.704106477714931e-05 val: 1.018204481228847e-05


16189it [16:16, 16.58it/s]                           


Iter: 17 train: 4.3505255218378736e-05 val: 1.257931061838537e-05


16189it [16:17, 16.57it/s]                           


Iter: 18 train: 3.847381597059141e-05 val: 8.921383921110233e-06


16189it [16:17, 16.57it/s]                           


Iter: 19 train: 3.150969069006574e-05 val: 9.224313958766127e-06


16189it [16:17, 16.56it/s]                           


Iter: 20 train: 3.0139679303697807e-05 val: 8.410668414128145e-06


16189it [16:15, 16.59it/s]                           


Iter: 21 train: 3.0144781536352143e-05 val: 6.90413046788368e-06


16189it [16:31, 16.33it/s]                           


Iter: 22 train: 2.6080654465281622e-05 val: 6.649876937466622e-06


16189it [16:21, 16.49it/s]                           


Iter: 23 train: 1.8686006115481426e-05 val: 4.87180747412227e-06


16189it [16:12, 16.64it/s]                           


Iter: 24 train: 1.889712425197381e-05 val: 6.43288942949912e-06


16189it [16:15, 16.60it/s]                           


Iter: 25 train: 1.5908691747716868e-05 val: 5.1981832544189095e-06


16189it [16:15, 16.59it/s]                           


Iter: 26 train: 1.6968805731544897e-05 val: 4.827376030957835e-06


16189it [16:15, 16.59it/s]                           


Iter: 27 train: 1.4480156878158434e-05 val: 4.359338395129572e-06


16189it [16:15, 16.59it/s]                           


Iter: 28 train: 1.2679012485989048e-05 val: 5.0447039803888565e-06


16189it [16:16, 16.58it/s]                           


Iter: 29 train: 1.1045041309332738e-05 val: 4.559793034870031e-06


16189it [16:16, 16.58it/s]                           


Iter: 30 train: 1.0795111519950996e-05 val: 5.093316440064272e-06


16189it [16:15, 16.60it/s]                           


Iter: 31 train: 1.1629269325175453e-05 val: 4.256537501672073e-06


16189it [16:16, 16.58it/s]                           


Iter: 32 train: 1.1427643176588382e-05 val: 4.523588930215289e-06


16189it [16:14, 16.62it/s]                           


Iter: 33 train: 1.0019675044918094e-05 val: 3.669816878491046e-06


16189it [16:17, 16.57it/s]                           


Iter: 34 train: 2.0637086024324754e-05 val: 5.456018986245096e-06


16189it [16:15, 16.60it/s]                           


Iter: 35 train: 1.577924720174173e-05 val: 4.383297652712805e-06


16189it [16:15, 16.59it/s]                           


Iter: 36 train: 2.148751236121293e-05 val: 4.93773254956467e-06


16189it [16:17, 16.57it/s]                           


Iter: 37 train: 1.322610389458776e-05 val: 5.930036246366626e-06


16189it [16:15, 16.59it/s]                           


Iter: 38 train: 1.0435206924414402e-05 val: 4.67560468014226e-06


16189it [16:16, 16.58it/s]                           


Iter: 39 train: 9.741238992042008e-06 val: 4.215332174845692e-06


16189it [16:25, 16.43it/s]                           


Iter: 40 train: 1.242479387691529e-05 val: 3.9054030556969e-06


16189it [16:27, 16.40it/s]                           


Iter: 41 train: 9.119531886126834e-06 val: 3.974975145671456e-06


16189it [16:23, 16.46it/s]                           


Iter: 42 train: 8.107623592506817e-06 val: 3.523247391145073e-06


16189it [16:17, 16.56it/s]                           


Iter: 43 train: 7.712112194444426e-06 val: 3.6976381639908808e-06


16189it [16:14, 16.62it/s]                           


Iter: 44 train: 8.66023785637861e-06 val: 4.164480954643635e-06


16189it [16:15, 16.59it/s]                           


Iter: 45 train: 1.109849701992842e-05 val: 4.209014429456512e-06


16189it [16:16, 16.57it/s]                           


Iter: 46 train: 1.0132523255837562e-05 val: 5.93341124272853e-06


16189it [16:23, 16.46it/s]                           


Iter: 47 train: 9.225977081522806e-06 val: 3.948689004574967e-06


16189it [16:16, 16.58it/s]                           


Iter: 48 train: 6.397679087064962e-06 val: 3.565931751891975e-06


16189it [16:23, 16.46it/s]                           


Iter: 49 train: 6.230484117520121e-06 val: 3.378236255586305e-06


16189it [16:17, 16.57it/s]                           


Iter: 50 train: 6.522333339547041e-06 val: 3.441349496608192e-06


16189it [16:10, 16.68it/s]                           


Iter: 51 train: 5.818723320249188e-06 val: 3.806451586540024e-06


16189it [16:12, 16.66it/s]                           


Iter: 52 train: 5.778382762225275e-06 val: 3.1893916194968463e-06


16189it [16:11, 16.67it/s]                           


Iter: 53 train: 5.0223824113529885e-06 val: 2.997443128988322e-06


16189it [16:16, 16.58it/s]                           


Iter: 54 train: 4.784772074803851e-06 val: 3.141050769566198e-06


16189it [16:15, 16.60it/s]                           


Iter: 55 train: 4.762040476838123e-06 val: 3.1553882950945283e-06


16189it [16:20, 16.50it/s]                           


Iter: 56 train: 4.889951752133121e-06 val: 2.990501210172107e-06


16189it [16:33, 16.29it/s]                           


Iter: 57 train: 5.003694411956608e-06 val: 3.1662756871551486e-06


16189it [16:34, 16.27it/s]                           


Iter: 58 train: 5.081360627965277e-06 val: 2.8771200188023985e-06


16189it [16:31, 16.32it/s]                           


Iter: 59 train: 4.4598202595661365e-06 val: 2.9484998287063734e-06


16189it [16:27, 16.40it/s]                           


Iter: 60 train: 5.959525111676271e-06 val: 3.1776123236628086e-06


16189it [16:35, 16.27it/s]                           


Iter: 61 train: 4.624200081223789e-06 val: 3.0086986825061227e-06


16189it [16:37, 16.22it/s]                           


Iter: 62 train: 5.498979080218896e-06 val: 3.030605774700926e-06


16189it [16:44, 16.11it/s]                           


Iter: 63 train: 4.869237497653558e-06 val: 2.9247881220526086e-06


16189it [16:31, 16.33it/s]                           


Iter: 64 train: 4.869674498749262e-06 val: 3.1305975990426236e-06


16189it [16:28, 16.37it/s]                           


Iter: 65 train: 5.731863541017525e-06 val: 3.3710074113271448e-06


16189it [16:22, 16.48it/s]                           


Iter: 66 train: 4.895253072337112e-06 val: 3.119002898653191e-06


16189it [16:23, 16.46it/s]                           


Iter: 67 train: 1.376251825024985e-05 val: 3.351525595980499e-06


16189it [16:15, 16.59it/s]                           


Iter: 68 train: 8.724992723067568e-06 val: 3.4973904401400245e-06


16189it [16:19, 16.53it/s]                           


Iter: 69 train: 6.029581394893593e-06 val: 2.7859318622692587e-06


16189it [16:30, 16.35it/s]                           


Iter: 70 train: 4.117541572340794e-06 val: 2.7116449355390463e-06


16189it [16:31, 16.33it/s]                           


Iter: 71 train: 3.4296794343851002e-06 val: 2.639699706848055e-06


16189it [16:24, 16.44it/s]                           


Iter: 72 train: 2.737358358880181e-06 val: 2.48042442876402e-06


16189it [16:26, 16.42it/s]                           


Iter: 73 train: 2.5813994658787662e-06 val: 2.438931280518047e-06


16189it [16:35, 16.26it/s]                           


Iter: 74 train: 2.7724569977999463e-06 val: 2.4493274793686095e-06


16189it [16:37, 16.22it/s]                           


Iter: 75 train: 2.5356473539211395e-06 val: 2.437861434036706e-06


16189it [16:40, 16.18it/s]                           


Iter: 76 train: 2.408880695222505e-06 val: 2.3745797897532913e-06


16189it [16:53, 15.97it/s]                           


Iter: 77 train: 2.3180989513090675e-06 val: 2.350308944826346e-06


16189it [16:20, 16.50it/s]                           


Iter: 78 train: 2.2292154252105048e-06 val: 2.3770498199359073e-06


16189it [16:19, 16.54it/s]                           


Iter: 79 train: 2.2751066286928293e-06 val: 2.3538749599230943e-06


16189it [16:13, 16.62it/s]                           


Iter: 80 train: 3.4906308683549257e-06 val: 2.7863708562717365e-06


16189it [16:07, 16.73it/s]                           


Iter: 81 train: 2.4319226376360136e-06 val: 2.2762341281114713e-06


16189it [16:04, 16.78it/s]                           


Iter: 82 train: 2.0868770210367886e-06 val: 2.2819004669101546e-06


16189it [16:05, 16.76it/s]                           


Iter: 83 train: 2.209448965627771e-06 val: 2.293571055442748e-06


16189it [16:11, 16.66it/s]                           


Iter: 84 train: 2.086417601543414e-06 val: 2.3308136550564038e-06


16189it [16:12, 16.65it/s]                           
