Import the required libraries

In [1]:
from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DiscreteCQL, DQN
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.dataset import Episode
from d3rlpy.dataset import MDPDataset

from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split

import import_ipynb
import numpy as np
from random import random
from create_dataset import CreateDataset
from FootballEnv import FootballEnv

importing Jupyter notebook from FootballEnv.ipynb


Helper function to create a dummy dataset

In [2]:
def create_dataset():

    dataset_maker = CreateDataset()
    dataset_maker.loadFile('data.json')

    observations, actions, rewards = dataset_maker.createEpisodeDataset()
    terminals = np.array([[0, 0, 0, 0, 1] for i in range(len(observations))])

    return MDPDataset(
        observations,
        actions,
        rewards, 
        terminals,
    )

In [3]:
dataset = create_dataset()
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)


In [4]:
# setup CQL algorithm
cql = DiscreteCQL(use_gpu=False)

env = FootballEnv()

# start training
output = cql.fit(

    train_episodes,
    eval_episodes=test_episodes,
    n_epochs=50,
    scorers={
        'environment': evaluate_on_environment(env), # evaluate with Football Env
        'advantage': discounted_sum_of_advantage_scorer, # smaller is better
        'td_error': td_error_scorer, # smaller is better
        'value_scale': average_value_estimation_scorer # smaller is better
    }
    
)

2022-02-09 18:59.37 [debug    ] RoundIterator is selected.
2022-02-09 18:59.37 [info     ] Directory is created at d3rlpy_logs\DiscreteCQL_20220209185937
2022-02-09 18:59.37 [debug    ] Building models...
2022-02-09 18:59.37 [debug    ] Models have been built.
2022-02-09 18:59.37 [info     ] Parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'bootstrap': False, 'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_reduction_type': 'min', 'target_update_interval': 8000, 'use_

Epoch 1/50: 100%|██████████| 5/5 [00:00<00:00, 74.63it/s, loss=1.59]


2022-02-09 18:59.37 [info     ] DiscreteCQL_20220209185937: epoch=1 step=5 epoch=1 metrics={'time_sample_batch': 0.00020012855529785156, 'time_algorithm_update': 0.012398958206176758, 'loss': 1.5351174354553223, 'time_step': 0.01319890022277832, 'environment': 2.6485580262271236, 'advantage': -0.2708004844428999, 'td_error': 0.08050351134640248, 'value_scale': 0.062185308003487684} step=5
2022-02-09 18:59.37 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_5.pt


Epoch 2/50: 100%|██████████| 5/5 [00:00<00:00, 83.33it/s, loss=1.49]


2022-02-09 18:59.37 [info     ] DiscreteCQL_20220209185937: epoch=2 step=10 epoch=2 metrics={'time_sample_batch': 0.00040044784545898435, 'time_algorithm_update': 0.011199092864990235, 'loss': 1.472410249710083, 'time_step': 0.012000370025634765, 'environment': 2.6480849834463083, 'advantage': -0.07488980034329944, 'td_error': 0.08331471621612756, 'value_scale': 0.043040832543435194} step=10
2022-02-09 18:59.37 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_10.pt


Epoch 3/50: 100%|██████████| 5/5 [00:00<00:00, 65.79it/s, loss=1.43]


2022-02-09 18:59.38 [info     ] DiscreteCQL_20220209185937: epoch=3 step=15 epoch=3 metrics={'time_sample_batch': 0.0, 'time_algorithm_update': 0.013999557495117188, 'loss': 1.4211612462997436, 'time_step': 0.01459951400756836, 'environment': 2.537695294578913, 'advantage': -0.07435993912358024, 'td_error': 0.09530536680406805, 'value_scale': 0.0941613120182107} step=15
2022-02-09 18:59.38 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_15.pt


Epoch 4/50: 100%|██████████| 5/5 [00:00<00:00, 35.46it/s, loss=1.36]


2022-02-09 18:59.38 [info     ] DiscreteCQL_20220209185937: epoch=4 step=20 epoch=4 metrics={'time_sample_batch': 0.0006017684936523438, 'time_algorithm_update': 0.026199865341186523, 'loss': 1.3857901573181153, 'time_step': 0.027401542663574217, 'environment': 2.7813846763412133, 'advantage': -0.10934450229968068, 'td_error': 0.1090198943107931, 'value_scale': 0.14545006475721797} step=20
2022-02-09 18:59.38 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_20.pt


Epoch 5/50: 100%|██████████| 5/5 [00:00<00:00, 57.47it/s, loss=1.37]


2022-02-09 18:59.38 [info     ] DiscreteCQL_20220209185937: epoch=5 step=25 epoch=5 metrics={'time_sample_batch': 0.00040130615234375, 'time_algorithm_update': 0.016197776794433592, 'loss': 1.3330538988113403, 'time_step': 0.01699972152709961, 'environment': 2.3578517048848724, 'advantage': -0.14009704712779664, 'td_error': 0.1212466614137971, 'value_scale': 0.18271165255767605} step=25
2022-02-09 18:59.38 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_25.pt


Epoch 6/50: 100%|██████████| 5/5 [00:00<00:00, 57.47it/s, loss=1.32]


2022-02-09 18:59.39 [info     ] DiscreteCQL_20220209185937: epoch=6 step=30 epoch=6 metrics={'time_sample_batch': 0.00019969940185546876, 'time_algorithm_update': 0.016199684143066405, 'loss': 1.3036642789840698, 'time_step': 0.01699938774108887, 'environment': 2.4069124475368895, 'advantage': -0.1673673646960656, 'td_error': 0.1323245644976936, 'value_scale': 0.20887961828460297} step=30
2022-02-09 18:59.39 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_30.pt


Epoch 7/50: 100%|██████████| 5/5 [00:00<00:00, 34.25it/s, loss=1.22]


2022-02-09 18:59.39 [info     ] DiscreteCQL_20220209185937: epoch=7 step=35 epoch=7 metrics={'time_sample_batch': 0.0003993988037109375, 'time_algorithm_update': 0.026800060272216798, 'loss': 1.2699882507324218, 'time_step': 0.02819833755493164, 'environment': 2.3906948923740305, 'advantage': -0.19111341379249655, 'td_error': 0.14044418612288703, 'value_scale': 0.2260803608999898} step=35
2022-02-09 18:59.39 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_35.pt


Epoch 8/50: 100%|██████████| 5/5 [00:00<00:00, 39.37it/s, loss=1.27]


2022-02-09 18:59.40 [info     ] DiscreteCQL_20220209185937: epoch=8 step=40 epoch=8 metrics={'time_sample_batch': 0.0004001140594482422, 'time_algorithm_update': 0.023200321197509765, 'loss': 1.2418326377868651, 'time_step': 0.024199819564819335, 'environment': 2.3939833168328315, 'advantage': -0.20854177662368678, 'td_error': 0.1453568712783806, 'value_scale': 0.2321768926922232} step=40
2022-02-09 18:59.40 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_40.pt


Epoch 9/50: 100%|██████████| 5/5 [00:00<00:00, 41.67it/s, loss=1.18]


2022-02-09 18:59.41 [info     ] DiscreteCQL_20220209185937: epoch=9 step=45 epoch=9 metrics={'time_sample_batch': 0.0006002426147460937, 'time_algorithm_update': 0.021999263763427736, 'loss': 1.2133164167404176, 'time_step': 0.023200130462646483, 'environment': 2.3979133516616797, 'advantage': -0.22761010552362848, 'td_error': 0.15022337653171766, 'value_scale': 0.23211994472270212} step=45
2022-02-09 18:59.41 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_45.pt


Epoch 10/50: 100%|██████████| 5/5 [00:00<00:00, 38.76it/s, loss=1.18]


2022-02-09 18:59.41 [info     ] DiscreteCQL_20220209185937: epoch=10 step=50 epoch=10 metrics={'time_sample_batch': 0.0005976676940917969, 'time_algorithm_update': 0.023600244522094728, 'loss': 1.1808648347854613, 'time_step': 0.024997806549072264, 'environment': 2.5239058861226296, 'advantage': -0.24599030299612878, 'td_error': 0.15397539571565252, 'value_scale': 0.22731102475275597} step=50
2022-02-09 18:59.41 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_50.pt


Epoch 11/50: 100%|██████████| 5/5 [00:00<00:00, 50.00it/s, loss=1.18]


2022-02-09 18:59.42 [info     ] DiscreteCQL_20220209185937: epoch=11 step=55 epoch=11 metrics={'time_sample_batch': 0.000200653076171875, 'time_algorithm_update': 0.018999719619750978, 'loss': 1.1543416023254394, 'time_step': 0.01959962844848633, 'environment': 2.303796781685899, 'advantage': -0.26106660341278404, 'td_error': 0.15665815479635134, 'value_scale': 0.21784257345522443} step=55
2022-02-09 18:59.42 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_55.pt


Epoch 12/50: 100%|██████████| 5/5 [00:00<00:00, 50.00it/s, loss=1.16]


2022-02-09 18:59.42 [info     ] DiscreteCQL_20220209185937: epoch=12 step=60 epoch=12 metrics={'time_sample_batch': 0.0003983497619628906, 'time_algorithm_update': 0.01800055503845215, 'loss': 1.1156686067581176, 'time_step': 0.019799327850341795, 'environment': 2.61900526045974, 'advantage': -0.27760712843973123, 'td_error': 0.16091622624444804, 'value_scale': 0.20811724519201866} step=60
2022-02-09 18:59.42 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_60.pt


Epoch 13/50: 100%|██████████| 5/5 [00:00<00:00, 40.98it/s, loss=1.11]


2022-02-09 18:59.42 [info     ] DiscreteCQL_20220209185937: epoch=13 step=65 epoch=13 metrics={'time_sample_batch': 0.00020170211791992188, 'time_algorithm_update': 0.022199344635009766, 'loss': 1.0996192932128905, 'time_step': 0.023000001907348633, 'environment': 2.5964912660022454, 'advantage': -0.2940192132780353, 'td_error': 0.16681136211463846, 'value_scale': 0.199290040996857} step=65
2022-02-09 18:59.42 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_65.pt


Epoch 14/50: 100%|██████████| 5/5 [00:00<00:00, 54.95it/s, loss=1.04]


2022-02-09 18:59.43 [info     ] DiscreteCQL_20220209185937: epoch=14 step=70 epoch=14 metrics={'time_sample_batch': 0.0004008769989013672, 'time_algorithm_update': 0.01679701805114746, 'loss': 1.0640998601913452, 'time_step': 0.01799921989440918, 'environment': 2.416475934853333, 'advantage': -0.311480436036547, 'td_error': 0.1738372278206081, 'value_scale': 0.19093265493089953} step=70
2022-02-09 18:59.43 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_70.pt


Epoch 15/50: 100%|██████████| 5/5 [00:00<00:00, 58.13it/s, loss=1.17]


2022-02-09 18:59.43 [info     ] DiscreteCQL_20220209185937: epoch=15 step=75 epoch=15 metrics={'time_sample_batch': 0.0, 'time_algorithm_update': 0.016201114654541014, 'loss': 1.0522493481636048, 'time_step': 0.016801786422729493, 'environment': 2.50861123188764, 'advantage': -0.3253743647280801, 'td_error': 0.18001059504753888, 'value_scale': 0.17837885716774812} step=75
2022-02-09 18:59.43 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_75.pt


Epoch 16/50: 100%|██████████| 5/5 [00:00<00:00, 58.82it/s, loss=1.09]


2022-02-09 18:59.44 [info     ] DiscreteCQL_20220209185937: epoch=16 step=80 epoch=16 metrics={'time_sample_batch': 0.0005993366241455079, 'time_algorithm_update': 0.015401649475097656, 'loss': 1.0329849004745484, 'time_step': 0.016800117492675782, 'environment': 2.827467884304314, 'advantage': -0.34065917623396963, 'td_error': 0.18747197493902, 'value_scale': 0.1658723036913822} step=80
2022-02-09 18:59.44 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_80.pt


Epoch 17/50: 100%|██████████| 5/5 [00:00<00:00, 56.18it/s, loss=1]


2022-02-09 18:59.44 [info     ] DiscreteCQL_20220209185937: epoch=17 step=85 epoch=17 metrics={'time_sample_batch': 0.00020036697387695312, 'time_algorithm_update': 0.01640353202819824, 'loss': 1.0072842359542846, 'time_step': 0.01740427017211914, 'environment': 2.605411953845163, 'advantage': -0.35353857048145193, 'td_error': 0.19466900670093423, 'value_scale': 0.15266565194663903} step=85
2022-02-09 18:59.44 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_85.pt


Epoch 18/50: 100%|██████████| 5/5 [00:00<00:00, 51.03it/s, loss=0.939]


2022-02-09 18:59.44 [info     ] DiscreteCQL_20220209185937: epoch=18 step=90 epoch=18 metrics={'time_sample_batch': 0.00019979476928710938, 'time_algorithm_update': 0.018398523330688477, 'loss': 0.9884159803390503, 'time_step': 0.019397687911987305, 'environment': 2.4633031142507718, 'advantage': -0.3688005416764182, 'td_error': 0.20338609984831635, 'value_scale': 0.14058895145232478} step=90
2022-02-09 18:59.44 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_90.pt


Epoch 19/50: 100%|██████████| 5/5 [00:00<00:00, 57.47it/s, loss=0.957]


2022-02-09 18:59.45 [info     ] DiscreteCQL_20220209185937: epoch=19 step=95 epoch=19 metrics={'time_sample_batch': 0.0, 'time_algorithm_update': 0.01660146713256836, 'loss': 0.9675394058227539, 'time_step': 0.01700129508972168, 'environment': 2.559111512045462, 'advantage': -0.37947049924571435, 'td_error': 0.2100964228006935, 'value_scale': 0.12208277871832252} step=95
2022-02-09 18:59.45 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_95.pt


Epoch 20/50: 100%|██████████| 5/5 [00:00<00:00, 64.93it/s, loss=0.893]


2022-02-09 18:59.45 [info     ] DiscreteCQL_20220209185937: epoch=20 step=100 epoch=20 metrics={'time_sample_batch': 0.00040063858032226565, 'time_algorithm_update': 0.014400291442871093, 'loss': 0.935209047794342, 'time_step': 0.015200185775756835, 'environment': 2.6854781061375155, 'advantage': -0.39382910474396865, 'td_error': 0.21880755636883484, 'value_scale': 0.10611687072863181} step=100
2022-02-09 18:59.45 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_100.pt


Epoch 21/50: 100%|██████████| 5/5 [00:00<00:00, 50.50it/s, loss=1]


2022-02-09 18:59.46 [info     ] DiscreteCQL_20220209185937: epoch=21 step=105 epoch=21 metrics={'time_sample_batch': 0.00040044784545898435, 'time_algorithm_update': 0.01860198974609375, 'loss': 0.9272896766662597, 'time_step': 0.01960282325744629, 'environment': 2.531906318333788, 'advantage': -0.40824639696130777, 'td_error': 0.22834092999828925, 'value_scale': 0.08774436723130445} step=105
2022-02-09 18:59.46 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_105.pt


Epoch 22/50: 100%|██████████| 5/5 [00:00<00:00, 35.97it/s, loss=0.841]


2022-02-09 18:59.46 [info     ] DiscreteCQL_20220209185937: epoch=22 step=110 epoch=22 metrics={'time_sample_batch': 0.001000690460205078, 'time_algorithm_update': 0.02540116310119629, 'loss': 0.9178928732872009, 'time_step': 0.027200984954833984, 'environment': 2.2941105553006094, 'advantage': -0.42488076999948676, 'td_error': 0.2412630412065937, 'value_scale': 0.07807979360222816} step=110
2022-02-09 18:59.46 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_110.pt


Epoch 23/50: 100%|██████████| 5/5 [00:00<00:00, 51.02it/s, loss=0.875]


2022-02-09 18:59.46 [info     ] DiscreteCQL_20220209185937: epoch=23 step=115 epoch=23 metrics={'time_sample_batch': 0.00020008087158203126, 'time_algorithm_update': 0.018599605560302733, 'loss': 0.8687935590744018, 'time_step': 0.019398880004882813, 'environment': 2.563136459473319, 'advantage': -0.4390053106302774, 'td_error': 0.25450018371699673, 'value_scale': 0.06881685888705154} step=115
2022-02-09 18:59.46 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_115.pt


Epoch 24/50: 100%|██████████| 5/5 [00:00<00:00, 41.66it/s, loss=0.846]


2022-02-09 18:59.47 [info     ] DiscreteCQL_20220209185937: epoch=24 step=120 epoch=24 metrics={'time_sample_batch': 0.0005990028381347656, 'time_algorithm_update': 0.02160158157348633, 'loss': 0.8712131261825562, 'time_step': 0.022800159454345704, 'environment': 2.3606530338757143, 'advantage': -0.455999249888132, 'td_error': 0.27001336960270095, 'value_scale': 0.06287954980507493} step=120
2022-02-09 18:59.47 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_120.pt


Epoch 25/50: 100%|██████████| 5/5 [00:00<00:00, 40.65it/s, loss=0.987]


2022-02-09 18:59.47 [info     ] DiscreteCQL_20220209185937: epoch=25 step=125 epoch=25 metrics={'time_sample_batch': 0.0003996372222900391, 'time_algorithm_update': 0.022600269317626952, 'loss': 0.8343440294265747, 'time_step': 0.023600387573242187, 'environment': 2.356552819293567, 'advantage': -0.4744851103854993, 'td_error': 0.2861025427422514, 'value_scale': 0.05708697771963974} step=125
2022-02-09 18:59.47 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_125.pt


Epoch 26/50: 100%|██████████| 5/5 [00:00<00:00, 60.97it/s, loss=0.829]


2022-02-09 18:59.48 [info     ] DiscreteCQL_20220209185937: epoch=26 step=130 epoch=26 metrics={'time_sample_batch': 0.0002009868621826172, 'time_algorithm_update': 0.015599298477172851, 'loss': 0.8542925357818604, 'time_step': 0.01640033721923828, 'environment': 2.5808341576951532, 'advantage': -0.48440292346762376, 'td_error': 0.2996251773997045, 'value_scale': 0.048023632572342954} step=130
2022-02-09 18:59.48 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_130.pt


Epoch 27/50: 100%|██████████| 5/5 [00:00<00:00, 50.00it/s, loss=0.909]


2022-02-09 18:59.48 [info     ] DiscreteCQL_20220209185937: epoch=27 step=135 epoch=27 metrics={'time_sample_batch': 0.00039944648742675783, 'time_algorithm_update': 0.018600988388061523, 'loss': 0.838256323337555, 'time_step': 0.019799184799194337, 'environment': 2.550910394185275, 'advantage': -0.49513515537066044, 'td_error': 0.3114679272569383, 'value_scale': 0.033759301993995905} step=135
2022-02-09 18:59.48 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_135.pt


Epoch 28/50: 100%|██████████| 5/5 [00:00<00:00, 47.62it/s, loss=0.839]


2022-02-09 18:59.49 [info     ] DiscreteCQL_20220209185937: epoch=28 step=140 epoch=28 metrics={'time_sample_batch': 0.0006012439727783204, 'time_algorithm_update': 0.01899876594543457, 'loss': 0.7920126080513, 'time_step': 0.020400238037109376, 'environment': 2.317340153329403, 'advantage': -0.5097893314071279, 'td_error': 0.3269030480588538, 'value_scale': 0.021848351539423067} step=140
2022-02-09 18:59.49 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_140.pt


Epoch 29/50: 100%|██████████| 5/5 [00:00<00:00, 40.32it/s, loss=0.809]


2022-02-09 18:59.49 [info     ] DiscreteCQL_20220209185937: epoch=29 step=145 epoch=29 metrics={'time_sample_batch': 0.0007991790771484375, 'time_algorithm_update': 0.022002315521240233, 'loss': 0.8081509828567505, 'time_step': 0.02400050163269043, 'environment': 2.6237699820556615, 'advantage': -0.523550950730144, 'td_error': 0.342414580732824, 'value_scale': 0.015255820394183198} step=145
2022-02-09 18:59.49 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_145.pt


Epoch 30/50: 100%|██████████| 5/5 [00:00<00:00, 53.76it/s, loss=0.86]


2022-02-09 18:59.49 [info     ] DiscreteCQL_20220209185937: epoch=30 step=150 epoch=30 metrics={'time_sample_batch': 0.0005992412567138672, 'time_algorithm_update': 0.01740102767944336, 'loss': 0.7942911863327027, 'time_step': 0.018400192260742188, 'environment': 2.4233296211254673, 'advantage': -0.5361999908100494, 'td_error': 0.3571670066069507, 'value_scale': 0.0050348408209780855} step=150
2022-02-09 18:59.49 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_150.pt


Epoch 31/50: 100%|██████████| 5/5 [00:00<00:00, 56.82it/s, loss=0.828]


2022-02-09 18:59.50 [info     ] DiscreteCQL_20220209185937: epoch=31 step=155 epoch=31 metrics={'time_sample_batch': 0.00019998550415039061, 'time_algorithm_update': 0.01659855842590332, 'loss': 0.7695378184318542, 'time_step': 0.017400121688842772, 'environment': 2.242138485332037, 'advantage': -0.5516271277350336, 'td_error': 0.373408627661572, 'value_scale': -0.0020575574599206448} step=155
2022-02-09 18:59.50 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_155.pt


Epoch 32/50: 100%|██████████| 5/5 [00:00<00:00, 42.74it/s, loss=0.673]


2022-02-09 18:59.50 [info     ] DiscreteCQL_20220209185937: epoch=32 step=160 epoch=32 metrics={'time_sample_batch': 0.0003989696502685547, 'time_algorithm_update': 0.021600818634033202, 'loss': 0.7695267915725708, 'time_step': 0.022799873352050783, 'environment': 2.429866336985332, 'advantage': -0.5710833771762668, 'td_error': 0.39432256595643206, 'value_scale': -0.003417775073709587} step=160
2022-02-09 18:59.50 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_160.pt


Epoch 33/50: 100%|██████████| 5/5 [00:00<00:00, 55.56it/s, loss=0.744]


2022-02-09 18:59.51 [info     ] DiscreteCQL_20220209185937: epoch=33 step=165 epoch=33 metrics={'time_sample_batch': 0.00040106773376464845, 'time_algorithm_update': 0.016799259185791015, 'loss': 0.7467447876930237, 'time_step': 0.017800378799438476, 'environment': 2.355121215827924, 'advantage': -0.5816635796890203, 'td_error': 0.40858772782102354, 'value_scale': -0.008489362973098954} step=165
2022-02-09 18:59.51 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_165.pt


Epoch 34/50: 100%|██████████| 5/5 [00:00<00:00, 40.33it/s, loss=0.849]


2022-02-09 18:59.51 [info     ] DiscreteCQL_20220209185937: epoch=34 step=170 epoch=34 metrics={'time_sample_batch': 0.0003983020782470703, 'time_algorithm_update': 0.023200321197509765, 'loss': 0.7260936021804809, 'time_step': 0.023998785018920898, 'environment': 2.55160522724411, 'advantage': -0.5964332093116097, 'td_error': 0.42494026510849725, 'value_scale': -0.009869503478209177} step=170
2022-02-09 18:59.51 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_170.pt


Epoch 35/50: 100%|██████████| 5/5 [00:00<00:00, 60.24it/s, loss=0.633]


2022-02-09 18:59.51 [info     ] DiscreteCQL_20220209185937: epoch=35 step=175 epoch=35 metrics={'time_sample_batch': 0.0005992412567138672, 'time_algorithm_update': 0.015401601791381836, 'loss': 0.7502642154693604, 'time_step': 0.016400718688964845, 'environment': 2.5735478416081556, 'advantage': -0.6099976613037403, 'td_error': 0.4413972315519303, 'value_scale': -0.015569684600147108} step=175
2022-02-09 18:59.51 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_175.pt


Epoch 36/50: 100%|██████████| 5/5 [00:00<00:00, 54.35it/s, loss=0.741]


2022-02-09 18:59.52 [info     ] DiscreteCQL_20220209185937: epoch=36 step=180 epoch=36 metrics={'time_sample_batch': 0.0, 'time_algorithm_update': 0.01700143814086914, 'loss': 0.741123104095459, 'time_step': 0.017801332473754882, 'environment': 2.6499370202097343, 'advantage': -0.6225680311916092, 'td_error': 0.45493489137704174, 'value_scale': -0.0266127057839185} step=180
2022-02-09 18:59.52 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_180.pt


Epoch 37/50: 100%|██████████| 5/5 [00:00<00:00, 64.94it/s, loss=0.732]


2022-02-09 18:59.52 [info     ] DiscreteCQL_20220209185937: epoch=37 step=185 epoch=37 metrics={'time_sample_batch': 0.00059967041015625, 'time_algorithm_update': 0.014199924468994141, 'loss': 0.7294121384620667, 'time_step': 0.015199232101440429, 'environment': 2.7797088264626755, 'advantage': -0.6317156360242522, 'td_error': 0.4680606285922351, 'value_scale': -0.03930478279168407} step=185
2022-02-09 18:59.52 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_185.pt


Epoch 38/50: 100%|██████████| 5/5 [00:00<00:00, 58.82it/s, loss=0.769]


2022-02-09 18:59.52 [info     ] DiscreteCQL_20220209185937: epoch=38 step=190 epoch=38 metrics={'time_sample_batch': 0.00020022392272949218, 'time_algorithm_update': 0.015801191329956055, 'loss': 0.7194191336631774, 'time_step': 0.01680026054382324, 'environment': 2.299916998464419, 'advantage': -0.6436999865632355, 'td_error': 0.4832278063108905, 'value_scale': -0.0490977797890082} step=190
2022-02-09 18:59.52 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_190.pt


Epoch 39/50: 100%|██████████| 5/5 [00:00<00:00, 58.14it/s, loss=0.723]


2022-02-09 18:59.53 [info     ] DiscreteCQL_20220209185937: epoch=39 step=195 epoch=39 metrics={'time_sample_batch': 0.0004015445709228516, 'time_algorithm_update': 0.0160006046295166, 'loss': 0.6768136143684387, 'time_step': 0.017000961303710937, 'environment': 2.644752590645898, 'advantage': -0.6645197476695229, 'td_error': 0.5024457377559184, 'value_scale': -0.04801429237704724} step=195
2022-02-09 18:59.53 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_195.pt


Epoch 40/50: 100%|██████████| 5/5 [00:00<00:00, 61.73it/s, loss=0.714]


2022-02-09 18:59.53 [info     ] DiscreteCQL_20220209185937: epoch=40 step=200 epoch=40 metrics={'time_sample_batch': 0.00019965171813964843, 'time_algorithm_update': 0.01520085334777832, 'loss': 0.7339321255683899, 'time_step': 0.01600050926208496, 'environment': 2.524488545643539, 'advantage': -0.668978622665656, 'td_error': 0.5151392164732504, 'value_scale': -0.055174327843512096} step=200
2022-02-09 18:59.53 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_200.pt


Epoch 41/50: 100%|██████████| 5/5 [00:00<00:00, 60.98it/s, loss=0.784]


2022-02-09 18:59.53 [info     ] DiscreteCQL_20220209185937: epoch=41 step=205 epoch=41 metrics={'time_sample_batch': 0.00059967041015625, 'time_algorithm_update': 0.0152008056640625, 'loss': 0.7142581701278686, 'time_step': 0.016399860382080078, 'environment': 2.2930646442995397, 'advantage': -0.6686818776329492, 'td_error': 0.5271105365383851, 'value_scale': -0.05996533673411856} step=205
2022-02-09 18:59.53 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_205.pt


Epoch 42/50: 100%|██████████| 5/5 [00:00<00:00, 36.50it/s, loss=0.722]


2022-02-09 18:59.54 [info     ] DiscreteCQL_20220209185937: epoch=42 step=210 epoch=42 metrics={'time_sample_batch': 0.00059967041015625, 'time_algorithm_update': 0.025200033187866212, 'loss': 0.7076627850532532, 'time_step': 0.026599645614624023, 'environment': 2.4649189881032836, 'advantage': -0.6770968716029248, 'td_error': 0.5418872149974128, 'value_scale': -0.05604932003188878} step=210
2022-02-09 18:59.54 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_210.pt


Epoch 43/50: 100%|██████████| 5/5 [00:00<00:00, 51.54it/s, loss=0.692]


2022-02-09 18:59.54 [info     ] DiscreteCQL_20220209185937: epoch=43 step=215 epoch=43 metrics={'time_sample_batch': 0.0006010055541992188, 'time_algorithm_update': 0.01779966354370117, 'loss': 0.7208712160587311, 'time_step': 0.018999671936035155, 'environment': 2.270579888178195, 'advantage': -0.6868361571488165, 'td_error': 0.5569573165403128, 'value_scale': -0.05055833196577927} step=215
2022-02-09 18:59.54 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_215.pt


Epoch 44/50: 100%|██████████| 5/5 [00:00<00:00, 55.56it/s, loss=0.716]


2022-02-09 18:59.55 [info     ] DiscreteCQL_20220209185937: epoch=44 step=220 epoch=44 metrics={'time_sample_batch': 0.00019931793212890625, 'time_algorithm_update': 0.017001628875732422, 'loss': 0.7174609541893006, 'time_step': 0.017800378799438476, 'environment': 2.3742181862117206, 'advantage': -0.6961373716529492, 'td_error': 0.5706833587564818, 'value_scale': -0.0550679579610005} step=220
2022-02-09 18:59.55 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_220.pt


Epoch 45/50: 100%|██████████| 5/5 [00:00<00:00, 46.30it/s, loss=0.967]


2022-02-09 18:59.55 [info     ] DiscreteCQL_20220209185937: epoch=45 step=225 epoch=45 metrics={'time_sample_batch': 0.00020017623901367188, 'time_algorithm_update': 0.019800949096679687, 'loss': 0.7088926315307618, 'time_step': 0.020600605010986327, 'environment': 2.2178371565504453, 'advantage': -0.7032900879297189, 'td_error': 0.5828470115737918, 'value_scale': -0.05383357258203129} step=225
2022-02-09 18:59.55 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_225.pt


Epoch 46/50: 100%|██████████| 5/5 [00:00<00:00, 58.14it/s, loss=0.531]


2022-02-09 18:59.55 [info     ] DiscreteCQL_20220209185937: epoch=46 step=230 epoch=46 metrics={'time_sample_batch': 0.0003997802734375, 'time_algorithm_update': 0.016200923919677736, 'loss': 0.6791933298110961, 'time_step': 0.017000818252563478, 'environment': 2.5459647501525984, 'advantage': -0.7152085772929776, 'td_error': 0.5973558807330303, 'value_scale': -0.05092676406881461} step=230
2022-02-09 18:59.55 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_230.pt


Epoch 47/50: 100%|██████████| 5/5 [00:00<00:00, 52.63it/s, loss=0.491]


2022-02-09 18:59.56 [info     ] DiscreteCQL_20220209185937: epoch=47 step=235 epoch=47 metrics={'time_sample_batch': 0.0006009101867675781, 'time_algorithm_update': 0.0175994873046875, 'loss': 0.695578682422638, 'time_step': 0.01900019645690918, 'environment': 2.7379831896440656, 'advantage': -0.7124006769580314, 'td_error': 0.6060405193533042, 'value_scale': -0.0577136316181471} step=235
2022-02-09 18:59.56 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_235.pt


Epoch 48/50: 100%|██████████| 5/5 [00:00<00:00, 57.47it/s, loss=0.755]


2022-02-09 18:59.56 [info     ] DiscreteCQL_20220209185937: epoch=48 step=240 epoch=48 metrics={'time_sample_batch': 0.0005999088287353515, 'time_algorithm_update': 0.0160006046295166, 'loss': 0.6709317147731781, 'time_step': 0.017200326919555663, 'environment': 2.4758073565326795, 'advantage': -0.7155714972054082, 'td_error': 0.6166489564885053, 'value_scale': -0.060230982451078795} step=240
2022-02-09 18:59.56 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_240.pt


Epoch 49/50: 100%|██████████| 5/5 [00:00<00:00, 62.50it/s, loss=0.885]


2022-02-09 18:59.56 [info     ] DiscreteCQL_20220209185937: epoch=49 step=245 epoch=49 metrics={'time_sample_batch': 0.0, 'time_algorithm_update': 0.015000009536743164, 'loss': 0.6932828068733216, 'time_step': 0.015600252151489257, 'environment': 2.2020624742118207, 'advantage': -0.7207827015588982, 'td_error': 0.6257488033530029, 'value_scale': -0.060410601630186044} step=245
2022-02-09 18:59.57 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_245.pt


Epoch 50/50: 100%|██████████| 5/5 [00:00<00:00, 57.47it/s, loss=0.552]


2022-02-09 18:59.57 [info     ] DiscreteCQL_20220209185937: epoch=50 step=250 epoch=50 metrics={'time_sample_batch': 0.0008016586303710938, 'time_algorithm_update': 0.015799236297607423, 'loss': 0.6844325304031372, 'time_step': 0.017000198364257812, 'environment': 2.0041963843238486, 'advantage': -0.7317800378155631, 'td_error': 0.6386047639419262, 'value_scale': -0.06021358259022236} step=250
2022-02-09 18:59.57 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220209185937\model_250.pt


In [5]:
obs = env.reset()

ds = CreateDataset()
ds.loadFile('data.json')

observations, actions, rewards = ds.createEpisodeDataset()
answers = {}
for situation in observations:

    situation = np.array([ds.PASS, ds.CARRY, ds.CARRY, ds.CARRY, ds.CARRY], dtype=np.float32)
    predictions = cql.predict([situation])[0]

    p = ds.ID_to_str[predictions]
    if not (p in answers): answers[p] = 1
    else: answers[p] += 1

    # print(ds.ID_to_str[predictions])

In [6]:
counts = {}

for item in actions:
    if not (item in counts): counts[item] = 1
    else: counts[item] += 1

counts


{2.0: 105, 0.0: 168, 3.0: 7, 1.0: 1}