Import the required libraries

In [2]:
from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DiscreteCQL, DQN
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.dataset import Episode
from d3rlpy.dataset import MDPDataset

from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split

import import_ipynb
import numpy as np
from random import random
from create_dataset import CreateDataset
from FootballEnv import FootballEnv

Helper function to create a dummy dataset

In [3]:
def create_dataset():

    dataset_maker = CreateDataset()
    dataset_maker.loadFile('data.json')

    observations, actions, rewards = dataset_maker.createEpisodeDataset()
    terminals = np.array([[0, 0, 0, 0, 1] for i in range(len(observations))])

    return MDPDataset(
        observations,
        actions,
        rewards, 
        terminals,
    )

In [4]:
dataset = create_dataset()
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)


In [6]:
# setup CQL algorithm
cql = DiscreteCQL(use_gpu=False)

env = FootballEnv()

# start training
output = cql.fit(

    train_episodes,
    eval_episodes=test_episodes,
    n_epochs=50,
    scorers={
        'environment': evaluate_on_environment(env), # evaluate with Football Env
        'advantage': discounted_sum_of_advantage_scorer, # smaller is better
        'td_error': td_error_scorer, # smaller is better
        'value_scale': average_value_estimation_scorer # smaller is better
    }
    
)

2022-02-09 18:27.17 [debug    ] RoundIterator is selected.
2022-02-09 18:27.17 [info     ] Directory is created at d3rlpy_logs\DiscreteCQL_20220209182717
2022-02-09 18:27.17 [debug    ] Building models...
2022-02-09 18:27.17 [debug    ] Models have been built.
2022-02-09 18:27.17 [info     ] Parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'bootstrap': False, 'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_reduction_type': 'min', 'target_update_interval': 8000, 'use_

Epoch 1/50: 100%|██████████| 39/39 [00:00<00:00, 88.84it/s, loss=0.837]


2022-02-09 18:27.18 [info     ] DiscreteCQL_20220209182717: epoch=1 step=39 epoch=1 metrics={'time_sample_batch': 0.00012815304291553987, 'time_algorithm_update': 0.01071775876558744, 'loss': 0.8102536369592716, 'time_step': 0.011076682653182592, 'environment': 2.450121178648094, 'advantage': 0.0, 'td_error': 0.04359680364814267, 'value_scale': 0.4186018109321594} step=39
2022-02-09 18:27.18 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_39.pt


Epoch 2/50: 100%|██████████| 39/39 [00:00<00:00, 99.49it/s, loss=0.639] 


2022-02-09 18:27.19 [info     ] DiscreteCQL_20220209182717: epoch=2 step=78 epoch=2 metrics={'time_sample_batch': 0.00023074639149201222, 'time_algorithm_update': 0.009487231572469076, 'loss': 0.6268272124803983, 'time_step': 0.00992291401594113, 'environment': 2.391892965375231, 'advantage': 0.0, 'td_error': 0.04835817293883338, 'value_scale': 0.4408651888370514} step=78
2022-02-09 18:27.19 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_78.pt


Epoch 3/50: 100%|██████████| 39/39 [00:00<00:00, 108.94it/s, loss=0.523]


2022-02-09 18:27.20 [info     ] DiscreteCQL_20220209182717: epoch=3 step=117 epoch=3 metrics={'time_sample_batch': 0.00020503997802734375, 'time_algorithm_update': 0.008589530602479592, 'loss': 0.5118690133094788, 'time_step': 0.008948485056559244, 'environment': 2.3394252512715616, 'advantage': 0.0, 'td_error': 0.03644832080067317, 'value_scale': 0.38273920118808746} step=117
2022-02-09 18:27.20 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_117.pt


Epoch 4/50: 100%|██████████| 39/39 [00:00<00:00, 113.04it/s, loss=0.414]


2022-02-09 18:27.21 [info     ] DiscreteCQL_20220209182717: epoch=4 step=156 epoch=4 metrics={'time_sample_batch': 7.67218760955028e-05, 'time_algorithm_update': 0.008410875613872822, 'loss': 0.40373570873187137, 'time_step': 0.008717928177271133, 'environment': 2.6109199025329124, 'advantage': 0.0, 'td_error': 0.02451996064019113, 'value_scale': 0.3138415813446045} step=156
2022-02-09 18:27.21 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_156.pt


Epoch 5/50: 100%|██████████| 39/39 [00:00<00:00, 129.14it/s, loss=0.314]


2022-02-09 18:27.21 [info     ] DiscreteCQL_20220209182717: epoch=5 step=195 epoch=5 metrics={'time_sample_batch': 0.0001280674567589393, 'time_algorithm_update': 0.007333205296443059, 'loss': 0.3050049061958606, 'time_step': 0.007666459450354943, 'environment': 2.713118402102263, 'advantage': 0.0, 'td_error': 0.015200071680851579, 'value_scale': 0.24686583876609802} step=195
2022-02-09 18:27.21 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_195.pt


Epoch 6/50: 100%|██████████| 39/39 [00:00<00:00, 112.07it/s, loss=0.229]


2022-02-09 18:27.22 [info     ] DiscreteCQL_20220209182717: epoch=6 step=234 epoch=6 metrics={'time_sample_batch': 0.00020507054451184394, 'time_algorithm_update': 0.008512753706711989, 'loss': 0.2220483047839923, 'time_step': 0.008820417599800305, 'environment': 2.2105634950407516, 'advantage': 0.0, 'td_error': 0.009180523265001739, 'value_scale': 0.19140379130840302} step=234
2022-02-09 18:27.22 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_234.pt


Epoch 7/50: 100%|██████████| 39/39 [00:00<00:00, 103.72it/s, loss=0.163]


2022-02-09 18:27.23 [info     ] DiscreteCQL_20220209182717: epoch=7 step=273 epoch=7 metrics={'time_sample_batch': 0.00023015951498960837, 'time_algorithm_update': 0.009103866723867564, 'loss': 0.15770331368996546, 'time_step': 0.009512925759339944, 'environment': 2.7300564579403495, 'advantage': 0.0, 'td_error': 0.0049669050712282115, 'value_scale': 0.13988003879785538} step=273
2022-02-09 18:27.23 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_273.pt


Epoch 8/50: 100%|██████████| 39/39 [00:00<00:00, 83.16it/s, loss=0.115]


2022-02-09 18:27.24 [info     ] DiscreteCQL_20220209182717: epoch=8 step=312 epoch=8 metrics={'time_sample_batch': 0.000205076657808744, 'time_algorithm_update': 0.011358890777979141, 'loss': 0.11174011211364697, 'time_step': 0.011743288773756761, 'environment': 2.6667805816924752, 'advantage': 0.0, 'td_error': 0.002865786350923827, 'value_scale': 0.10495540872216225} step=312
2022-02-09 18:27.24 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_312.pt


Epoch 9/50: 100%|██████████| 39/39 [00:00<00:00, 132.66it/s, loss=0.0831]


2022-02-09 18:27.25 [info     ] DiscreteCQL_20220209182717: epoch=9 step=351 epoch=9 metrics={'time_sample_batch': 0.0002297193576128055, 'time_algorithm_update': 0.007077571673270984, 'loss': 0.08062018606907283, 'time_step': 0.00743568860567533, 'environment': 2.354235654260605, 'advantage': 0.0, 'td_error': 0.0017690378707015952, 'value_scale': 0.08078968524932861} step=351
2022-02-09 18:27.25 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_351.pt


Epoch 10/50: 100%|██████████| 39/39 [00:00<00:00, 126.21it/s, loss=0.0614]


2022-02-09 18:27.26 [info     ] DiscreteCQL_20220209182717: epoch=10 step=390 epoch=10 metrics={'time_sample_batch': 0.00023090533721141325, 'time_algorithm_update': 0.007384563103700295, 'loss': 0.059721987981062666, 'time_step': 0.007769236197838416, 'environment': 2.561946820371105, 'advantage': 0.0, 'td_error': 0.0012598445859666185, 'value_scale': 0.0665595643222332} step=390
2022-02-09 18:27.26 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_390.pt


Epoch 11/50: 100%|██████████| 39/39 [00:00<00:00, 125.81it/s, loss=0.0467]


2022-02-09 18:27.27 [info     ] DiscreteCQL_20220209182717: epoch=11 step=429 epoch=11 metrics={'time_sample_batch': 0.0002050032982459435, 'time_algorithm_update': 0.007538447013268104, 'loss': 0.04549629527788896, 'time_step': 0.007897499280098157, 'environment': 2.377003016029788, 'advantage': 0.0, 'td_error': 0.0008667262656203434, 'value_scale': 0.05291527882218361} step=429
2022-02-09 18:27.27 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_429.pt


Epoch 12/50: 100%|██████████| 39/39 [00:00<00:00, 125.00it/s, loss=0.0364]


2022-02-09 18:27.28 [info     ] DiscreteCQL_20220209182717: epoch=12 step=468 epoch=12 metrics={'time_sample_batch': 0.0001792602049998748, 'time_algorithm_update': 0.0075386793185503054, 'loss': 0.035573785121624284, 'time_step': 0.00784603754679362, 'environment': 2.104619728209615, 'advantage': 0.0, 'td_error': 0.0006872046829329292, 'value_scale': 0.04527231305837631} step=468
2022-02-09 18:27.28 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_468.pt


Epoch 13/50: 100%|██████████| 39/39 [00:00<00:00, 124.21it/s, loss=0.0291]


2022-02-09 18:27.28 [info     ] DiscreteCQL_20220209182717: epoch=13 step=507 epoch=13 metrics={'time_sample_batch': 0.0001793519044533754, 'time_algorithm_update': 0.007589865953494341, 'loss': 0.02846353840178404, 'time_step': 0.007897450373722957, 'environment': 2.393950774426694, 'advantage': 0.0, 'td_error': 0.000552042912147499, 'value_scale': 0.038472309708595276} step=507
2022-02-09 18:27.28 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_507.pt


Epoch 14/50: 100%|██████████| 39/39 [00:00<00:00, 127.04it/s, loss=0.0237]


2022-02-09 18:27.29 [info     ] DiscreteCQL_20220209182717: epoch=14 step=546 epoch=14 metrics={'time_sample_batch': 5.119886153783554e-05, 'time_algorithm_update': 0.007461138260670197, 'loss': 0.023232759334720098, 'time_step': 0.007743652050311749, 'environment': 2.3318205094981725, 'advantage': 0.0, 'td_error': 0.0004934303157426712, 'value_scale': 0.03509025275707245} step=546
2022-02-09 18:27.29 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_546.pt


Epoch 15/50: 100%|██████████| 39/39 [00:00<00:00, 117.83it/s, loss=0.0196]


2022-02-09 18:27.30 [info     ] DiscreteCQL_20220209182717: epoch=15 step=585 epoch=15 metrics={'time_sample_batch': 0.0001534254123003055, 'time_algorithm_update': 0.008077028470161634, 'loss': 0.019291020213411406, 'time_step': 0.00835885757055038, 'environment': 2.649708692217324, 'advantage': 0.0, 'td_error': 0.000402864079694254, 'value_scale': 0.02904122695326805} step=585
2022-02-09 18:27.30 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_585.pt


Epoch 16/50: 100%|██████████| 39/39 [00:00<00:00, 126.21it/s, loss=0.0165]


2022-02-09 18:27.31 [info     ] DiscreteCQL_20220209182717: epoch=16 step=624 epoch=16 metrics={'time_sample_batch': 0.00025656895759778144, 'time_algorithm_update': 0.007384697596232097, 'loss': 0.01625604660083086, 'time_step': 0.00784628207866962, 'environment': 2.6103919841423844, 'advantage': 0.0, 'td_error': 0.0003955849276522372, 'value_scale': 0.028494812548160553} step=624
2022-02-09 18:27.31 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_624.pt


Epoch 17/50: 100%|██████████| 39/39 [00:00<00:00, 128.71it/s, loss=0.0141]


2022-02-09 18:27.32 [info     ] DiscreteCQL_20220209182717: epoch=17 step=663 epoch=17 metrics={'time_sample_batch': 0.00015413455474071013, 'time_algorithm_update': 0.007358673291328626, 'loss': 0.013873648734237904, 'time_step': 0.007692196430304112, 'environment': 2.7593472278793567, 'advantage': 0.0, 'td_error': 0.00036334151235983825, 'value_scale': 0.02592425048351288} step=663
2022-02-09 18:27.32 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_663.pt


Epoch 18/50: 100%|██████████| 39/39 [00:00<00:00, 129.14it/s, loss=0.0121]


2022-02-09 18:27.33 [info     ] DiscreteCQL_20220209182717: epoch=18 step=702 epoch=18 metrics={'time_sample_batch': 0.00017876502795097156, 'time_algorithm_update': 0.007359840931036534, 'loss': 0.011973054458697638, 'time_step': 0.0076923431494297125, 'environment': 2.905420053531321, 'advantage': 0.0, 'td_error': 0.0003635750350241551, 'value_scale': 0.0259438157081604} step=702
2022-02-09 18:27.33 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_702.pt


Epoch 19/50: 100%|██████████| 39/39 [00:00<00:00, 123.41it/s, loss=0.0106]


2022-02-09 18:27.34 [info     ] DiscreteCQL_20220209182717: epoch=19 step=741 epoch=19 metrics={'time_sample_batch': 0.00025628163264347956, 'time_algorithm_update': 0.007564563017625075, 'loss': 0.010432502230963646, 'time_step': 0.007974465688069662, 'environment': 2.3850284931773205, 'advantage': 0.0, 'td_error': 0.00036085288999210263, 'value_scale': 0.025714166462421417} step=741
2022-02-09 18:27.34 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_741.pt


Epoch 20/50: 100%|██████████| 39/39 [00:00<00:00, 122.64it/s, loss=0.00928]


2022-02-09 18:27.34 [info     ] DiscreteCQL_20220209182717: epoch=20 step=780 epoch=20 metrics={'time_sample_batch': 0.00017938858423477565, 'time_algorithm_update': 0.007589596968430739, 'loss': 0.009167751679435754, 'time_step': 0.008025377224653196, 'environment': 2.5573339825577843, 'advantage': 0.0, 'td_error': 0.00031900534304440953, 'value_scale': 0.021849704906344414} step=780
2022-02-09 18:27.34 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_780.pt


Epoch 21/50: 100%|██████████| 39/39 [00:00<00:00, 120.74it/s, loss=0.00821]


2022-02-09 18:27.35 [info     ] DiscreteCQL_20220209182717: epoch=21 step=819 epoch=21 metrics={'time_sample_batch': 0.00017976760864257812, 'time_algorithm_update': 0.007794765325692983, 'loss': 0.008118701812166434, 'time_step': 0.008153713666475736, 'environment': 2.5345290985778823, 'advantage': 0.0, 'td_error': 0.0002933482552416322, 'value_scale': 0.019049357622861862} step=819
2022-02-09 18:27.35 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_819.pt


Epoch 22/50: 100%|██████████| 39/39 [00:00<00:00, 111.75it/s, loss=0.00732]


2022-02-09 18:27.36 [info     ] DiscreteCQL_20220209182717: epoch=22 step=858 epoch=22 metrics={'time_sample_batch': 0.00022989664322290666, 'time_algorithm_update': 0.00841065553518442, 'loss': 0.007236824156000064, 'time_step': 0.008820020235501803, 'environment': 2.5866060527487678, 'advantage': 0.0, 'td_error': 0.0002958759716324977, 'value_scale': 0.01934550516307354} step=858
2022-02-09 18:27.36 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_858.pt


Epoch 23/50: 100%|██████████| 39/39 [00:00<00:00, 116.76it/s, loss=0.00656]


2022-02-09 18:27.37 [info     ] DiscreteCQL_20220209182717: epoch=23 step=897 epoch=23 metrics={'time_sample_batch': 0.00015369439736390725, 'time_algorithm_update': 0.008076680012238331, 'loss': 0.0064901212851206464, 'time_step': 0.008410105338463416, 'environment': 2.6353221347925615, 'advantage': 0.0, 'td_error': 0.000302654006761216, 'value_scale': 0.020115533843636513} step=897
2022-02-09 18:27.37 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_897.pt


Epoch 24/50: 100%|██████████| 39/39 [00:00<00:00, 120.37it/s, loss=0.00591]


2022-02-09 18:27.38 [info     ] DiscreteCQL_20220209182717: epoch=24 step=936 epoch=24 metrics={'time_sample_batch': 0.00015383500319260819, 'time_algorithm_update': 0.007923126220703125, 'loss': 0.005852135471426523, 'time_step': 0.008230649507962741, 'environment': 2.364146759130956, 'advantage': 0.0, 'td_error': 0.00029261028917026977, 'value_scale': 0.01896195486187935} step=936
2022-02-09 18:27.38 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_936.pt


Epoch 25/50: 100%|██████████| 39/39 [00:00<00:00, 123.42it/s, loss=0.00535]


2022-02-09 18:27.39 [info     ] DiscreteCQL_20220209182717: epoch=25 step=975 epoch=25 metrics={'time_sample_batch': 0.00010292957990597456, 'time_algorithm_update': 0.007692092504256811, 'loss': 0.005300424455736692, 'time_step': 0.007974526821038662, 'environment': 2.435703243878822, 'advantage': 0.0, 'td_error': 0.00030417862383202987, 'value_scale': 0.020284214988350868} step=975
2022-02-09 18:27.39 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_975.pt


Epoch 26/50: 100%|██████████| 39/39 [00:00<00:00, 122.65it/s, loss=0.00487]


2022-02-09 18:27.39 [info     ] DiscreteCQL_20220209182717: epoch=26 step=1014 epoch=26 metrics={'time_sample_batch': 0.0002564528049566807, 'time_algorithm_update': 0.007614722618689904, 'loss': 0.004823672107588022, 'time_step': 0.008025059333214393, 'environment': 2.463471238257159, 'advantage': 0.0, 'td_error': 0.0002879286889330146, 'value_scale': 0.018396280705928802} step=1014
2022-02-09 18:27.39 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1014.pt


Epoch 27/50: 100%|██████████| 39/39 [00:00<00:00, 127.45it/s, loss=0.00445]


2022-02-09 18:27.40 [info     ] DiscreteCQL_20220209182717: epoch=27 step=1053 epoch=27 metrics={'time_sample_batch': 0.00012744390047513522, 'time_algorithm_update': 0.007487938954279973, 'loss': 0.00440770013926503, 'time_step': 0.007743321932279146, 'environment': 2.5434352417011374, 'advantage': 0.0, 'td_error': 0.00026457083208075716, 'value_scale': 0.015213698148727417} step=1053
2022-02-09 18:27.40 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1053.pt


Epoch 28/50: 100%|██████████| 39/39 [00:00<00:00, 118.18it/s, loss=0.00408]


2022-02-09 18:27.41 [info     ] DiscreteCQL_20220209182717: epoch=28 step=1092 epoch=28 metrics={'time_sample_batch': 0.00017965145600147737, 'time_algorithm_update': 0.007948961013402695, 'loss': 0.004042911152235973, 'time_step': 0.008333676900619116, 'environment': 2.61871312332471, 'advantage': 0.0, 'td_error': 0.00028928155067919903, 'value_scale': 0.018561750650405884} step=1092
2022-02-09 18:27.41 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1092.pt


Epoch 29/50: 100%|██████████| 39/39 [00:00<00:00, 123.41it/s, loss=0.00375]


2022-02-09 18:27.42 [info     ] DiscreteCQL_20220209182717: epoch=29 step=1131 epoch=29 metrics={'time_sample_batch': 0.00012823862907214044, 'time_algorithm_update': 0.007692300356351412, 'loss': 0.003720382747885126, 'time_step': 0.007974453461475862, 'environment': 2.4091116748499433, 'advantage': 0.0, 'td_error': 0.000301825995382643, 'value_scale': 0.020023195073008537} step=1131
2022-02-09 18:27.42 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1131.pt


Epoch 30/50: 100%|██████████| 39/39 [00:00<00:00, 127.45it/s, loss=0.00347]


2022-02-09 18:27.43 [info     ] DiscreteCQL_20220209182717: epoch=30 step=1170 epoch=30 metrics={'time_sample_batch': 0.00020502775143354366, 'time_algorithm_update': 0.007435982043926532, 'loss': 0.0034365508681497513, 'time_step': 0.007769321783995017, 'environment': 2.693256729143915, 'advantage': 0.0, 'td_error': 0.0002678595105480852, 'value_scale': 0.015706997364759445} step=1170
2022-02-09 18:27.43 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1170.pt


Epoch 31/50: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.0032] 


2022-02-09 18:27.44 [info     ] DiscreteCQL_20220209182717: epoch=31 step=1209 epoch=31 metrics={'time_sample_batch': 0.0002814928690592448, 'time_algorithm_update': 0.007666496130136343, 'loss': 0.0031793555483604088, 'time_step': 0.008127829967400966, 'environment': 2.656660488105994, 'advantage': 0.0, 'td_error': 0.00027500840241057034, 'value_scale': 0.01672189310193062} step=1209
2022-02-09 18:27.44 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1209.pt


Epoch 32/50: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.00297]


2022-02-09 18:27.44 [info     ] DiscreteCQL_20220209182717: epoch=32 step=1248 epoch=32 metrics={'time_sample_batch': 0.0001280980232434395, 'time_algorithm_update': 0.007794392414582081, 'loss': 0.0029515666373742698, 'time_step': 0.008101872908763396, 'environment': 2.470054145084439, 'advantage': 0.0, 'td_error': 0.00025144784509656404, 'value_scale': 0.013025157153606415} step=1248
2022-02-09 18:27.44 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1248.pt


Epoch 33/50: 100%|██████████| 39/39 [00:00<00:00, 106.56it/s, loss=0.00277]


2022-02-09 18:27.45 [info     ] DiscreteCQL_20220209182717: epoch=33 step=1287 epoch=33 metrics={'time_sample_batch': 0.00010234270340357072, 'time_algorithm_update': 0.009025977208064152, 'loss': 0.0027485196359264543, 'time_step': 0.009308209786048302, 'environment': 2.0561354339069315, 'advantage': 0.0, 'td_error': 0.00023781978667791748, 'value_scale': 0.010142712853848934} step=1287
2022-02-09 18:27.45 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1287.pt


Epoch 34/50: 100%|██████████| 39/39 [00:00<00:00, 123.42it/s, loss=0.00259]


2022-02-09 18:27.46 [info     ] DiscreteCQL_20220209182717: epoch=34 step=1326 epoch=34 metrics={'time_sample_batch': 0.00012814692961863981, 'time_algorithm_update': 0.0076667162088247445, 'loss': 0.0025675352662801743, 'time_step': 0.00797420892959986, 'environment': 2.598389546393346, 'advantage': 0.0, 'td_error': 0.0002697604512817975, 'value_scale': 0.015984125435352325} step=1326
2022-02-09 18:27.46 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1326.pt


Epoch 35/50: 100%|██████████| 39/39 [00:00<00:00, 120.37it/s, loss=0.00241]


2022-02-09 18:27.47 [info     ] DiscreteCQL_20220209182717: epoch=35 step=1365 epoch=35 metrics={'time_sample_batch': 0.00025656895759778144, 'time_algorithm_update': 0.007794490227332482, 'loss': 0.002396724945029769, 'time_step': 0.008205053133842273, 'environment': 2.278612412494257, 'advantage': 0.0, 'td_error': 0.0002919164694681342, 'value_scale': 0.01887933909893036} step=1365
2022-02-09 18:27.47 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1365.pt


Epoch 36/50: 100%|██████████| 39/39 [00:00<00:00, 126.21it/s, loss=0.00226]


2022-02-09 18:27.48 [info     ] DiscreteCQL_20220209182717: epoch=36 step=1404 epoch=36 metrics={'time_sample_batch': 0.00023086254413311297, 'time_algorithm_update': 0.007461272753201998, 'loss': 0.0022449743969795797, 'time_step': 0.007820447285970053, 'environment': 2.6252902678448704, 'advantage': 0.0, 'td_error': 0.0002635138926452285, 'value_scale': 0.015051186084747314} step=1404
2022-02-09 18:27.48 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1404.pt


Epoch 37/50: 100%|██████████| 39/39 [00:00<00:00, 126.63it/s, loss=0.00212]


2022-02-09 18:27.49 [info     ] DiscreteCQL_20220209182717: epoch=37 step=1443 epoch=37 metrics={'time_sample_batch': 0.0003076333266038161, 'time_algorithm_update': 0.0073845203106219955, 'loss': 0.002111196446304138, 'time_step': 0.007820386153001051, 'environment': 2.4138232987607635, 'advantage': 0.0, 'td_error': 0.000260580636183505, 'value_scale': 0.014589011669158936} step=1443
2022-02-09 18:27.49 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1443.pt


Epoch 38/50: 100%|██████████| 39/39 [00:00<00:00, 125.81it/s, loss=0.002]  


2022-02-09 18:27.49 [info     ] DiscreteCQL_20220209182717: epoch=38 step=1482 epoch=38 metrics={'time_sample_batch': 0.0002297071310190054, 'time_algorithm_update': 0.007385736856705103, 'loss': 0.0019828564344117274, 'time_step': 0.007769297330807417, 'environment': 2.9348084967712196, 'advantage': 0.0, 'td_error': 0.00027304632815283725, 'value_scale': 0.016450483351945877} step=1482
2022-02-09 18:27.49 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1482.pt


Epoch 39/50: 100%|██████████| 39/39 [00:00<00:00, 120.75it/s, loss=0.00188]


2022-02-09 18:27.50 [info     ] DiscreteCQL_20220209182717: epoch=39 step=1521 epoch=39 metrics={'time_sample_batch': 0.00012822640247834034, 'time_algorithm_update': 0.007820337246625852, 'loss': 0.0018673297077512895, 'time_step': 0.008153921518570337, 'environment': 2.410307665528551, 'advantage': 0.0, 'td_error': 0.0002721521139288363, 'value_scale': 0.016325118020176888} step=1521
2022-02-09 18:27.50 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1521.pt


Epoch 40/50: 100%|██████████| 39/39 [00:00<00:00, 115.39it/s, loss=0.00177]


2022-02-09 18:27.51 [info     ] DiscreteCQL_20220209182717: epoch=40 step=1560 epoch=40 metrics={'time_sample_batch': 0.00010258112198267228, 'time_algorithm_update': 0.008204973660982572, 'loss': 0.001762418709217738, 'time_step': 0.008512729253524389, 'environment': 2.4096266841427125, 'advantage': 0.0, 'td_error': 0.000259043485450583, 'value_scale': 0.014339837245643139} step=1560
2022-02-09 18:27.51 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1560.pt


Epoch 41/50: 100%|██████████| 39/39 [00:00<00:00, 130.00it/s, loss=0.00168]


2022-02-09 18:27.52 [info     ] DiscreteCQL_20220209182717: epoch=41 step=1599 epoch=41 metrics={'time_sample_batch': 0.00025656284430088143, 'time_algorithm_update': 0.00717920523423415, 'loss': 0.001665232252353468, 'time_step': 0.0075896397615090394, 'environment': 2.763512677985966, 'advantage': 0.0, 'td_error': 0.00026665074493736896, 'value_scale': 0.015527789480984211} step=1599
2022-02-09 18:27.52 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1599.pt


Epoch 42/50: 100%|██████████| 39/39 [00:00<00:00, 127.45it/s, loss=0.00158]


2022-02-09 18:27.53 [info     ] DiscreteCQL_20220209182717: epoch=42 step=1638 epoch=42 metrics={'time_sample_batch': 0.00030733377505571413, 'time_algorithm_update': 0.007308000173324194, 'loss': 0.0015755224650582443, 'time_step': 0.007743511444483048, 'environment': 2.517949811019306, 'advantage': 0.0, 'td_error': 0.0002554105525622852, 'value_scale': 0.01372955460101366} step=1638
2022-02-09 18:27.53 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1638.pt


Epoch 43/50: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.0015] 


2022-02-09 18:27.53 [info     ] DiscreteCQL_20220209182717: epoch=43 step=1677 epoch=43 metrics={'time_sample_batch': 7.688693511180389e-05, 'time_algorithm_update': 0.0077427167158860425, 'loss': 0.0014952090509140338, 'time_step': 0.008051212017352764, 'environment': 2.774349501580196, 'advantage': 0.0, 'td_error': 0.00025564456787074974, 'value_scale': 0.013769876211881638} step=1677
2022-02-09 18:27.53 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1677.pt


Epoch 44/50: 100%|██████████| 39/39 [00:00<00:00, 117.83it/s, loss=0.00142]


2022-02-09 18:27.54 [info     ] DiscreteCQL_20220209182717: epoch=44 step=1716 epoch=44 metrics={'time_sample_batch': 7.684414203350361e-05, 'time_algorithm_update': 0.008051193677462064, 'loss': 0.001416107818770867, 'time_step': 0.008358906476925582, 'environment': 2.296920386051411, 'advantage': 0.0, 'td_error': 0.00026737936906728876, 'value_scale': 0.015636049211025238} step=1716
2022-02-09 18:27.54 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1716.pt


Epoch 45/50: 100%|██████████| 39/39 [00:00<00:00, 119.27it/s, loss=0.00135]


2022-02-09 18:27.55 [info     ] DiscreteCQL_20220209182717: epoch=45 step=1755 epoch=45 metrics={'time_sample_batch': 0.00015374941703600762, 'time_algorithm_update': 0.007871780640039688, 'loss': 0.0013467967032622069, 'time_step': 0.008204894188122872, 'environment': 2.822559832005067, 'advantage': 0.0, 'td_error': 0.0002790224822124543, 'value_scale': 0.01726192981004715} step=1755
2022-02-09 18:27.55 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1755.pt


Epoch 46/50: 100%|██████████| 39/39 [00:00<00:00, 129.14it/s, loss=0.00129]


2022-02-09 18:27.56 [info     ] DiscreteCQL_20220209182717: epoch=46 step=1794 epoch=46 metrics={'time_sample_batch': 0.00025603098747057794, 'time_algorithm_update': 0.007155173864120092, 'loss': 0.0012823466408567934, 'time_step': 0.007615572366959009, 'environment': 2.4187367125750265, 'advantage': 0.0, 'td_error': 0.00028451597658296635, 'value_scale': 0.01797119341790676} step=1794
2022-02-09 18:27.56 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1794.pt


Epoch 47/50: 100%|██████████| 39/39 [00:00<00:00, 131.31it/s, loss=0.00123]


2022-02-09 18:27.57 [info     ] DiscreteCQL_20220209182717: epoch=47 step=1833 epoch=47 metrics={'time_sample_batch': 0.00012740722069373497, 'time_algorithm_update': 0.007231223277556591, 'loss': 0.00122354129771105, 'time_step': 0.007563603230011769, 'environment': 2.606089470597506, 'advantage': 0.0, 'td_error': 0.00023795368123424865, 'value_scale': 0.010176033712923527} step=1833
2022-02-09 18:27.57 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1833.pt


Epoch 48/50: 100%|██████████| 39/39 [00:00<00:00, 127.04it/s, loss=0.00117]


2022-02-09 18:27.57 [info     ] DiscreteCQL_20220209182717: epoch=48 step=1872 epoch=48 metrics={'time_sample_batch': 0.00017893620026417269, 'time_algorithm_update': 0.0073598714975210335, 'loss': 0.0011642652933891767, 'time_step': 0.00774389658218775, 'environment': 2.1993021984742134, 'advantage': 0.0, 'td_error': 0.00027236875217795387, 'value_scale': 0.01635560393333435} step=1872
2022-02-09 18:27.57 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1872.pt


Epoch 49/50: 100%|██████████| 39/39 [00:00<00:00, 121.88it/s, loss=0.00112]


2022-02-09 18:27.58 [info     ] DiscreteCQL_20220209182717: epoch=49 step=1911 epoch=49 metrics={'time_sample_batch': 0.00012773733872633713, 'time_algorithm_update': 0.00776976194137182, 'loss': 0.0011103677097707987, 'time_step': 0.008077046810052333, 'environment': 2.201296036094921, 'advantage': 0.0, 'td_error': 0.00027519004211029596, 'value_scale': 0.016746830195188522} step=1911
2022-02-09 18:27.58 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1911.pt


Epoch 50/50: 100%|██████████| 39/39 [00:00<00:00, 126.62it/s, loss=0.00107]


2022-02-09 18:27.59 [info     ] DiscreteCQL_20220209182717: epoch=50 step=1950 epoch=50 metrics={'time_sample_batch': 0.0001533214862530048, 'time_algorithm_update': 0.007410550728822365, 'loss': 0.001062715870853609, 'time_step': 0.007768692114414313, 'environment': 2.563432976098427, 'advantage': 0.0, 'td_error': 0.00026381723443291705, 'value_scale': 0.015098026022315025} step=1950
2022-02-09 18:27.59 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209182717\model_1950.pt


In [32]:
obs = env.reset()

ds = CreateDataset()
ds.loadFile('data.json')

observations, actions, rewards = ds.createEpisodeDataset()
answers = {}
for situation in observations:

    situation = np.array([ds.PASS, ds.CARRY, ds.CARRY, ds.CARRY, ds.CARRY], dtype=np.float32)
    predictions = cql.predict([situation])[0]

    p = ds.ID_to_str[predictions]
    if not (p in answers): answers[p] = 1
    else: answers[p] += 1

    # print(ds.ID_to_str[predictions])

In [38]:
counts = {}

for item in actions:
    if not (item in counts): counts[item] = 1
    else: counts[item] += 1

counts


{2.0: 1964}