In [1]:
import pandas as pd
import numpy as np
import os, yaml, wandb, pickle, optuna, gc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import TensorDataset, DataLoader, Sampler

from rl import make_transition, imvt, imvt2, CustomSampler
from make_plot import show_AUROC, plot_alpha, plot_beta, make_transition_test, make_betas, make_transition_test_for_AKI

from tqdm import tqdm
from sklearn.metrics import roc_curve, roc_auc_score

if os.getcwd()[-4:] == "code":
    os.chdir('../')
    
with open(os.path.join("./code/params.yaml")) as f:
        params = yaml.safe_load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
np.random.seed(params['random_seed'])
torch.manual_seed(params['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(params['random_seed'])
    torch.cuda.manual_seed_all(params['random_seed'])

In [3]:
train = './code/train_4hrs_mean.parquet'
val = './code/val_4hrs_mean.parquet'
target = -1

train_data = make_transition(train,"r:AKI_stage3",target,rolling_size=6)
val_data = make_transition(val,"r:AKI_stage3",target,rolling_size=6)

100%|██████████| 30730/30730 [00:48<00:00, 633.97it/s]
100%|██████████| 4390/4390 [00:04<00:00, 969.24it/s] 


In [4]:
val_transition = make_transition_test_for_AKI(val,rolling_size=6)

['m:AKI_stage1', 'm:AKI_stage2']


100%|██████████| 4390/4390 [00:05<00:00, 791.84it/s]


In [8]:
def objective(trial):
    batch_size = trial.suggest_categorical("batch_size",[32,64])
    n_units = trial.suggest_categorical("n_units",[2,4,8,16,32,64])
    
    lr = trial.suggest_categorical("learning_rate",[1e-6,5e-6,1e-5,5e-5,1e-4])
    lr_decay = trial.suggest_categorical("lr_decay",[0.75,0.8,0.85,0.9,0.95,1])
    lr_step = trial.suggest_categorical("lr_step",[2,5,10])

    ns = trial.suggest_categorical("negative_sampling",[2,4,6,8])
    
    loss = trial.suggest_categorical("loss",['smooth_l1','mse'])
    
    update_freq = trial.suggest_categorical("update_freq",[2,4,8,16,32])
    
    epochs = 50

    wandb.init(
        project='IMV_LSTM_AKI_new', name=f'trial-{trial.number}', reinit=True,
        config={
        "batch_size":batch_size,
        "n_units":n_units,
        "learning_rate":lr,
        "lr_decay":lr_decay,
        "lr_step":lr_step,
        "ns":ns,
        "loss":loss,
        "update_freq":update_freq
    })

    auroc = train(batch_size,n_units,lr,lr_decay,lr_step,ns,loss,epochs,update_freq)

    return auroc

In [10]:
def train(batch_size,n_units,lr,lr_decay,lr_step,ns,loss_type,epochs,update_freq=2):
    network = imvt2(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=n_units, device=device).to(device)
    target_network = imvt2(input_dim=params['state_dim'], output_dim=params['num_actions'], n_units=n_units, device=device).to(device)
    gamma = 1.0
    patience = 5
    best_loss = 1e6

    optimizer = optim.Adam(network.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay)

    num_workers = 4

    sampler = CustomSampler(train_data,batch_size,ns=ns,target=target)
    train_loader = DataLoader(train_data,batch_sampler=sampler,num_workers=num_workers)
    val_loader = DataLoader(val_data,batch_size=256,shuffle=False)
    

    for epoch in range(epochs):
        train_loss = 0
        update_counter = 0
        for s,a,r,s2,t in tqdm(train_loader):            
            s = s.to(device)
            a = a.to(device)
            r = r.to(device)
            s2 = s2.to(device)
            t = t.to(device)

            q,_,_ = network(s)
            q_pred = q.gather(1, a).squeeze()
            
            with torch.no_grad():
                q2,_,_ = target_network(s2)
                q2_net,_,_ = network(s2)

            q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze(1).detach()

            bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
            if loss_type == "l1":loss = F.l1_loss(q_pred, bellman_target)
            elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
            elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)

            optimizer.zero_grad()
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            update_counter += 1
            if update_counter == update_freq:
                target_network.load_state_dict(network.state_dict())
                update_counter = 0

        with torch.no_grad():
            val_loss = 0
            for s,a,r,s2,t in val_loader:
                s = s.to(device)
                a = a.to(device)
                r = r.to(device)
                s2 = s2.to(device)
                t = t.to(device)

                q,_,_ = network(s)
                q2,_,_ = target_network(s2)
                q2 = q2.detach()
                q_pred = q.gather(1, a).squeeze()

                q2_net,_,_ = network(s2)
                q2_net = q2_net.detach()
                q2_max = q2.gather(1, torch.max(q2_net,dim=1)[1].unsqueeze(1)).squeeze()

                bellman_target = torch.clamp(r, max=0.0, min=-1.0) + gamma * torch.clamp(q2_max.detach(), max=0.0, min=-1.0)*(1-t)
                if loss_type == "l1":loss = F.l1_loss(q_pred, bellman_target)
                elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
                elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
                val_loss += loss.item()

            q_value = []
            aki1 = []
            aki2 = []
            reward = []
            for s,a,r,m1,m2 in val_transition:
                s = s.to(device)
                q,_,_ = network(s.to(device))
                aki1.append(m1.detach().cpu().numpy())
                aki2.append(m2.detach().cpu().numpy())
                q_value.append(q.detach().cpu().numpy())
                reward.append(r.detach().cpu().numpy())
            aki1 = 1 - np.concatenate(aki1,axis=0)
            aki2 = 1 - np.concatenate(aki2,axis=0)
            q_value = 1 + np.concatenate(q_value,axis=0)
            reward  = 1 + np.concatenate(reward,axis=0)
            
            q_max = q_value.max(axis=1)
            q_median = np.median(q_value, axis=1)
            
            auroc      = roc_auc_score(reward,q_max)
            auroc_med  = roc_auc_score(reward,q_median)
            auroc1_max = roc_auc_score(aki1,q_max)
            auroc1_med = roc_auc_score(aki1,q_median)
            auroc2_max = roc_auc_score(aki2,q_max)
            auroc2_med = roc_auc_score(aki2,q_median)

        
        if (epoch%lr_step ==0):
            scheduler.step()
        
        if val_loss < best_loss:
            best_loss = val_loss
            counters = 0
        else :
            counters += 1

        wandb.log({"Iter:": epoch, "train:":train_loss, "val:":val_loss, "AUROC":auroc, "AUROC_median":auroc_med,"AUROC_stage_1_max":auroc1_max, "AUROC_stage_1_median":auroc1_med,"AUROC_stage_2_max":auroc2_max,"AUROC_stage_2_median":auroc2_med,"counters":counters})

        if (counters > patience)&(epoch>=20):
            break

    gc.collect()
    torch.cuda.empty_cache()
    return auroc

In [11]:
device = 'cuda:0'
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=1000)

best_params = study.best_params
best_loss = study.best_value

print("Best Hyperparameters:", best_params)
print("Best Validation Loss:", best_loss)

[I 2024-03-20 17:44:56,937] A new study created in memory with name: no-name-c7d65d17-43b0-447f-b412-0a110b64015a


0,1
AUROC,▁
AUROC_median,▁
AUROC_stage_1_max,▁
AUROC_stage_1_median,▁
AUROC_stage_2_max,▁
AUROC_stage_2_median,▁
Iter:,▁
counters,▁
train:,▁
val:,▁

0,1
AUROC,0.67563
AUROC_median,0.67608
AUROC_stage_1_max,0.63127
AUROC_stage_1_median,0.63285
AUROC_stage_2_max,0.66404
AUROC_stage_2_median,0.66548
Iter:,0.0
counters,0.0
train:,1735.88978
val:,21.72088


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:07, 91.97it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [03:59, 94.96it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:09, 91.17it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:08, 91.60it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,███████████████████▂▁
AUROC_median,█████▂▁▂▃█▄▅█████████
AUROC_stage_1_max,███████████████████▂▁
AUROC_stage_1_median,█████▂▁▃▄█▅▅█████████
AUROC_stage_2_max,███████████████████▂▁
AUROC_stage_2_median,█████▂▁▂▃█▄▅█████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇██
train:,█▇▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val:,█▆▄▂▁▁▂▂▂▂▂▂▂▂▂▁▁▁▂▂▂

0,1
AUROC,0.40212
AUROC_median,0.82857
AUROC_stage_1_max,0.4862
AUROC_stage_1_median,0.80269
AUROC_stage_2_max,0.43047
AUROC_stage_2_median,0.81642
Iter:,20.0
counters,16.0
train:,2540.16299
val:,15.612


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:52<00:00, 91.74it/s]
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:50<00:00, 92.61it/s]
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:55<00:00, 88.75it/s]
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:51<00:00, 92.25it/s]
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.m

0,1
AUROC,██▆▃▃▂▄▄▄▄▃▄▄▄▄▂▄▁▁▁▄
AUROC_median,██▆▃▃▃▄▄▄▄▃▄▄▄▄▂▄▁▁▂▄
AUROC_stage_1_max,▄▆▃▁▃▃▅▄▅▇▄▅▆▇▆▅▆▆▄▄█
AUROC_stage_1_median,▄▆▃▁▃▃▅▄▅▇▄▅▆▇▆▄▆▅▄▅█
AUROC_stage_2_max,▇█▅▁▂▂▄▄▃▄▃▃▄▄▄▂▄▂▁▂▄
AUROC_stage_2_median,██▅▂▃▂▄▄▄▄▃▄▄▅▄▂▄▁▁▂▄
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▆▆▆▇▇▇▇▇▇▇▇█████████
val:,▁▄▄▄▂▆▆▅▅▅▅▅▄▅▅▇▇▇█▆█

0,1
AUROC,0.72733
AUROC_median,0.72762
AUROC_stage_1_max,0.70576
AUROC_stage_1_median,0.70547
AUROC_stage_2_max,0.71986
AUROC_stage_2_median,0.71983
Iter:,20.0
counters,20.0
train:,399.57151
val:,16.76094


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
24502it [04:16, 95.44it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
24502it [04:12, 97.22it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
24502it [04:18, 94.91it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
24502it [04:16, 95.58it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,█▁▂▂▂▂▂▂▂▂▂▃▂▂▂▃▃▂▂▂▃
AUROC_median,█▁▁▂▂▂▂▂▂▂▂▂▃▂▂▂▃▂▂▂▂
AUROC_stage_1_max,█▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_1_median,█▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_max,█▁▁▂▂▂▂▂▂▂▂▃▂▃▂▃▃▃▃▃▃
AUROC_stage_2_median,█▁▁▂▂▂▂▂▂▂▂▂▃▂▂▃▃▂▃▂▂
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▁▁▂▁▂▂▃▃▄▄▅▅▆▆▇▇█
train:,▁███████████▇▇▇▇▇▇▇▇▇
val:,█▅▆▆▅▃▇▁▄▄▄▅▅▆▅▅▄▄▂▆▄

0,1
AUROC,0.63564
AUROC_median,0.6393
AUROC_stage_1_max,0.54933
AUROC_stage_1_median,0.55197
AUROC_stage_2_max,0.61058
AUROC_stage_2_median,0.61422
Iter:,20.0
counters,13.0
train:,4733.25492
val:,36.19268


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:36, 88.52it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:34, 89.13it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:31, 90.09it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:28, 91.40it/s]         

0,1
AUROC,██▇▆▃▁▁▁▁▂▂▂▂▂▂▁▁▂▂▁▂
AUROC_median,███▆▄▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_1_max,██▇▆▃▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
AUROC_stage_1_median,██▇▅▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
AUROC_stage_2_max,██▇▆▃▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_median,██▇▅▃▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▆███████████████████
val:,▁██████████████▇█████

0,1
AUROC,0.63411
AUROC_median,0.63543
AUROC_stage_1_max,0.53602
AUROC_stage_1_median,0.53657
AUROC_stage_2_max,0.60706
AUROC_stage_2_median,0.60754
Iter:,20.0
counters,20.0
train:,2362.89299
val:,18.53776


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:49, 96.99it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:46, 99.23it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:46, 99.38it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:47, 98.59it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = 

0,1
AUROC,█▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
AUROC_median,▇████▃▂▂▂▁████████████████
AUROC_stage_1_max,█▇▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
AUROC_stage_1_median,▆████▄▂▂▂▁████████████████
AUROC_stage_2_max,█▆▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
AUROC_stage_2_median,▇████▃▂▂▂▁████████████████
Iter:,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
counters,▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇█
train:,█▆▅▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▂
val:,▇█▇▆▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂

0,1
AUROC,0.82198
AUROC_median,0.82723
AUROC_stage_1_max,0.79575
AUROC_stage_1_median,0.80029
AUROC_stage_2_max,0.81464
AUROC_stage_2_median,0.81856
Iter:,25.0
counters,6.0
train:,625.82069
val:,12.59354


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:54<00:00, 89.85it/s]
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:49<00:00, 93.51it/s] 
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:49<00:00, 93.55it/s]
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:53<0

0,1
AUROC,▃██▄▃▂▂▂▂▄▁▂▂▂▃▃▁▂▂▂▁
AUROC_median,▃██▄▃▂▂▂▂▄▁▂▂▂▃▃▁▂▃▂▂
AUROC_stage_1_max,▁▇█▄▄▃▅▆▅▇▄▅▆▆▇▇▆▇██▇
AUROC_stage_1_median,▁▇█▅▄▃▅▅▅▆▄▅▅▆▇▇▆▇██▇
AUROC_stage_2_max,▁██▄▃▁▂▂▂▄▁▂▂▂▃▄▁▂▃▃▂
AUROC_stage_2_median,▂██▄▃▁▂▂▂▄▁▂▂▂▃▄▂▂▃▃▂
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▂▄▇▇████▇█▇▇▇▇▇▇▇▇▇▇
val:,▁▂▄▅▆▅▇▆██▇▆▇█▆▇▆▆█▇▇

0,1
AUROC,0.71334
AUROC_median,0.71358
AUROC_stage_1_max,0.68238
AUROC_stage_1_median,0.68125
AUROC_stage_2_max,0.70303
AUROC_stage_2_median,0.70277
Iter:,20.0
counters,20.0
train:,205.42261
val:,8.23274


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10983it [01:56, 94.33it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10983it [01:54, 95.62it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10983it [01:58, 93.00it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10983it [01:56, 94.18it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.

0,1
AUROC,▇██▇▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▂
AUROC_median,▁▁███████████████████
AUROC_stage_1_max,▆██▇▇▆▆▄▄▄▄▃▃▃▂▂▁▁▁▁▁
AUROC_stage_1_median,▁▁███████████████████
AUROC_stage_2_max,▇██▇▆▆▆▅▄▃▃▃▂▂▂▁▁▁▁▁▁
AUROC_stage_2_median,▁▁███████████████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▆▃▂▂▁▁▁▁▁▁▁▂▂▂▃▃▄▄▅▇█
val:,▁▂▃▂▂▂▂▃▃▃▃▃▃▃▄▄▅▅▆▇█

0,1
AUROC,0.81741
AUROC_median,0.81874
AUROC_stage_1_max,0.79106
AUROC_stage_1_median,0.79158
AUROC_stage_2_max,0.80358
AUROC_stage_2_median,0.80444
Iter:,20.0
counters,20.0
train:,1026.33052
val:,16.77335


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
21235it [03:35, 98.45it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
21235it [03:37, 97.48it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
21235it [03:37, 97.58it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
21235it [03:32, 99.77it/s]         

0,1
AUROC,███▇▇▇▇▇▆▆▆▆▆▆▅▄▂▂▁▁▁
AUROC_median,█▅█▇▇▇▇▆▆▆▆▆▆▆▅▄▃▂▁▁▁
AUROC_stage_1_max,█████▇▇▇▇▇▇▆▆▆▅▃▂▁▁▁▁
AUROC_stage_1_median,█▇███▇▇▇▇▇▆▆▆▆▅▄▂▂▁▁▁
AUROC_stage_2_max,███▇▇▇▇▇▆▆▆▆▆▆▅▄▂▁▁▁▁
AUROC_stage_2_median,█▆▇▇▇▇▇▆▆▆▆▆▆▅▅▄▂▂▁▁▁
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train:,▃▁▁▁▁▁▁▂▂▃▃▄▅▆▇██████
val:,▁▁▁▁▁▁▁▂▂▃▄▅▆▇███████

0,1
AUROC,0.55968
AUROC_median,0.56335
AUROC_stage_1_max,0.48888
AUROC_stage_1_median,0.49153
AUROC_stage_2_max,0.53054
AUROC_stage_2_median,0.53398
Iter:,20.0
counters,18.0
train:,1047.69327
val:,13.62354


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:52<00:00, 91.74it/s]
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:51<00:00, 92.52it/s] 
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:56<00:00, 88.39it/s] 
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
100%|█████████▉| 10275/10291 [01:51<

0,1
AUROC,▇█▇▅▅▅▇▆█▄▄▃▄▁▄▄▅▅▃▄▅
AUROC_median,▇█▇▄▅▅▇▆█▄▄▃▄▁▄▃▄▅▃▄▅
AUROC_stage_1_max,▁▄▃▂▄▃▅▆█▅▅▄▅▂▄▄▆▇▄▆▇
AUROC_stage_1_median,▁▄▃▂▄▄▅▆█▅▅▄▅▃▅▄▇▇▅▇▇
AUROC_stage_2_max,▆▇▆▄▅▄▆▆█▄▄▄▄▁▄▄▅▅▃▅▅
AUROC_stage_2_median,▅▇▆▃▄▄▇▅█▄▄▃▄▁▄▃▅▅▃▅▅
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▄▇████▇▇▇▇▇▇█▇▇▇▇▇▇▇
val:,▁▃▄▆▇▆▅▅▅▆▅▆▅▅▇▇▇▇██▆

0,1
AUROC,0.72249
AUROC_median,0.7221
AUROC_stage_1_max,0.69687
AUROC_stage_1_median,0.69524
AUROC_stage_2_max,0.71537
AUROC_stage_2_median,0.71493
Iter:,20.0
counters,20.0
train:,204.12713
val:,8.24343


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [08:01, 50.84it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:21, 93.57it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:26, 92.02it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
24502it [04:22, 93.31it/s]         

0,1
AUROC,█▇▇███▇▆▅▅▄▄▃▃▃▂▂▂▁▁▁
AUROC_median,█▇▇███▇▆▆▅▄▄▃▃▃▂▂▂▁▁▁
AUROC_stage_1_max,█████▇▇▆▅▅▄▄▃▃▃▂▂▂▁▁▁
AUROC_stage_1_median,█▇▇██▇▇▆▅▅▄▄▃▃▃▂▂▂▁▁▁
AUROC_stage_2_max,█▇▇██▇▆▆▅▄▄▃▃▃▂▂▂▂▁▁▁
AUROC_stage_2_median,█▇▇▇▇▇▆▆▅▄▄▄▃▃▂▂▂▂▁▁▁
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▂▁▂▆█████████████████
val:,▁▁▃██████████████████

0,1
AUROC,0.73345
AUROC_median,0.75428
AUROC_stage_1_max,0.69156
AUROC_stage_1_median,0.71443
AUROC_stage_2_max,0.70664
AUROC_stage_2_median,0.72882
Iter:,20.0
counters,20.0
train:,2376.50715
val:,18.59695


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:47, 98.71it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:47, 98.58it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:51, 94.90it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [01:47, 98.44it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = 

0,1
AUROC,████▇▇▇▇▇▇▇▇▆▅▃▂▂▁▁▁▁
AUROC_median,█▁▇█████████▇▇▆▆▅▅▅▅▅
AUROC_stage_1_max,█████▇▇▇▇█▇▇▆▄▃▂▂▁▁▁▁
AUROC_stage_1_median,█▁▇████████▇▇▆▅▄▄▄▃▃▃
AUROC_stage_2_max,████▇▇▇▇▇▇▇▆▅▄▃▂▁▁▁▁▁
AUROC_stage_2_median,█▁▇████████▇▇▆▆▅▅▄▄▄▄
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train:,▃▁▁▁▁▂▃▄▅▆▇█████████▇
val:,▁▁▁▁▂▂▃▄▅▇███████████

0,1
AUROC,0.64362
AUROC_median,0.60743
AUROC_stage_1_max,0.54523
AUROC_stage_1_median,0.5318
AUROC_stage_2_max,0.60498
AUROC_stage_2_median,0.57821
Iter:,20.0
counters,18.0
train:,1049.33424
val:,27.47546


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:00, 94.03it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 94.01it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:03, 91.95it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:00, 94.20it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.

0,1
AUROC,▁▅▆▇▇▇▇▇▇▇█▇▇▇▇▇▇▇███
AUROC_median,██████▁▂█████████████
AUROC_stage_1_max,▁▅▆▇▇██▇▇▇█▇▇▇▇██████
AUROC_stage_1_median,██████▁▂█████████████
AUROC_stage_2_max,▁▅▆▇▇▇▇▇▇▇██▇▇▇▇▇▇▇▇▇
AUROC_stage_2_median,██████▁▂█████████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▂▁▁▂▂▃▄▄▅▅▆▇▇█
train:,▄▃▃▂▂▂▁▁▁▁▂▂▂▃▃▄▄▅▆▇█
val:,▂▂▂▂▂▂▂▂▁▁▂▂▃▃▃▄▄▅▆▇█

0,1
AUROC,0.83196
AUROC_median,0.83128
AUROC_stage_1_max,0.80281
AUROC_stage_1_median,0.80235
AUROC_stage_2_max,0.81934
AUROC_stage_2_median,0.8184
Iter:,20.0
counters,11.0
train:,1376.76554
val:,19.5188


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:54, 99.71it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 93.63it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.75it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:54, 99.67it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,▁▂▃▄▅▅▆▆▆▇▇▇▇▇▇▆██████
AUROC_median,▁▁▁▁▁▁▁▁▂▅▆▇█████▄▃▃▃▃
AUROC_stage_1_max,▁▃▄▅▆▆▇▇▇██████▅██████
AUROC_stage_1_median,▁▁▁▁▁▁▁▁▂▅▆▇▇████▅▃▃▃▄
AUROC_stage_2_max,▁▂▃▄▄▅▅▅▆▆▆▆▆▆▇▆██████
AUROC_stage_2_median,▁▁▁▁▁▁▁▁▁▅▆▇█████▄▃▃▃▃
Iter:,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇█
train:,█▇▇▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁▁
val:,█▇▇▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▂▂▂▂

0,1
AUROC,0.84699
AUROC_median,0.36603
AUROC_stage_1_max,0.81015
AUROC_stage_1_median,0.42067
AUROC_stage_2_max,0.83525
AUROC_stage_2_median,0.3835
Iter:,21.0
counters,6.0
train:,1300.75508
val:,16.28523


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:59, 95.14it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:54, 99.67it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:54, 98.95it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:59, 95.16it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,██▁▁▂▂▂▂▂▂▂▂▁▁▂▂▁▁▂▂▂
AUROC_median,██▁▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_1_max,██▁▁▂▂▂▂▁▁▁▂▁▁▁▁▁▁▂▁▁
AUROC_stage_1_median,██▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂
AUROC_stage_2_max,██▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_median,██▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▃
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▄██████████▇▇▇▇▇▇▇▇▇
val:,▁████████████████████

0,1
AUROC,0.59257
AUROC_median,0.61988
AUROC_stage_1_max,0.51116
AUROC_stage_1_median,0.52658
AUROC_stage_2_max,0.57402
AUROC_stage_2_median,0.59787
Iter:,20.0
counters,20.0
train:,1718.59745
val:,31.04056


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:03, 92.19it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 93.63it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:07, 89.21it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:02, 93.00it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,████▇▆▅▃▂▁▁▁▂▂▂▂▂▃▃▃▃
AUROC_median,████▇▆▅▄▂▁▁▁▂▂▂▂▂▂▂▂▂
AUROC_stage_1_max,████▇▆▅▃▂▁▁▁▁▂▂▂▂▂▂▂▂
AUROC_stage_1_median,████▇▆▅▃▂▁▁▁▁▁▂▂▂▂▂▂▂
AUROC_stage_2_max,████▇▆▅▃▂▁▁▁▂▂▂▂▂▃▃▃▃
AUROC_stage_2_median,████▇▆▅▃▂▁▁▁▂▂▂▂▂▂▂▂▂
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▁▁▃██████████████████
val:,▁▂▆██████████████████

0,1
AUROC,0.62945
AUROC_median,0.62297
AUROC_stage_1_max,0.52743
AUROC_stage_1_median,0.524
AUROC_stage_2_max,0.59689
AUROC_stage_2_median,0.59174
Iter:,20.0
counters,20.0
train:,1719.74869
val:,30.89289


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 93.53it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 95.69it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.58it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:03, 91.78it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,██████▅▃▂▁▁▁▂▃▃▄▅▆███
AUROC_median,▂▂▂▂▂▂▂▂▁▆▃█▇▂▁▁▁▁▁▁▁
AUROC_stage_1_max,██████▆▄▂▁▁▁▃▃▄▅▆▇███
AUROC_stage_1_median,▂▂▂▂▂▂▂▂▁▇▂██▄▁▁▁▁▁▁▁
AUROC_stage_2_max,██████▅▃▂▁▁▁▂▃▄▄▅▇███
AUROC_stage_2_median,▂▂▂▂▂▂▂▂▁▆▃█▇▂▁▁▁▁▁▁▁
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▁▂▂▃▃▄▄▅▅▆▆▇▇█
train:,█▇▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁
val:,█▆▅▃▂▁▁▁▁▂▃▄▄▃▃▃▃▃▃▂▂

0,1
AUROC,0.843
AUROC_median,0.19583
AUROC_stage_1_max,0.80822
AUROC_stage_1_median,0.25011
AUROC_stage_2_max,0.83148
AUROC_stage_2_median,0.21332
Iter:,20.0
counters,13.0
train:,1277.71873
val:,15.40704


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:21, 80.62it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:03, 92.39it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:05, 90.31it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:06, 90.25it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.m

0,1
AUROC,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇█████
AUROC_median,██████████████▄▂▂▂▁▁▁
AUROC_stage_1_max,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██████
AUROC_stage_1_median,██████████████▅▃▃▂▂▁▁
AUROC_stage_2_max,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██████
AUROC_stage_2_median,██████████████▅▂▂▂▁▁▁
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▄▅▆▆▇█
train:,█▇▇▆▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁
val:,█▇▆▅▄▃▂▂▂▁▁▁▁▁▂▂▂▂▂▂▂

0,1
AUROC,0.8424
AUROC_median,0.44788
AUROC_stage_1_max,0.80929
AUROC_stage_1_median,0.52136
AUROC_stage_2_max,0.8307
AUROC_stage_2_median,0.4654
Iter:,20.0
counters,9.0
train:,1295.35789
val:,15.71725


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:06, 89.98it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:54, 99.01it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:56, 97.99it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:00, 94.04it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,▁▄▆▇████████▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁
AUROC_median,███████████████▅▂▁▁▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▅
AUROC_stage_1_max,▁▄▆▇▇████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
AUROC_stage_1_median,███████████████▆▂▂▁▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▄
AUROC_stage_2_max,▂▅▇▇████████▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
AUROC_stage_2_median,███████████████▆▂▁▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▅
Iter:,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
counters,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train:,█▇▇▆▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val:,█▇▇▆▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
AUROC,0.83684
AUROC_median,0.61116
AUROC_stage_1_max,0.80549
AUROC_stage_1_median,0.57372
AUROC_stage_2_max,0.82644
AUROC_stage_2_median,0.60542
Iter:,49.0
counters,0.0
train:,1258.36848
val:,15.52313


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 96.33it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 96.85it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 93.26it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:59, 95.57it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,█▁▁▂▂▂▂▂▂▃▂▃▂▃▃▃▃▄▃▃▃
AUROC_median,█▁▁▂▂▂▂▂▂▃▂▃▃▃▃▃▃▃▃▃▃
AUROC_stage_1_max,█▁▁▂▂▂▂▂▂▂▂▂▂▂▂▃▂▃▃▃▂
AUROC_stage_1_median,█▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_max,█▁▂▂▂▂▂▂▂▃▂▃▂▃▃▃▃▄▄▄▃
AUROC_stage_2_median,█▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▂▁▁▂▁▁▂▃▄▁▂▃▁▂▃▄▅▆▇█
train:,▁████████████████████
val:,██▅▅▅▄▃▄▃▄▂▃▃▁▄▃▂▃▂▃▃

0,1
AUROC,0.62289
AUROC_median,0.62235
AUROC_stage_1_max,0.5377
AUROC_stage_1_median,0.53649
AUROC_stage_2_max,0.60325
AUROC_stage_2_median,0.60141
Iter:,20.0
counters,7.0
train:,1715.4899
val:,31.02093


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.59it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:02, 92.75it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:00, 94.71it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.45it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,▁▁▁▂▄████▇▇▇▇██▃▃▆▇▇█
AUROC_median,▁█▁▅▂▂▁▂▂▁▁▂▂▂▁▂▃▇▇▇▇
AUROC_stage_1_max,▁▁▁▂▅█████▇████▄▄▆▇▇█
AUROC_stage_1_median,▁█▁▄▂▂▁▁▁▁▂▂▂▂▂▂▄████
AUROC_stage_2_max,▁▁▁▂▅████▇▇▇███▃▃▆▇██
AUROC_stage_2_median,▁█▁▅▂▂▁▂▂▁▁▂▂▂▁▂▄▇█▇▇
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▁▁▂▂▃▃▄▅▅▆▆▇▇█
train:,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂
val:,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂

0,1
AUROC,0.77481
AUROC_median,0.76064
AUROC_stage_1_max,0.76064
AUROC_stage_1_median,0.76198
AUROC_stage_2_max,0.77174
AUROC_stage_2_median,0.76515
Iter:,20.0
counters,12.0
train:,1388.11117
val:,19.61919


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 95.76it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:59, 94.92it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.91it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 96.19it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,█▇▂▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_median,█▇▁▁▂▂▂▂▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃
AUROC_stage_1_max,█▇▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▂▁▁▂▂▂
AUROC_stage_1_median,█▇▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_max,█▇▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_median,█▇▁▁▂▂▂▂▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃
Iter:,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▂▃▁▂▁▂▁▁▂▃▅▁▂▁▂▃▅▆▇█
train:,▁███████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val:,█▆▄▄▄▂▃▂▃▁▁▃▃▂▁▂▁▂▂▂▃▂▂

0,1
AUROC,0.6291
AUROC_median,0.63982
AUROC_stage_1_max,0.53501
AUROC_stage_1_median,0.54564
AUROC_stage_2_max,0.60533
AUROC_stage_2_median,0.61845
Iter:,22.0
counters,6.0
train:,1712.53536
val:,31.03166


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:56, 97.47it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:02, 92.80it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 96.55it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 97.21it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▄▆█▇▆▆
AUROC_median,████▃▁▂█████████████▆▄▃▃▂▃
AUROC_stage_1_max,▁▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇█▇▆▆
AUROC_stage_1_median,████▃▁▂█████████████▇▅▄▃▃▃
AUROC_stage_2_max,▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▆██▇█
AUROC_stage_2_median,████▃▁▂█████████████▆▄▃▃▃▃
Iter:,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
counters,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇█
train:,██▇▇▆▆▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
val:,█▇▇▆▆▅▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
AUROC,0.15409
AUROC_median,0.39152
AUROC_stage_1_max,0.19011
AUROC_stage_1_median,0.45011
AUROC_stage_2_max,0.16591
AUROC_stage_2_median,0.40969
Iter:,25.0
counters,6.0
train:,1343.48876
val:,16.77529


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 96.46it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 93.35it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 96.90it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:00, 94.23it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,████████████████████████▇▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁
AUROC_median,▁▁▂▃▄▄▃▃▃▃▃▃▃▅████████▇▇▃▃▃▃▃▃▃▃▃▃▅▆▄▄▄▄
AUROC_stage_1_max,████████████████████████▇▆▅▅▃▃▃▃▂▂▂▂▁▁▁▁
AUROC_stage_1_median,▁▁▂▄▄▅▄▄▄▄▄▃▃▅████████▇▇▄▃▃▃▃▃▃▃▃▃▄▆▅▅▅▅
AUROC_stage_2_max,████████████████████████▇▆▅▄▃▃▃▂▂▂▂▂▁▁▁▁
AUROC_stage_2_median,▁▁▂▃▄▄▄▃▃▃▃▃▃▅███████▇▇▇▃▃▃▃▃▃▃▃▃▃▅▆▄▄▄▄
Iter:,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
counters,▁▁▁▂▃▄▄▅▆▇▇█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train:,█▇▇▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val:,▆▅▅▆▇▇██▇▇▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
AUROC,0.24246
AUROC_median,0.42567
AUROC_stage_1_max,0.282
AUROC_stage_1_median,0.46834
AUROC_stage_2_max,0.2519
AUROC_stage_2_median,0.42618
Iter:,49.0
counters,0.0
train:,1265.56635
val:,15.62509


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 96.58it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 96.05it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.13it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 95.66it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = 

0,1
AUROC,████████████████████████████▅▂▁
AUROC_median,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▇█▃▂▁▁▁▃▅▆▇███
AUROC_stage_1_max,████████████████████████████▆▃▁
AUROC_stage_1_median,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▆█▅▂▂▁▁▂▄▅▇███
AUROC_stage_2_max,████████████████████████████▆▃▁
AUROC_stage_2_median,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▆█▃▂▂▁▁▂▅▆▇███
Iter:,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
counters,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇█
train:,██▇▇▇▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val:,██▇▇▆▆▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
AUROC,0.41938
AUROC_median,0.84747
AUROC_stage_1_max,0.50194
AUROC_stage_1_median,0.81027
AUROC_stage_2_max,0.43617
AUROC_stage_2_median,0.83553
Iter:,30.0
counters,6.0
train:,1273.02438
val:,13.95972


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:55, 98.38it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:57, 97.13it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 96.13it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:56, 97.76it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F

0,1
AUROC,▁▂▄▆▇▇███████████████
AUROC_median,█████▃▂▁▂▇███████████
AUROC_stage_1_max,▁▃▅▆▇▇███████████████
AUROC_stage_1_median,█████▃▂▁▂▇███████████
AUROC_stage_2_max,▁▂▅▆▇▇███████████████
AUROC_stage_2_median,█████▃▂▁▂▇███████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train:,█▇▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
val:,▃▂▁▁▂▃▅▇█▇▇▆▆▅▅▄▄▃▃▃▂

0,1
AUROC,0.83402
AUROC_median,0.83286
AUROC_stage_1_max,0.80286
AUROC_stage_1_median,0.80152
AUROC_stage_2_max,0.82467
AUROC_stage_2_median,0.82354
Iter:,20.0
counters,18.0
train:,1292.98982
val:,16.11269


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:03, 92.03it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:04, 91.29it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:01, 94.01it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:06, 90.07it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.

0,1
AUROC,▁█▇▂▅█▆▆▆▆▆▆▆▆▆▆▆▅▅▄▄
AUROC_median,█▁█▄▇▇▇██████████▇▇▇▇
AUROC_stage_1_max,▁██▄▆█▇▇▇▇▇▇▇▇▇▆▆▅▅▅▄
AUROC_stage_1_median,▇▁█▄▇▇▇▇██████▇▇▇▇▇▇▆
AUROC_stage_2_max,▁█▇▂▅█▆▆▆▆▇▇▇▇▆▆▅▅▅▄▄
AUROC_stage_2_median,█▁█▄▇▇▇▇███████▇▇▇▇▇▇
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
train:,█▂▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▃▃▃▃
val:,▃▂▁▁▁▂▂▃▄▅▆▇█████████

0,1
AUROC,0.78877
AUROC_median,0.78939
AUROC_stage_1_max,0.74315
AUROC_stage_1_median,0.74468
AUROC_stage_2_max,0.76312
AUROC_stage_2_median,0.76409
Iter:,20.0
counters,17.0
train:,1735.34549
val:,31.65863


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:54, 90.10it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:46, 92.78it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:56, 89.41it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:54, 90.27it/s]         

0,1
AUROC,▁▁▁▁▁▁▁▁▁▂▅▆▇▇███████
AUROC_median,███▆▁▄▆▇█▇▇██████████
AUROC_stage_1_max,▁▁▁▁▁▁▁▁▁▂▅▆▆▇███████
AUROC_stage_1_median,███▇▁▄▆▇▇▇▇▇█████████
AUROC_stage_2_max,▁▁▁▁▁▁▁▁▁▂▅▆▇▇███████
AUROC_stage_2_median,███▆▁▄▆▇▇▇▇▇█████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
train:,█▇▆▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val:,█▆▄▃▁▁▁▂▄▆▇▇▇▇▆▆▆▆▆▆▆

0,1
AUROC,0.83895
AUROC_median,0.83833
AUROC_stage_1_max,0.80248
AUROC_stage_1_median,0.8055
AUROC_stage_2_max,0.82903
AUROC_stage_2_median,0.82552
Iter:,20.0
counters,15.0
train:,2647.46933
val:,14.25334


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:06, 90.14it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:59, 95.19it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:58, 96.15it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:05, 90.29it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = 

0,1
AUROC,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▄▅▇█
AUROC_median,▇▇▇▇▇██████████▁▁▂▄███
AUROC_stage_1_max,▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▄▆▇█
AUROC_stage_1_median,▇▇▇▇▇██████▇▇██▅▁▅▆▇██
AUROC_stage_2_max,▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▄▅▇█
AUROC_stage_2_median,▇▇▇▇▇██████████▃▁▃▅███
Iter:,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇█
train:,██▇▇▆▆▅▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁
val:,█▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁

0,1
AUROC,0.33003
AUROC_median,0.84937
AUROC_stage_1_max,0.37553
AUROC_stage_1_median,0.81639
AUROC_stage_2_max,0.35157
AUROC_stage_2_median,0.83647
Iter:,21.0
counters,6.0
train:,1281.56567
val:,15.79334


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [02:07, 83.44it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [02:06, 83.70it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [02:02, 86.75it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
10617it [02:07, 83.08it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.

0,1
AUROC,██████▁▁█████████▇▇▇▇
AUROC_median,██▅▁▅▇████████▇▇▇▇▇▇▇
AUROC_stage_1_max,██████▂▁█████████████
AUROC_stage_1_median,██▆▁▄▅██████████▇▇▇▇▇
AUROC_stage_2_max,██████▁▁███████▇▇▇▇▇▇
AUROC_stage_2_median,██▅▁▅▇███████▇▇▇▇▇▇▇▇
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
train:,▂▁▁▁▁▁▁▁▂▂▂▂▃▃▄▄▅▆▆▇█
val:,▁▁▁▁▁▁▁▁▂▂▂▂▃▃▄▄▅▆▆▇█

0,1
AUROC,0.81381
AUROC_median,0.81131
AUROC_stage_1_max,0.78796
AUROC_stage_1_median,0.78598
AUROC_stage_2_max,0.80091
AUROC_stage_2_median,0.79816
Iter:,20.0
counters,20.0
train:,705.02831
val:,15.49961


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:16, 88.66it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:05, 92.57it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:15, 89.12it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
22752it [04:15, 89.09it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,███████████████████▅▁
AUROC_median,▇▁▅▂▂▂▁▁▁▁▁▅▇████████
AUROC_stage_1_max,███████████████████▅▁
AUROC_stage_1_median,▇▂▄▃▂▂▁▁▁▁▂▃▇████████
AUROC_stage_2_max,███████████████████▅▁
AUROC_stage_2_median,▇▁▅▂▂▂▁▁▁▁▂▅▇████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train:,█▆▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂
val:,▃▁▂▇█▇▇▇▇▆▆▆▆▆▆▆▆▆▇▇▇

0,1
AUROC,0.6007
AUROC_median,0.8357
AUROC_stage_1_max,0.66076
AUROC_stage_1_median,0.80545
AUROC_stage_2_max,0.6086
AUROC_stage_2_median,0.82457
Iter:,20.0
counters,19.0
train:,2531.16955
val:,15.42855


  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:53, 100.02it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:52, 100.99it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [02:02, 93.18it/s]                            
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
11376it [01:53, 100.35it/s]                           
  elif loss_type == "mse":loss = F.mse_loss(q_pred, bellman_target)
  elif loss_type == "mse":loss =

0,1
AUROC,███▄▁████████████▇▇▇▇
AUROC_median,▁▁▂█████████████▇▇▇▇▇
AUROC_stage_1_max,███▅▁███████████▇▇▇▇▆
AUROC_stage_1_median,▁▂▃███████████▇▇▇▇▇▆▆
AUROC_stage_2_max,███▄▁██████████▇▇▇▇▇▇
AUROC_stage_2_median,▁▁▂██████████▇▇▇▇▇▇▆▆
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train:,█▃▁▁▁▁▂▂▃▄▅▆▆▆▆▆▆▆▆▆▆
val:,▂▁▁▁▂▂▂▃▅▆███████████

0,1
AUROC,0.7428
AUROC_median,0.76646
AUROC_stage_1_max,0.67072
AUROC_stage_1_median,0.70738
AUROC_stage_2_max,0.70322
AUROC_stage_2_median,0.73142
Iter:,20.0
counters,19.0
train:,1733.65545
val:,31.56914


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:42, 93.99it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:35, 96.29it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:45, 93.10it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:42, 94.09it/s]         

0,1
AUROC,▁▃▅▆▇████████████████
AUROC_median,▁█████▃▁▅▇███████████
AUROC_stage_1_max,▁▄▅▆▇▇███████████████
AUROC_stage_1_median,▁█████▄▁▄▆███████████
AUROC_stage_2_max,▁▄▅▆▇████████████████
AUROC_stage_2_median,▁█████▄▁▅▆███████████
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
train:,█▇▆▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val:,▃▂▁▁▁▂▃▅█████████████

0,1
AUROC,0.80047
AUROC_median,0.82035
AUROC_stage_1_max,0.78473
AUROC_stage_1_median,0.79948
AUROC_stage_2_max,0.79006
AUROC_stage_2_median,0.81254
Iter:,20.0
counters,17.0
train:,2602.3797
val:,14.55438


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:33, 97.04it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:43, 93.58it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:46, 92.78it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:45, 92.93it/s]         

0,1
AUROC,▁▁▁▂▂▄▆▇█████████████
AUROC_median,██▂▁▁▁▁█████████████▇
AUROC_stage_1_max,▁▁▁▂▃▄▅▆▇████████████
AUROC_stage_1_median,██▂▁▁▁▁██████████████
AUROC_stage_2_max,▁▁▁▂▃▄▆▇█████████████
AUROC_stage_2_median,██▂▁▁▁▁█████████████▇
Iter:,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
counters,▁▁▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train:,█▇▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val:,▂▁▁▂▃▆█████▇▇▇▇▇▇▇▇▇▇

0,1
AUROC,0.82747
AUROC_median,0.77634
AUROC_stage_1_max,0.79935
AUROC_stage_1_median,0.78252
AUROC_stage_2_max,0.81694
AUROC_stage_2_median,0.76596
Iter:,20.0
counters,18.0
train:,2567.06602
val:,14.16981


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:43, 93.74it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:41, 94.45it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:44, 93.26it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:39, 94.92it/s]         

0,1
AUROC,█▁▂▂▃▃▃▃▃▃▂▃▂▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_median,█▁▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_1_max,█▁▂▂▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▂▁
AUROC_stage_1_median,█▁▂▃▃▃▂▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁
AUROC_stage_2_max,█▁▂▂▃▃▃▃▃▃▃▃▂▃▃▃▃▂▃▂▂▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂
AUROC_stage_2_median,█▁▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Iter:,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
counters,▁▁▁▁▂▃▃▃▄▄▅▆▆▆▇██▁▁▂▃▁▁▂▂▁▁▁▂▂▁▂▂▃▃▁▂▂▃▃
train:,▁█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val:,▇▆▄█▄▇▆▄▆▄▆▅▄▆▅▅▇▃▆▃▄▂▅▅▃▃▂▃▃▂▂▂▂▂▂▁▂▂▂▁

0,1
AUROC,0.59745
AUROC_median,0.58196
AUROC_stage_1_max,0.52443
AUROC_stage_1_median,0.51458
AUROC_stage_2_max,0.58006
AUROC_stage_2_median,0.56481
Iter:,47.0
counters,6.0
train:,3006.85825
val:,22.63474


  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:40, 94.51it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:36, 95.94it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:43, 93.73it/s]                            
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
  elif loss_type == "smooth_l1":loss = F.smooth_l1_loss(q_pred, bellman_target)
26544it [04:55, 89.79it/s]         