Import the required libraries

In [1]:
from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DiscreteCQL, DQN
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.dataset import Episode
from d3rlpy.ope import DiscreteFQE
from d3rlpy.dataset import MDPDataset

from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split

import import_ipynb
import numpy as np
from random import random
from create_dataset import CreateDataset
from FootballEnv import FootballEnv

from view import Visualiser

# metrics to evaluate with
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer
from d3rlpy.metrics.scorer import soft_opc_scorer


importing Jupyter notebook from FootballEnv.ipynb


Helper function to create a dummy dataset

In [3]:
def create_dataset():

    dataset_maker = CreateDataset()
    # dataset_maker.loadFile('data.json')
    dataset_maker.loadFilesFromDir('events/*.json')
    observations, actions, rewards = dataset_maker.createEpisodeDataset()
    terminals = np.array([ 0 if (i+1) % dataset_maker.lim == 0 else 1 for i in range(len(actions)) ])
    return MDPDataset(
        observations,
        actions,
        rewards, 
        terminals,
    ), observations

In [4]:
dataset, observations = create_dataset()
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2, shuffle=False)


In [5]:
len(train_episodes)

189371

In [6]:
# setup CQL algorithm
cql = DiscreteCQL(use_gpu=False, batch_size=32,)

# env = FootballEnv(observations)
# env.counter = 0

# start training
output = cql.fit(

    train_episodes,
    # eval_episodes=test_episodes,
    n_epochs=25,
    
    scorers={
        # 'environment': evaluate_on_environment(env), # evaluate with Football Env
        'advantage': discounted_sum_of_advantage_scorer, # smaller is better
        'td_error': td_error_scorer, # smaller is better
        'value_scale': average_value_estimation_scorer # smaller is better
    }
    
)

2022-02-10 22:35.10 [debug    ] RoundIterator is selected.
2022-02-10 22:35.10 [info     ] Directory is created at d3rlpy_logs\DiscreteCQL_20220210223510
2022-02-10 22:35.10 [debug    ] Building models...
2022-02-10 22:35.10 [debug    ] Models have been built.
2022-02-10 22:35.10 [info     ] Parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'bootstrap': False, 'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_reduction_type': 'min', 'target_update_interval': 8000, 'use_

Epoch 1/25: 100%|██████████| 1479/1479 [00:16<00:00, 89.31it/s, loss=0.913]

2022-02-10 22:35.27 [info     ] DiscreteCQL_20220210223510: epoch=1 step=1479 epoch=1 metrics={'time_sample_batch': 0.00024837735055506917, 'time_algorithm_update': 0.01050935652386586, 'loss': 0.9124909420727555, 'time_step': 0.011013241213991321} step=1479
2022-02-10 22:35.27 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_1479.pt



Epoch 2/25: 100%|██████████| 1479/1479 [00:20<00:00, 70.90it/s, loss=0.842]

2022-02-10 22:35.48 [info     ] DiscreteCQL_20220210223510: epoch=2 step=2958 epoch=2 metrics={'time_sample_batch': 0.00027618608094293415, 'time_algorithm_update': 0.013266327898594235, 'loss': 0.8422214590552048, 'time_step': 0.013844084788045824} step=2958
2022-02-10 22:35.48 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_2958.pt



Epoch 3/25: 100%|██████████| 1479/1479 [00:22<00:00, 65.43it/s, loss=0.82] 

2022-02-10 22:36.10 [info     ] DiscreteCQL_20220210223510: epoch=3 step=4437 epoch=3 metrics={'time_sample_batch': 0.00030936076078485846, 'time_algorithm_update': 0.014392450421625413, 'loss': 0.8201054230060345, 'time_step': 0.015002262600375802} step=4437
2022-02-10 22:36.10 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_4437.pt



Epoch 4/25: 100%|██████████| 1479/1479 [00:21<00:00, 68.68it/s, loss=0.806]

2022-02-10 22:36.32 [info     ] DiscreteCQL_20220210223510: epoch=4 step=5916 epoch=4 metrics={'time_sample_batch': 0.0002975189822843705, 'time_algorithm_update': 0.013697276976239126, 'loss': 0.8060149976415679, 'time_step': 0.01429640851849394} step=5916
2022-02-10 22:36.32 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_5916.pt



Epoch 5/25: 100%|██████████| 1479/1479 [00:18<00:00, 81.40it/s, loss=0.796]

2022-02-10 22:36.50 [info     ] DiscreteCQL_20220210223510: epoch=5 step=7395 epoch=5 metrics={'time_sample_batch': 0.0002759013972304978, 'time_algorithm_update': 0.01152301221541894, 'loss': 0.7952425727737844, 'time_step': 0.012071344480069607} step=7395
2022-02-10 22:36.50 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_7395.pt



Epoch 6/25: 100%|██████████| 1479/1479 [00:19<00:00, 75.19it/s, loss=0.788]

2022-02-10 22:37.10 [info     ] DiscreteCQL_20220210223510: epoch=6 step=8874 epoch=6 metrics={'time_sample_batch': 0.00027870745010840237, 'time_algorithm_update': 0.012505972038821, 'loss': 0.7879107959009329, 'time_step': 0.013061870506601288} step=8874
2022-02-10 22:37.10 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_8874.pt



Epoch 7/25: 100%|██████████| 1479/1479 [00:18<00:00, 81.71it/s, loss=0.781] 

2022-02-10 22:37.28 [info     ] DiscreteCQL_20220210223510: epoch=7 step=10353 epoch=7 metrics={'time_sample_batch': 0.0002913686211968371, 'time_algorithm_update': 0.011481201108683598, 'loss': 0.7809068537628917, 'time_step': 0.01203758705299073} step=10353
2022-02-10 22:37.28 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_10353.pt



Epoch 8/25: 100%|██████████| 1479/1479 [00:17<00:00, 82.71it/s, loss=0.774]

2022-02-10 22:37.46 [info     ] DiscreteCQL_20220210223510: epoch=8 step=11832 epoch=8 metrics={'time_sample_batch': 0.00026707765296796757, 'time_algorithm_update': 0.011343597962777303, 'loss': 0.7744346448825452, 'time_step': 0.011889575058096884} step=11832
2022-02-10 22:37.46 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_11832.pt



Epoch 9/25: 100%|██████████| 1479/1479 [00:20<00:00, 72.46it/s, loss=0.771] 

2022-02-10 22:38.06 [info     ] DiscreteCQL_20220210223510: epoch=9 step=13311 epoch=9 metrics={'time_sample_batch': 0.00030150245862526216, 'time_algorithm_update': 0.012971442966061399, 'loss': 0.7705032294106048, 'time_step': 0.013554860483237905} step=13311
2022-02-10 22:38.06 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_13311.pt



Epoch 10/25: 100%|██████████| 1479/1479 [00:16<00:00, 92.28it/s, loss=0.766] 

2022-02-10 22:38.22 [info     ] DiscreteCQL_20220210223510: epoch=10 step=14790 epoch=10 metrics={'time_sample_batch': 0.00025283427880050526, 'time_algorithm_update': 0.010150696314050186, 'loss': 0.7664786595964529, 'time_step': 0.010650309620071215} step=14790
2022-02-10 22:38.22 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_14790.pt



Epoch 11/25: 100%|██████████| 1479/1479 [00:18<00:00, 79.54it/s, loss=0.763]

2022-02-10 22:38.41 [info     ] DiscreteCQL_20220210223510: epoch=11 step=16269 epoch=11 metrics={'time_sample_batch': 0.00026706298353544676, 'time_algorithm_update': 0.011814503189355155, 'loss': 0.762570316320823, 'time_step': 0.012358577288510914} step=16269
2022-02-10 22:38.41 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_16269.pt



Epoch 12/25: 100%|██████████| 1479/1479 [00:17<00:00, 83.25it/s, loss=0.759] 

2022-02-10 22:38.59 [info     ] DiscreteCQL_20220210223510: epoch=12 step=17748 epoch=12 metrics={'time_sample_batch': 0.00026697238769943926, 'time_algorithm_update': 0.011273074746534903, 'loss': 0.7590705197634449, 'time_step': 0.011806825756462464} step=17748
2022-02-10 22:38.59 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_17748.pt



Epoch 13/25: 100%|██████████| 1479/1479 [00:17<00:00, 85.65it/s, loss=0.756] 

2022-02-10 22:39.16 [info     ] DiscreteCQL_20220210223510: epoch=13 step=19227 epoch=13 metrics={'time_sample_batch': 0.0002507668560304055, 'time_algorithm_update': 0.010973243152394015, 'loss': 0.7558905270310815, 'time_step': 0.01148477077806536} step=19227
2022-02-10 22:39.16 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_19227.pt



Epoch 14/25: 100%|██████████| 1479/1479 [00:17<00:00, 85.17it/s, loss=0.753] 

2022-02-10 22:39.33 [info     ] DiscreteCQL_20220210223510: epoch=14 step=20706 epoch=14 metrics={'time_sample_batch': 0.0002663398288729386, 'time_algorithm_update': 0.01100406443613296, 'loss': 0.7528524690983343, 'time_step': 0.011533448468840872} step=20706





2022-02-10 22:39.33 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_20706.pt


Epoch 15/25: 100%|██████████| 1479/1479 [00:17<00:00, 85.41it/s, loss=0.751] 

2022-02-10 22:39.51 [info     ] DiscreteCQL_20220210223510: epoch=15 step=22185 epoch=15 metrics={'time_sample_batch': 0.00026770408609737165, 'time_algorithm_update': 0.010963162511808797, 'loss': 0.7509245917835777, 'time_step': 0.011502663294474283} step=22185
2022-02-10 22:39.51 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_22185.pt



Epoch 16/25: 100%|██████████| 1479/1479 [00:17<00:00, 85.43it/s, loss=0.749]

2022-02-10 22:40.08 [info     ] DiscreteCQL_20220210223510: epoch=16 step=23664 epoch=16 metrics={'time_sample_batch': 0.0002706129862054769, 'time_algorithm_update': 0.010953262579287928, 'loss': 0.7488821311686956, 'time_step': 0.011491643810207735} step=23664
2022-02-10 22:40.08 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_23664.pt



Epoch 17/25: 100%|██████████| 1479/1479 [00:17<00:00, 86.48it/s, loss=0.746] 


2022-02-10 22:40.25 [info     ] DiscreteCQL_20220210223510: epoch=17 step=25143 epoch=17 metrics={'time_sample_batch': 0.0002650223203895058, 'time_algorithm_update': 0.01085420763919771, 'loss': 0.7459660753597674, 'time_step': 0.011379512763652067} step=25143
2022-02-10 22:40.25 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_25143.pt


Epoch 18/25: 100%|██████████| 1479/1479 [00:15<00:00, 93.48it/s, loss=0.743] 


2022-02-10 22:40.41 [info     ] DiscreteCQL_20220210223510: epoch=18 step=26622 epoch=18 metrics={'time_sample_batch': 0.0002602974734973714, 'time_algorithm_update': 0.009999189125355069, 'loss': 0.7435592926983257, 'time_step': 0.010508689628895108} step=26622
2022-02-10 22:40.41 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_26622.pt


Epoch 19/25: 100%|██████████| 1479/1479 [00:16<00:00, 88.72it/s, loss=0.742] 

2022-02-10 22:40.58 [info     ] DiscreteCQL_20220210223510: epoch=19 step=28101 epoch=19 metrics={'time_sample_batch': 0.0002549408737915122, 'time_algorithm_update': 0.010589593806083014, 'loss': 0.7418182923876006, 'time_step': 0.011089981851584206} step=28101
2022-02-10 22:40.58 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_28101.pt



Epoch 20/25: 100%|██████████| 1479/1479 [00:15<00:00, 93.42it/s, loss=0.74]  


2022-02-10 22:41.13 [info     ] DiscreteCQL_20220210223510: epoch=20 step=29580 epoch=20 metrics={'time_sample_batch': 0.0002366110148504024, 'time_algorithm_update': 0.01003821513554468, 'loss': 0.7397585918445535, 'time_step': 0.010520894919157512} step=29580
2022-02-10 22:41.13 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_29580.pt


Epoch 21/25: 100%|██████████| 1479/1479 [00:17<00:00, 85.97it/s, loss=0.738] 

2022-02-10 22:41.31 [info     ] DiscreteCQL_20220210223510: epoch=21 step=31059 epoch=21 metrics={'time_sample_batch': 0.00026775502610480647, 'time_algorithm_update': 0.010899668532984736, 'loss': 0.7374060257533823, 'time_step': 0.011431127081375173} step=31059
2022-02-10 22:41.31 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_31059.pt



Epoch 22/25: 100%|██████████| 1479/1479 [00:16<00:00, 92.02it/s, loss=0.736] 

2022-02-10 22:41.47 [info     ] DiscreteCQL_20220210223510: epoch=22 step=32538 epoch=22 metrics={'time_sample_batch': 0.0002521630313607629, 'time_algorithm_update': 0.010181592234572177, 'loss': 0.7357750963004075, 'time_step': 0.010690739220908721} step=32538
2022-02-10 22:41.47 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_32538.pt



Epoch 23/25: 100%|██████████| 1479/1479 [00:15<00:00, 97.80it/s, loss=0.734] 

2022-02-10 22:42.02 [info     ] DiscreteCQL_20220210223510: epoch=23 step=34017 epoch=23 metrics={'time_sample_batch': 0.0002417957726324467, 'time_algorithm_update': 0.009567757246057434, 'loss': 0.7336826150364131, 'time_step': 0.010049253641712738} step=34017
2022-02-10 22:42.02 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_34017.pt



Epoch 24/25: 100%|██████████| 1479/1479 [00:18<00:00, 78.96it/s, loss=0.733]

2022-02-10 22:42.21 [info     ] DiscreteCQL_20220210223510: epoch=24 step=35496 epoch=24 metrics={'time_sample_batch': 0.0002907247781914743, 'time_algorithm_update': 0.011910193832118903, 'loss': 0.7326045680521629, 'time_step': 0.012458543506645528} step=35496
2022-02-10 22:42.21 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_35496.pt



Epoch 25/25: 100%|██████████| 1479/1479 [00:19<00:00, 76.73it/s, loss=0.731]

2022-02-10 22:42.40 [info     ] DiscreteCQL_20220210223510: epoch=25 step=36975 epoch=25 metrics={'time_sample_batch': 0.0002695959592848875, 'time_algorithm_update': 0.012247019800646552, 'loss': 0.7306796889638804, 'time_step': 0.01279920496756842} step=36975
2022-02-10 22:42.40 [info     ] Model parameters are saved to d3rlpy_logs\DiscreteCQL_20220210223510\model_36975.pt





In [7]:
output

[(1,
  {'time_sample_batch': 0.0010013580322265625,
   'time_algorithm_update': 0.015999555587768555,
   'loss': 2.0833652019500732,
   'time_step': 0.018000125885009766}),
 (2,
  {'time_sample_batch': 0.0,
   'time_algorithm_update': 0.00800013542175293,
   'loss': 1.9851281642913818,
   'time_step': 0.008998632431030273}),
 (3,
  {'time_sample_batch': 0.0,
   'time_algorithm_update': 0.010997533798217773,
   'loss': 2.0170738697052,
   'time_step': 0.011996984481811523}),
 (4,
  {'time_sample_batch': 0.0009968280792236328,
   'time_algorithm_update': 0.004998683929443359,
   'loss': 1.900363564491272,
   'time_step': 0.007997274398803711}),
 (5,
  {'time_sample_batch': 0.0,
   'time_algorithm_update': 0.008998870849609375,
   'loss': 1.7809127569198608,
   'time_step': 0.009999752044677734}),
 (6,
  {'time_sample_batch': 0.0,
   'time_algorithm_update': 0.007996320724487305,
   'loss': 1.7211077213287354,
   'time_step': 0.009997844696044922}),
 (7,
  {'time_sample_batch': 0.0,
   't

In [None]:
# TODO: MAKE THIS WORK!

# off-policy evaluation algorithm
fqe = DiscreteFQE(algo=cql)

# train estimators to evaluate the trained policy
fqe.fit(test_episodes,
   eval_episodes=test_episodes,
   n_epochs=50,
   scorers={
      'init_value': initial_state_value_estimation_scorer,
      'soft_opc': soft_opc_scorer(return_threshold=600)
   }
)

Load Saved Model

In [2]:
path =  "C:\\Users\\micha\\Documents\\Masters\\Football-RL\\d3rlpy_logs\\DiscreteCQL_20220210223510\\"
m2 = DiscreteCQL.from_json(f'{path}params.json')

# ready to load
m2.load_model(f'{path}model_36975.pt')



Visualise predictions

In [4]:
ds = CreateDataset()
ds.loadFile('data.json')

visualiser = Visualiser()


observations, actions, rewards = ds.createEpisodeDataset()
answers = {}
for situation in observations:

    # print(situation)
    predictions = m2.predict([situation])[0]

    p = ds.ID_to_str[predictions]
    if not (p in answers): answers[p] = 1
    else: answers[p] += 1

    if(p == "shot"):
        visualiser.visualise_sequence(situation, 3, predictions)
    
    if(p == "clearance"):
        visualiser.visualise_sequence(situation, 3, predictions)

print("finished")
answers

finished


{'carry': 126, 'pass': 151, 'shot': 3, 'clearance': 1}

In [9]:
counts = {}

for item in actions:
    if not (item in counts): counts[item] = 1
    else: counts[item] += 1

counts


{2.0: 105, 0.0: 168, 3.0: 7, 1.0: 1}