In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from rl import make_transition, make_transition_for_AKI, imvt

from tqdm import tqdm

if os.getcwd()[-4:] == "code":
    os.chdir('../')

In [2]:
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)

In [3]:
wandb.init(
    name="IMV_LSTM_Rescue_mse",
    project="AIAKI",
    config={
        "learning_rate":params["learning_rate"],
        "epoch":params['epoch'],
        "batch_size":params["minibatch_size"],
        "n_units":params["n_units"],
        "update_freq":params["update_freq"]
    }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchanreverse[0m ([33mdahs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_loader, train_len = make_transition(params['train'],rolling_size=24,batch_size=64,istrain=True)
val_loader, val_len = make_transition(params['val'],rolling_size=24,batch_size=256,istrain=False)

100%|██████████| 11426/11426 [00:33<00:00, 336.62it/s]


Finished train data transition


100%|██████████| 762/762 [00:01<00:00, 560.76it/s]


Finished val or test data trnsition


In [5]:
device = 'cuda:0'
network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)
target_network = imvt(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=params['n_units'], device=device).to(device)

In [6]:
epoch = params['epoch']
gamma = params['gamma']
optimizer = optim.Adam(network.parameters(), lr=params['learning_rate'], amsgrad=True)

update_freq = params['update_freq']
train_loss_history = []
val_loss_history = []


for i in range(epoch):
    loss_train = 0
    update_counter = 0
    for s,a,r1,r2,r3,s2,t in tqdm(train_loader):
        s = s.to(device)
        a = a.to(device)
        r1 = r1.to(device)
        r2 = r2.to(device)
        r3 = r3.to(device)
        s2 = s2.to(device)
        t = t.to(device)

        q,_,_ = network(s)
        q_pred = q.gather(1, a).squeeze()
        
        with torch.no_grad():
            q2,_,_ = target_network(s2)
            q2_net, _, _ = network(s2)

        q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

        bellman_target = torch.clamp(r2, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
        loss = F.mse_loss(q_pred, bellman_target)

        optimizer.zero_grad()
        loss.backward()
        loss_train += loss.item()
        optimizer.step()

        update_counter += 1
        if update_counter == update_freq:
            target_network.load_state_dict(network.state_dict())
            update_counter = 0

    with torch.no_grad():
        loss_val = 0

        for s,a,r1,r2,r3,s2,t in val_loader:
            s = s.to(device)
            a = a.to(device)
            r1 = r1.to(device)
            r2 = r2.to(device)
            r3 = r3.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q2,_,_ = target_network(s2)
            q2 = q2.detach()
            q_pred = q.gather(1, a).squeeze()

            q2_net,_,_ = network(s2)
            q2_net = q2_net.detach()
            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

            bellman_target = torch.clamp(r2, max=1.0, min=0.0) + gamma * torch.clamp(q2_max.detach(), max=1.0, min=0.0)*(1-t)
            loss = F.mse_loss(q_pred, bellman_target)
            loss_val += loss.item()
            
    wandb.log({"Iter:": i, "train:":loss_train/train_len, "val:":loss_val/val_len})
    print("Iter:", i, "train:", loss_train/train_len, "val:",loss_val/val_len)
    torch.save(network.state_dict(), './checkpoint/checkpoint_%i.pt'%i)
    train_loss_history.append(loss_train/train_len)
    val_loss_history.append(loss_val/val_len)

 55%|█████▌    | 8674/15689 [08:42<06:53, 16.95it/s]