# Sample Workflow for d3rlpy Experiments

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [2]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_det_small.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data=path)
    states = []
    actions = []
    rewards = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    terminals = np.zeros(len(states))
    terminals[::100] = 1 #episode length 100, change if necessary
    print(states.shape)
    dataset = d3rlpy.dataset.MDPDataset(states.numpy(), 
                                        actions.numpy(), 
                                        rewards.numpy(), terminals)
    return dataset, states.numpy(), actions.numpy(), rewards.numpy()

We can build the dataset from there, just like this, and split into train and test sets.

In [3]:
dataset, states, actions, rewards = get_dataset([i for i in range(998)], path="collected_data/rl_deterministic.txt")

start
[ 0.00000000e+00  7.95731469e+08 -7.69891077e-02  4.00000469e-03
  5.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.50295370e-01 -2.41931634e-01  6.00000000e-01]
Read chunk # 1 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.47410892e-01 -8.79999531e-03
  2.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.86362318e-02  3.75821673e-01 -6.00000000e-01]
Read chunk # 2 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.03891077e-02 -1.41999953e-02
 -2.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.78778459e-03 -1.34615461e-02  4.84073546e-02]
Read chunk # 3 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -8.17891077e-02 -1.19999531e-03
  7.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.09713430e-01 -2.63658359e-01  6.00000000e-01]
Read chunk # 4 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.24891077e-02 -1.35999953e-02
 -4.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.2331

[ 0.00000000e+00  7.95731469e+08 -1.24891077e-02  6.40000469e-03
  1.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.89056008e-02 -4.28655416e-02  1.34176685e-01]
Read chunk # 41 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.35110892e-01  6.00004692e-04
  3.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  8.62448158e-02  3.37467545e-01 -6.00000000e-01]
Read chunk # 42 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.25889108e-01  8.60000469e-03
  4.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.04897050e-02 -3.61685289e-01  6.00000000e-01]
Read chunk # 43 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.54108923e-02  9.20000469e-03
  7.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.79504108e-01  7.04170084e-02 -3.91330857e-01]
Read chunk # 44 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.00989108e-01 -1.45999953e-02
 -3.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.005237

[ 0.00000000e+00  7.95731469e+08 -1.41789108e-01 -1.99995308e-04
  6.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.89125937e-01 -4.18590168e-01  6.00000000e-01]
Read chunk # 82 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  6.10108923e-02  1.32000047e-02
  7.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.44237123e-01  1.16216092e-01 -5.22459650e-01]
Read chunk # 83 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.11710892e-01 -1.49999953e-02
 -7.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.39860767e-01  3.42562745e-01 -6.00000000e-01]
Read chunk # 84 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.99891077e-02 -1.91999953e-02
  3.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.62080602e-01 -7.23826969e-02  1.20921962e-01]
Read chunk # 85 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.06108923e-02 -1.19999531e-03
 -9.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.611164

[ 0.00000000e+00  7.95731469e+08  4.60108923e-02  1.68000047e-02
  6.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.05018950e-01  8.23167968e-02 -3.72379504e-01]
Read chunk # 121 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -3.10891077e-02 -1.15999953e-02
  3.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.40217605e-01 -1.03693925e-01  2.47887223e-01]
Read chunk # 122 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -3.82891077e-02  1.34000047e-02
 -3.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.47363548e-01 -7.89246193e-02  3.94918266e-01]
Read chunk # 123 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -5.68891077e-02 -5.79999531e-03
 -1.00000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.53932923e-01 -8.56160479e-02  5.04761879e-01]
Read chunk # 124 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.37889108e-01  1.74000047e-02
 -7.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.82

[ 0.00000000e+00  7.95731469e+08 -1.15389108e-01  2.20000469e-03
 -3.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -9.45911398e-02 -2.84914456e-01  6.00000000e-01]
Read chunk # 161 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -4.69891077e-02 -1.81999953e-02
  9.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.99492982e-02 -1.30307889e-01  3.73407266e-01]
Read chunk # 162 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.09891077e-02 -8.79999531e-03
  5.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.95571791e-01 -6.59640509e-02  7.13476936e-02]
Read chunk # 163 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  5.77108923e-02 -1.45999953e-02
  6.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.51832656e-01  1.08126650e-01 -5.81577960e-01]
Read chunk # 164 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.33891077e-02  1.40000047e-02
  8.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.93

[ 0.00000000e+00  7.95731469e+08 -8.74891077e-02  2.20000469e-03
 -5.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.69370000e-01 -1.93671880e-01  6.00000000e-01]
Read chunk # 199 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.09891077e-02  5.80000469e-03
 -9.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.88906207e-01 -1.25511701e-01  6.00000000e-01]
Read chunk # 200 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.15891077e-02 -1.31999953e-02
  4.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.74265428e-01 -5.85292123e-02  6.27081420e-02]
Read chunk # 201 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -9.61891077e-02  7.80000469e-03
 -6.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.02680172e-01 -2.14119747e-01  6.00000000e-01]
Read chunk # 202 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.44489108e-01  1.82000047e-02
 -8.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.07

[ 0.00000000e+00  7.95731469e+08 -1.49289108e-01  1.60000047e-02
  7.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.41429974e-01 -4.42306935e-01  6.00000000e-01]
Read chunk # 238 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  3.58108923e-02 -5.99999531e-03
  9.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.75852264e-01  3.66308726e-02 -3.51685115e-01]
Read chunk # 239 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  2.74108923e-02 -8.59999531e-03
 -5.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.23724754e-01  1.09783209e-01 -2.82515006e-01]
Read chunk # 240 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.10892292e-04  1.08000047e-02
  3.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.29815510e-02 -1.83911271e-02  2.92631536e-02]
Read chunk # 241 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  5.74108923e-02  1.22000047e-02
 -2.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.17

[ 0.00000000e+00  7.95731469e+08  5.50108923e-02  1.84000047e-02
 -3.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.72752564e-01  1.69296388e-01 -4.50311226e-01]
Read chunk # 276 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.07589108e-01 -1.31999953e-02
  8.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.76740162e-01 -3.36453228e-01  6.00000000e-01]
Read chunk # 277 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.32110892e-01  1.34000047e-02
 -7.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.58147044e-01  3.97843289e-01 -6.00000000e-01]
Read chunk # 278 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.35710892e-01 -5.39999531e-03
 -1.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.56451566e-02  3.70635422e-01 -6.00000000e-01]
Read chunk # 279 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  5.68108923e-02  1.48000047e-02
 -8.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.94

[ 0.00000000e+00  7.95731469e+08  4.45108923e-02  1.22000047e-02
  6.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.37128742e-01  7.51243906e-02 -3.73354781e-01]
Read chunk # 314 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.39108923e-02  1.94000047e-02
  3.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.71485853e-02  1.54051255e-02 -6.76553142e-02]
Read chunk # 315 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -6.88891077e-02  1.18000047e-02
  5.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.11020855e-01 -2.18560342e-01  6.00000000e-01]
Read chunk # 316 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.57108923e-02  9.40000469e-03
 -2.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.06557388e-01  5.81938537e-02 -1.16497217e-01]
Read chunk # 317 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  3.10108923e-02 -7.99995308e-04
  5.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.52

[ 0.00000000e+00  7.95731469e+08 -5.32891077e-02  1.78000047e-02
  2.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.96492673e-03 -1.55360094e-01  5.47576363e-01]
Read chunk # 352 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -8.08891077e-02  6.20000469e-03
  2.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.73954685e-02 -2.30340975e-01  6.00000000e-01]
Read chunk # 353 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.89107708e-04 -1.21999953e-02
 -5.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.12202455e-01  3.57922913e-02 -3.37745665e-02]
Read chunk # 354 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  2.56108923e-02  5.60000469e-03
 -1.34198363e-08  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.22318228e-02  6.82834559e-02 -2.20138865e-01]
Read chunk # 355 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  7.61089229e-03 -9.99999531e-03
  2.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.80

[ 0.00000000e+00  7.95731469e+08 -7.33891077e-02 -1.45999953e-02
 -8.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.74683803e-01 -1.39597501e-01  6.00000000e-01]
Read chunk # 390 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.00610892e-01  7.20000469e-03
  6.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.62517758e-01  2.22312232e-01 -6.00000000e-01]
Read chunk # 391 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.15108923e-02  1.52000047e-02
 -7.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.62523399e-01  1.57412713e-01 -3.35991594e-01]
Read chunk # 392 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  3.16108923e-02  6.00004692e-04
  2.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.13185292e-02  6.93401154e-02 -2.91642802e-01]
Read chunk # 393 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -2.86891077e-02  1.22000047e-02
  2.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.72

[ 0.00000000e+00  7.95731469e+08  6.11089229e-03 -5.99995308e-04
 -5.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.52714913e-01  5.27651707e-02 -6.00947981e-02]
Read chunk # 429 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.20410892e-01 -7.99999531e-03
 -9.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.17503119e-01  3.77189337e-01 -6.00000000e-01]
Read chunk # 430 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.75108923e-02 -2.39999531e-03
 -1.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.81679499e-02  1.32689870e-01 -4.48098247e-01]
Read chunk # 431 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  5.38108923e-02  1.88000047e-02
  4.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.55352025e-02  1.12627946e-01 -4.37943901e-01]
Read chunk # 432 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.05789108e-01  2.40000469e-03
  9.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.61

[ 0.00000000e+00  7.95731469e+08  1.12010892e-01  2.00000469e-03
  8.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.41322929e-01  2.38949753e-01 -6.00000000e-01]
Read chunk # 468 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  9.19108923e-02 -1.39999531e-03
  4.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.44037326e-01  2.11531680e-01 -6.00000000e-01]
Read chunk # 469 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.21891077e-02 -7.59999531e-03
  3.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.21568144e-01 -2.11826571e-01  6.00000000e-01]
Read chunk # 470 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -2.92891077e-02  9.20000469e-03
  1.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.88680390e-02 -8.99071375e-02  2.98296406e-01]
Read chunk # 471 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -2.44891077e-02 -1.35999953e-02
 -9.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.14

[ 0.00000000e+00  7.95731469e+08  1.06010892e-01 -8.79999531e-03
 -9.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.14327144e-01  3.39077641e-01 -6.00000000e-01]
Read chunk # 506 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.21010892e-01  1.58000047e-02
 -6.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.37209506e-01  3.61376160e-01 -6.00000000e-01]
Read chunk # 507 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.17789108e-01  1.98000047e-02
  9.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.92814191e-01 -3.74405301e-01  6.00000000e-01]
Read chunk # 508 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -9.13891077e-02  1.54000047e-02
  9.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.10282052e-01 -3.04533858e-01  6.00000000e-01]
Read chunk # 509 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.69891077e-02  1.26000047e-02
 -5.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.94

start
[ 0.00000000e+00  7.95731469e+08 -5.28910771e-03 -1.41999953e-02
  6.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.57606626e-02 -1.80093054e-02  1.32436882e-03]
Read chunk # 544 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.75108923e-02  1.66000047e-02
 -1.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.86710623e-02  4.74901143e-02 -1.09913189e-01]
Read chunk # 545 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  6.13108923e-02 -1.47999953e-02
 -7.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.46193936e-01  2.10460783e-01 -6.00000000e-01]
Read chunk # 546 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.45389108e-01 -2.19999531e-03
 -3.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.54320412e-02 -3.62380361e-01  6.00000000e-01]
Read chunk # 547 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.43589108e-01  3.40000469e-03
 -7.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08 -1.48891077e-02 -1.39999953e-02
 -3.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.68951766e-02 -1.50596502e-02  9.05956532e-02]
Read chunk # 583 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -3.16891077e-02  8.00000469e-03
 -7.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.45017977e-01 -3.37437933e-02  3.16586179e-01]
Read chunk # 584 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  3.40108923e-02  1.38000047e-02
 -8.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.87430905e-01  1.44652234e-01 -2.71263322e-01]
Read chunk # 585 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -3.55891077e-02  1.68000047e-02
 -2.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.22087217e-01 -8.08015022e-02  3.80948268e-01]
Read chunk # 586 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  6.07108923e-02  1.90000047e-02
 -2.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.39

start
[ 0.00000000e+00  7.95731469e+08 -8.58910771e-03 -1.09999953e-02
  3.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.29526861e-01 -4.22109373e-02  4.21016312e-02]
Read chunk # 622 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.07810892e-01 -1.41999953e-02
 -3.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.05620082e-02  3.08394787e-01 -6.00000000e-01]
Read chunk # 623 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.15389108e-01  6.40000469e-03
 -8.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.69131489e-01 -2.48178665e-01  6.00000000e-01]
Read chunk # 624 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.42989108e-01  1.46000047e-02
  7.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.55296692e-01 -4.27566531e-01  6.00000000e-01]
Read chunk # 625 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.26489108e-01 -1.39999531e-03
  7.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08 -5.86891077e-02 -1.71999953e-02
  4.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.95684476e-01 -1.84475194e-01  4.84643612e-01]
Read chunk # 662 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -2.68891077e-02 -8.59999531e-03
  7.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.55708720e-01 -1.22224442e-01  2.18780313e-01]
Read chunk # 663 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.26710892e-01  2.00004692e-04
 -6.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.74115181e-02  3.39726249e-01 -6.00000000e-01]
Read chunk # 664 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -6.28891077e-02 -1.19999531e-03
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.47268491e-01 -1.07296309e-01  5.74976841e-01]
Read chunk # 665 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.30891077e-02  1.42000047e-02
 -9.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.16

start
[ 0.00000000e+00  7.95731469e+08 -1.35189108e-01  1.34000047e-02
 -6.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.41529519e-01 -3.13471999e-01  6.00000000e-01]
Read chunk # 702 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.45310892e-01 -8.19999531e-03
  3.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.23950125e-01  3.63818842e-01 -6.00000000e-01]
Read chunk # 703 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.47489108e-01  9.40000469e-03
  6.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.51014241e-01 -4.33676048e-01  6.00000000e-01]
Read chunk # 704 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.16510892e-01 -1.45999953e-02
  9.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.18302754e-01  2.48281708e-01 -6.00000000e-01]
Read chunk # 705 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  6.43108923e-02 -4.99999531e-03
  3.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08 -1.23489108e-01  1.54000047e-02
 -2.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.19298848e-01 -3.12797161e-01  6.00000000e-01]
Read chunk # 741 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.06891077e-02  7.20000469e-03
 -8.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.66768289e-01  2.76361518e-02  1.20137111e-01]
Read chunk # 742 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -3.40891077e-02  1.52000047e-02
  6.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.22449249e-01 -1.32257464e-01  3.61944431e-01]
Read chunk # 743 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.37510892e-01 -1.81999953e-02
 -8.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.00967251e-02  3.69598996e-01 -6.00000000e-01]
Read chunk # 744 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.35489108e-01  1.82000047e-02
 -1.34198363e-08  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.22

[ 0.00000000e+00  7.95731469e+08  4.09108923e-02  3.20000469e-03
  7.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.86706394e-01  6.23740287e-02 -3.69121673e-01]
Read chunk # 780 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.20789108e-01  6.80000469e-03
 -6.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.93171029e-01 -2.80516203e-01  6.00000000e-01]
Read chunk # 781 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  8.50108923e-02 -1.61999953e-02
  6.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  8.37005993e-02  2.20982790e-01 -6.00000000e-01]
Read chunk # 782 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.07210892e-01  1.58000047e-02
 -8.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.09218779e-01  3.41609128e-01 -6.00000000e-01]
Read chunk # 783 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.20789108e-01  4.40000469e-03
  9.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.53

[ 0.00000000e+00  7.95731469e+08 -4.30891077e-02  4.69199998e-09
 -4.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.35709783e-01 -8.19612047e-02  3.96050998e-01]
Read chunk # 819 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -2.95891077e-02 -1.87999953e-02
 -6.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.16466125e-01 -3.33417381e-02  2.10837734e-01]
Read chunk # 820 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.42310892e-01 -8.59999531e-03
 -1.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.84804326e-02  3.89392258e-01 -6.00000000e-01]
Read chunk # 821 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -6.48910771e-03  4.40000469e-03
  1.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.57671880e-02 -2.44077179e-02  7.23400618e-02]
Read chunk # 822 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  9.58108923e-02  5.80000469e-03
  8.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.06

[ 0.00000000e+00  7.95731469e+08  1.09010892e-01  1.94000047e-02
  7.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.22392906e-01  2.42610592e-01 -6.00000000e-01]
Read chunk # 859 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.02489108e-01 -1.39999531e-03
  4.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.32958976e-01 -3.00398270e-01  6.00000000e-01]
Read chunk # 860 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.20489108e-01 -2.79999531e-03
  5.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.60673631e-01 -3.53193791e-01  6.00000000e-01]
Read chunk # 861 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.49810892e-01  1.86000047e-02
 -5.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.26168719e-01  4.32443653e-01 -6.00000000e-01]
Read chunk # 862 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.18108923e-02 -1.65999953e-02
  5.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.15

[ 0.00000000e+00  7.95731469e+08 -1.34589108e-01 -6.39999531e-03
  7.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.75644967e-02 -3.60865066e-01  6.00000000e-01]
Read chunk # 897 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  1.38110892e-01  2.00004692e-04
 -9.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.61135209e-01  4.26612914e-01 -6.00000000e-01]
Read chunk # 898 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  4.87108923e-02  2.00004692e-04
  2.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.67544535e-02  1.11375317e-01 -4.50798258e-01]
Read chunk # 899 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.10889108e-01  8.80000469e-03
 -1.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.09403578e-02 -2.84605327e-01  6.00000000e-01]
Read chunk # 900 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  5.92108923e-02 -1.17999953e-02
 -1.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.30

[ 0.00000000e+00  7.95731469e+08 -8.62891077e-02 -1.99999953e-02
  2.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.40330290e-01 -2.42054911e-01  6.00000000e-01]
Read chunk # 935 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  9.79108923e-02  4.00004692e-04
 -8.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.25924567e-01  3.11839424e-01 -6.00000000e-01]
Read chunk # 936 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  7.87108923e-02  1.30000047e-02
 -7.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.09967005e-02  2.13331749e-01 -6.00000000e-01]
Read chunk # 937 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -9.46891077e-02 -1.59999531e-03
 -2.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.45789734e-02 -2.35929281e-01  6.00000000e-01]
Read chunk # 938 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  6.58108923e-02 -6.59999531e-03
  8.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.53

[ 0.00000000e+00  7.95731469e+08 -2.28910771e-03  1.08000047e-02
  1.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.10143476e-03 -1.45807816e-02  5.41894402e-02]
Read chunk # 974 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -9.55891077e-02  7.00000469e-03
 -7.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.21660896e-01 -2.07375859e-01  6.00000000e-01]
Read chunk # 975 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -4.39891077e-02  1.18000047e-02
  2.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.07028205e-02 -1.33968728e-01  4.42384526e-01]
Read chunk # 976 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.18689108e-01  1.36000047e-02
  1.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.36940919e-03 -3.25872765e-01  6.00000000e-01]
Read chunk # 977 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -4.12891077e-02  1.36000047e-02
  9.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.06

In [4]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -4.2979746,
 'std': 2.9650762,
 'min': -19.356209,
 'max': 0.0,
 'histogram': (array([   21,    43,    95,   379,  1195,  1979,  2612,  3300,  4137,
          5446,  7025,  8735, 11719, 18487, 25193, 33395, 40758, 54770,
         46735, 11121]),
  array([-19.356209  , -18.3884    , -17.420588  , -16.452778  ,
         -15.484967  , -14.517157  , -13.549346  , -12.581535  ,
         -11.613726  , -10.645915  ,  -9.678104  ,  -8.710294  ,
          -7.7424836 ,  -6.774673  ,  -5.806863  ,  -4.839052  ,
          -3.8712418 ,  -2.9034314 ,  -1.9356209 ,  -0.96781045,
           0.        ], dtype=float32))}

In [5]:
# plt.plot(states[:,3])

In [6]:
# plt.plot(actions)

In [7]:
# plt.plot(model.predict(np.array(states)))

In [8]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

## Setting up an Algorithm

In [9]:
from d3rlpy.algos import CQL
from d3rlpy.models.encoders import VectorEncoderFactory

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

actor_encoder = VectorEncoderFactory(hidden_units=[12, 24, 36, 24, 12],
                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)
critic_encoder = VectorEncoderFactory(hidden_units=[12, 24, 24, 12],
                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)

model = CQL(q_func_factory='qr', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            actor_encoder_factory = actor_encoder,
            critic_encoder_factory = critic_encoder,
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=True) #change it to true if you have one
model.build_with_dataset(dataset)

In [10]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

-0.019474013344166884


In [11]:
#%load_ext tensorboard
#%tensorboard --logdir runs

In [12]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-15 19:48.22 [debug    ] RoundIterator is selected.
2022-04-15 19:48.22 [info     ] Directory is created at d3rlpy_logs/CQL_20220415194822
2022-04-15 19:48.22 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-15 19:48.22 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-15 19:48.23 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220415194822/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'vector', 'params': {'hidden_units': [12, 24, 36, 24, 12], 'activation': 'relu', 'use_batch_norm': True, 'dropout_rate': 0.2, 'use_dense': False}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': Fals

Epoch 1/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 20:27.40 [info     ] CQL_20220415194822: epoch=1 step=85741 epoch=1 metrics={'time_sample_batch': 0.00031275933899537264, 'time_algorithm_update': 0.020723903086432908, 'temp_loss': 0.5277301647774326, 'temp': 0.1243415923763854, 'alpha_loss': -19313.581439918868, 'alpha': 1318.5249205411626, 'critic_loss': 17695.369889204558, 'actor_loss': 2.185401217090781, 'time_step': 0.02453802335468208, 'td_error': 1.3116502787876172, 'init_value': -5.9981889724731445, 'ave_value': -5.998183777200564} step=85741
2022-04-15 20:27.40 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_85741.pt


Epoch 2/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 21:07.40 [info     ] CQL_20220415194822: epoch=2 step=171482 epoch=2 metrics={'time_sample_batch': 0.0003043265410616692, 'time_algorithm_update': 0.02055568481845553, 'temp_loss': 0.00015883790537005648, 'temp': 0.0008094566549346069, 'alpha_loss': -8375901.8132385975, 'alpha': 640866.8451333335, 'critic_loss': 7871029.914287469, 'actor_loss': 10.528318059189855, 'time_step': 0.025112815303733306, 'td_error': 2.360244169918384, 'init_value': -11.755678176879883, 'ave_value': -11.755678390894605} step=171482
2022-04-15 21:07.40 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_171482.pt


Epoch 3/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 21:47.49 [info     ] CQL_20220415194822: epoch=3 step=257223 epoch=3 metrics={'time_sample_batch': 0.00030402676114102704, 'time_algorithm_update': 0.020606310450087106, 'temp_loss': 0.00015607679701247522, 'temp': 0.00033537459650869155, 'alpha_loss': -13383317.610349774, 'alpha': 1001080.3125, 'critic_loss': 12327216.323042652, 'actor_loss': 11.624818662523152, 'time_step': 0.025162534040486042, 'td_error': 2.2134307715695707, 'init_value': -11.188474655151367, 'ave_value': -11.188474011703462} step=257223
2022-04-15 21:47.49 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_257223.pt


Epoch 4/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 22:28.13 [info     ] CQL_20220415194822: epoch=4 step=342964 epoch=4 metrics={'time_sample_batch': 0.0003065771539428083, 'time_algorithm_update': 0.020708899227647796, 'temp_loss': -2.1391722713341628e-05, 'temp': 0.00015364132093558856, 'alpha_loss': -12327037.557189675, 'alpha': 1001080.3125, 'critic_loss': 12140400.615971355, 'actor_loss': 12.278944450913466, 'time_step': 0.02535528280262828, 'td_error': 2.846814049539306, 'init_value': -13.689773559570312, 'ave_value': -13.68977322760292} step=342964
2022-04-15 22:28.13 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_342964.pt


Epoch 5/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 23:08.26 [info     ] CQL_20220415194822: epoch=5 step=428705 epoch=5 metrics={'time_sample_batch': 0.00030470813146989313, 'time_algorithm_update': 0.020629897705840523, 'temp_loss': -4.2312726479041704e-06, 'temp': 0.00014953832779034188, 'alpha_loss': -12505671.227988943, 'alpha': 1001080.3125, 'critic_loss': 12171208.153940355, 'actor_loss': 14.992480897295586, 'time_step': 0.02524293660964616, 'td_error': 3.144021573664416, 'init_value': -14.715094566345215, 'ave_value': -14.71509383888025} step=428705
2022-04-15 23:08.26 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_428705.pt


Epoch 6/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-15 23:48.40 [info     ] CQL_20220415194822: epoch=6 step=514446 epoch=6 metrics={'time_sample_batch': 0.000304875406263101, 'time_algorithm_update': 0.020652849072123468, 'temp_loss': 1.08136075050775e-06, 'temp': 5.8964901937962406e-05, 'alpha_loss': -12284862.400275247, 'alpha': 1001080.3125, 'critic_loss': 12085494.017751135, 'actor_loss': 13.579759732074209, 'time_step': 0.025258611466810896, 'td_error': 2.615160139787588, 'init_value': -12.83185863494873, 'ave_value': -12.831860048715814} step=514446
2022-04-15 23:48.40 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_514446.pt


Epoch 7/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-16 00:29.07 [info     ] CQL_20220415194822: epoch=7 step=600187 epoch=7 metrics={'time_sample_batch': 0.00030416193850525024, 'time_algorithm_update': 0.020700133276030763, 'temp_loss': 2.8629059917853134e-06, 'temp': 7.974594527920217e-05, 'alpha_loss': -12370957.113609591, 'alpha': 1001080.3125, 'critic_loss': 12100587.54939877, 'actor_loss': 11.931004624135351, 'time_step': 0.025369210033617178, 'td_error': 2.2850078968282594, 'init_value': -11.500444412231445, 'ave_value': -11.500446219969144} step=600187
2022-04-16 00:29.07 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_600187.pt


Epoch 8/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-16 01:09.32 [info     ] CQL_20220415194822: epoch=8 step=685928 epoch=8 metrics={'time_sample_batch': 0.0003054375130790892, 'time_algorithm_update': 0.02066124657206154, 'temp_loss': 4.2987725179149425e-06, 'temp': 4.682954016648013e-05, 'alpha_loss': -12387734.390676573, 'alpha': 1001080.3125, 'critic_loss': 12113971.474160554, 'actor_loss': 10.789485654162345, 'time_step': 0.02537421233297455, 'td_error': 2.2064208532309633, 'init_value': -11.162444114685059, 'ave_value': -11.162443161010742} step=685928
2022-04-16 01:09.32 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_685928.pt


Epoch 9/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-16 01:50.03 [info     ] CQL_20220415194822: epoch=9 step=771669 epoch=9 metrics={'time_sample_batch': 0.00030707452922892274, 'time_algorithm_update': 0.020735845929954203, 'temp_loss': -1.128560036883189e-06, 'temp': 3.86948897100545e-05, 'alpha_loss': -12239177.316546343, 'alpha': 1001080.3125, 'critic_loss': 12065161.480050385, 'actor_loss': 10.028992998721105, 'time_step': 0.025429588701093762, 'td_error': 1.7640987986467795, 'init_value': -9.013802528381348, 'ave_value': -9.013804238063573} step=771669
2022-04-16 01:50.03 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_771669.pt


Epoch 10/10:   0%|          | 0/85741 [00:00<?, ?it/s]

2022-04-16 02:30.34 [info     ] CQL_20220415194822: epoch=10 step=857410 epoch=10 metrics={'time_sample_batch': 0.0003076755517096982, 'time_algorithm_update': 0.020737744080086183, 'temp_loss': 4.55514840591495e-08, 'temp': 5.123224858212597e-05, 'alpha_loss': -12257118.904223183, 'alpha': 1001080.3125, 'critic_loss': 12076661.861279901, 'actor_loss': 7.639729850049536, 'time_step': 0.025443260273266214, 'td_error': 1.3771597688444286, 'init_value': -6.577911853790283, 'ave_value': -6.577912109429451} step=857410
2022-04-16 02:30.34 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220415194822/model_857410.pt


[(1,
  {'time_sample_batch': 0.00031275933899537264,
   'time_algorithm_update': 0.020723903086432908,
   'temp_loss': 0.5277301647774326,
   'temp': 0.1243415923763854,
   'alpha_loss': -19313.581439918868,
   'alpha': 1318.5249205411626,
   'critic_loss': 17695.369889204558,
   'actor_loss': 2.185401217090781,
   'time_step': 0.02453802335468208,
   'td_error': 1.3116502787876172,
   'init_value': -5.9981889724731445,
   'ave_value': -5.998183777200564}),
 (2,
  {'time_sample_batch': 0.0003043265410616692,
   'time_algorithm_update': 0.02055568481845553,
   'temp_loss': 0.00015883790537005648,
   'temp': 0.0008094566549346069,
   'alpha_loss': -8375901.8132385975,
   'alpha': 640866.8451333335,
   'critic_loss': 7871029.914287469,
   'actor_loss': 10.528318059189855,
   'time_step': 0.025112815303733306,
   'td_error': 2.360244169918384,
   'init_value': -11.755678176879883,
   'ave_value': -11.755678390894605}),
 (3,
  {'time_sample_batch': 0.00030402676114102704,
   'time_algorithm

In [15]:
model.save_policy("cqlDet998.pt") 

  minimum = torch.tensor(
  maximum = torch.tensor(


In [16]:
model.save_model("cqlDet998model.pt")

## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [13]:
from d3rlpy.ope import FQE
# metrics to evaluate with
from d3rlpy.metrics.scorer import soft_opc_scorer


ope_dataset, states_ope, actions_ope, rewards_ope = get_dataset([2,4,6,8], path="collected_data/rl_deterministic.txt") #change if you'd prefer different chunks
ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
fqe.fit(ope_train_episodes,
        eval_episodes=ope_test_episodes,
        tensorboard_dir='runs',
        n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
        scorers={
           'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer,
           'soft_opc': soft_opc_scorer(return_threshold=0)
        })

start
[ 0.00000000e+00  7.95731469e+08 -1.03891077e-02 -1.41999953e-02
 -2.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.78778459e-03 -1.34615461e-02  4.84073546e-02]
Read chunk # 3 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -7.24891077e-02 -1.35999953e-02
 -4.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.23311010e-02 -1.64283998e-01  6.00000000e-01]
Read chunk # 5 out of 10000
start
[ 0.00000000e+00  7.95731469e+08  7.01089229e-03 -4.19999531e-03
  7.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.21623335e-01 -2.86362315e-02 -8.00043364e-02]
Read chunk # 7 out of 10000
start
[ 0.00000000e+00  7.95731469e+08 -1.03989108e-01 -1.37999953e-02
  7.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.76352555e-01 -3.26280816e-01  6.00000000e-01]
Read chunk # 9 out of 10000
torch.Size([111080, 6])
2022-04-16 02:30.35 [debug    ] RoundIterator is selected.
2022-04-16 02:30.35 [info     ] Directory is created at d3rlp

Epoch 1/50:   0%|          | 0/878 [00:00<?, ?it/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


2022-04-16 02:30.38 [info     ] FQE_20220416023035: epoch=1 step=878 epoch=1 metrics={'time_sample_batch': 0.00011914167425898593, 'time_algorithm_update': 0.0029855746614634313, 'loss': 0.0002461599849627247, 'time_step': 0.003155057142427136, 'init_value': -0.22843682765960693, 'ave_value': -0.22831709708927517, 'soft_opc': nan} step=878




2022-04-16 02:30.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_878.pt


Epoch 2/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:30.42 [info     ] FQE_20220416023035: epoch=2 step=1756 epoch=2 metrics={'time_sample_batch': 0.0001231980486719918, 'time_algorithm_update': 0.0030236969233102297, 'loss': 0.0014282974988773878, 'time_step': 0.0031988797263838437, 'init_value': -0.5504694581031799, 'ave_value': -0.5503293672049991, 'soft_opc': nan} step=1756




2022-04-16 02:30.42 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_1756.pt


Epoch 3/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:30.46 [info     ] FQE_20220416023035: epoch=3 step=2634 epoch=3 metrics={'time_sample_batch': 0.000123508698847951, 'time_algorithm_update': 0.0031934004438222132, 'loss': 0.004029400426500787, 'time_step': 0.003368657922418894, 'init_value': -0.787994921207428, 'ave_value': -0.7878976750239928, 'soft_opc': nan} step=2634




2022-04-16 02:30.46 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_2634.pt


Epoch 4/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:30.49 [info     ] FQE_20220416023035: epoch=4 step=3512 epoch=4 metrics={'time_sample_batch': 0.000124397473346128, 'time_algorithm_update': 0.0031993448870057126, 'loss': 0.007525220854168146, 'time_step': 0.003378618822553978, 'init_value': -1.0022207498550415, 'ave_value': -1.002360916034561, 'soft_opc': nan} step=3512




2022-04-16 02:30.49 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_3512.pt


Epoch 5/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:30.54 [info     ] FQE_20220416023035: epoch=5 step=4390 epoch=5 metrics={'time_sample_batch': 0.00013417480746814493, 'time_algorithm_update': 0.0035719817211524773, 'loss': 0.010765274227209073, 'time_step': 0.0037644083244653933, 'init_value': -1.1627755165100098, 'ave_value': -1.1630126333066064, 'soft_opc': nan} step=4390




2022-04-16 02:30.54 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_4390.pt


Epoch 6/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:30.57 [info     ] FQE_20220416023035: epoch=6 step=5268 epoch=6 metrics={'time_sample_batch': 0.00012600747762345508, 'time_algorithm_update': 0.003410469426652565, 'loss': 0.013091343081994302, 'time_step': 0.003584749877860171, 'init_value': -1.2380139827728271, 'ave_value': -1.2383038729446794, 'soft_opc': nan} step=5268




2022-04-16 02:30.57 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_5268.pt


Epoch 7/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.02 [info     ] FQE_20220416023035: epoch=7 step=6146 epoch=7 metrics={'time_sample_batch': 0.00012883319941631483, 'time_algorithm_update': 0.003542760781655279, 'loss': 0.016386731998706525, 'time_step': 0.0037237403061503973, 'init_value': -1.3720546960830688, 'ave_value': -1.3723281069705142, 'soft_opc': nan} step=6146




2022-04-16 02:31.02 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_6146.pt


Epoch 8/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.06 [info     ] FQE_20220416023035: epoch=8 step=7024 epoch=8 metrics={'time_sample_batch': 0.00013460195146008884, 'time_algorithm_update': 0.0037369548867123544, 'loss': 0.018053457120000026, 'time_step': 0.003923984214765336, 'init_value': -1.415456771850586, 'ave_value': -1.4159563347314554, 'soft_opc': nan} step=7024




2022-04-16 02:31.06 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_7024.pt


Epoch 9/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.10 [info     ] FQE_20220416023035: epoch=9 step=7902 epoch=9 metrics={'time_sample_batch': 0.00012868873622259953, 'time_algorithm_update': 0.0035864019719777725, 'loss': 0.020974107120380222, 'time_step': 0.0037703375609969224, 'init_value': -1.553494930267334, 'ave_value': -1.5538333263033062, 'soft_opc': nan} step=7902




2022-04-16 02:31.10 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_7902.pt


Epoch 10/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.14 [info     ] FQE_20220416023035: epoch=10 step=8780 epoch=10 metrics={'time_sample_batch': 0.00012621113814091086, 'time_algorithm_update': 0.0035442499473589155, 'loss': 0.023965579658403274, 'time_step': 0.0037206905577883364, 'init_value': -1.6163628101348877, 'ave_value': -1.616753696293785, 'soft_opc': nan} step=8780




2022-04-16 02:31.14 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_8780.pt


Epoch 11/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.18 [info     ] FQE_20220416023035: epoch=11 step=9658 epoch=11 metrics={'time_sample_batch': 0.00013945586045945152, 'time_algorithm_update': 0.0039043869135863144, 'loss': 0.0252154201905678, 'time_step': 0.004103558481690009, 'init_value': -1.6846381425857544, 'ave_value': -1.6848610164560085, 'soft_opc': nan} step=9658




2022-04-16 02:31.18 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_9658.pt


Epoch 12/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.23 [info     ] FQE_20220416023035: epoch=12 step=10536 epoch=12 metrics={'time_sample_batch': 0.00013890950717795683, 'time_algorithm_update': 0.003996203863539294, 'loss': 0.02730591854972219, 'time_step': 0.004192780525103246, 'init_value': -1.7097747325897217, 'ave_value': -1.7100223125838854, 'soft_opc': nan} step=10536




2022-04-16 02:31.23 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_10536.pt


Epoch 13/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.28 [info     ] FQE_20220416023035: epoch=13 step=11414 epoch=13 metrics={'time_sample_batch': 0.00014992862736174078, 'time_algorithm_update': 0.004310467520172884, 'loss': 0.02973489029262921, 'time_step': 0.004526891306483935, 'init_value': -1.7842313051223755, 'ave_value': -1.7845746293989986, 'soft_opc': nan} step=11414




2022-04-16 02:31.28 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_11414.pt


Epoch 14/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.32 [info     ] FQE_20220416023035: epoch=14 step=12292 epoch=14 metrics={'time_sample_batch': 0.0001497428889698211, 'time_algorithm_update': 0.0044063218362239065, 'loss': 0.031877272707285076, 'time_step': 0.004620766856947358, 'init_value': -1.86216402053833, 'ave_value': -1.8628361398188722, 'soft_opc': nan} step=12292




2022-04-16 02:31.32 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_12292.pt


Epoch 15/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.37 [info     ] FQE_20220416023035: epoch=15 step=13170 epoch=15 metrics={'time_sample_batch': 0.00013555644041856493, 'time_algorithm_update': 0.003925226815469173, 'loss': 0.03220105893429627, 'time_step': 0.00411816881566493, 'init_value': -1.7957780361175537, 'ave_value': -1.7962395659541448, 'soft_opc': nan} step=13170




2022-04-16 02:31.37 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_13170.pt


Epoch 16/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.41 [info     ] FQE_20220416023035: epoch=16 step=14048 epoch=16 metrics={'time_sample_batch': 0.00013657827312147972, 'time_algorithm_update': 0.003964018170002653, 'loss': 0.030373404652768353, 'time_step': 0.0041586896553126446, 'init_value': -1.7345776557922363, 'ave_value': -1.7352293853543346, 'soft_opc': nan} step=14048




2022-04-16 02:31.41 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_14048.pt


Epoch 17/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.46 [info     ] FQE_20220416023035: epoch=17 step=14926 epoch=17 metrics={'time_sample_batch': 0.00014634909010693804, 'time_algorithm_update': 0.004316313934760648, 'loss': 0.02886914298745339, 'time_step': 0.004524589942636686, 'init_value': -1.7337546348571777, 'ave_value': -1.7345796636725788, 'soft_opc': nan} step=14926




2022-04-16 02:31.46 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_14926.pt


Epoch 18/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.51 [info     ] FQE_20220416023035: epoch=18 step=15804 epoch=18 metrics={'time_sample_batch': 0.000145303904331355, 'time_algorithm_update': 0.00420207694885671, 'loss': 0.030164917557960098, 'time_step': 0.004408499917571257, 'init_value': -1.812066674232483, 'ave_value': -1.812928592132238, 'soft_opc': nan} step=15804




2022-04-16 02:31.51 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_15804.pt


Epoch 19/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:31.56 [info     ] FQE_20220416023035: epoch=19 step=16682 epoch=19 metrics={'time_sample_batch': 0.00013428505569492767, 'time_algorithm_update': 0.0038951179161158672, 'loss': 0.0315289858962116, 'time_step': 0.004085287960895374, 'init_value': -1.8328897953033447, 'ave_value': -1.8337049069555578, 'soft_opc': nan} step=16682




2022-04-16 02:31.56 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_16682.pt


Epoch 20/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.01 [info     ] FQE_20220416023035: epoch=20 step=17560 epoch=20 metrics={'time_sample_batch': 0.00014040573311286528, 'time_algorithm_update': 0.004186181924337678, 'loss': 0.03234182231752208, 'time_step': 0.004386157544164288, 'init_value': -1.8282415866851807, 'ave_value': -1.8289167450547472, 'soft_opc': nan} step=17560




2022-04-16 02:32.01 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_17560.pt


Epoch 21/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.05 [info     ] FQE_20220416023035: epoch=21 step=18438 epoch=21 metrics={'time_sample_batch': 0.00013677405876559386, 'time_algorithm_update': 0.0040657305771777735, 'loss': 0.031840031128355796, 'time_step': 0.004261241415367039, 'init_value': -1.8357875347137451, 'ave_value': -1.8365463098278998, 'soft_opc': nan} step=18438




2022-04-16 02:32.05 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_18438.pt


Epoch 22/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.10 [info     ] FQE_20220416023035: epoch=22 step=19316 epoch=22 metrics={'time_sample_batch': 0.00014043098701702978, 'time_algorithm_update': 0.004209822565656588, 'loss': 0.0324509807861724, 'time_step': 0.004413044534131444, 'init_value': -1.864782452583313, 'ave_value': -1.8655394506221201, 'soft_opc': nan} step=19316




2022-04-16 02:32.10 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_19316.pt


Epoch 23/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.15 [info     ] FQE_20220416023035: epoch=23 step=20194 epoch=23 metrics={'time_sample_batch': 0.00014109274792508276, 'time_algorithm_update': 0.004184769334988605, 'loss': 0.032868608866862103, 'time_step': 0.004385049359401972, 'init_value': -1.8383742570877075, 'ave_value': -1.8393810563325699, 'soft_opc': nan} step=20194




2022-04-16 02:32.15 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_20194.pt


Epoch 24/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.19 [info     ] FQE_20220416023035: epoch=24 step=21072 epoch=24 metrics={'time_sample_batch': 0.00014618996335596593, 'time_algorithm_update': 0.004358206903201301, 'loss': 0.03188741395750761, 'time_step': 0.004569656485034013, 'init_value': -1.809640645980835, 'ave_value': -1.810712363120897, 'soft_opc': nan} step=21072




2022-04-16 02:32.19 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_21072.pt


Epoch 25/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.24 [info     ] FQE_20220416023035: epoch=25 step=21950 epoch=25 metrics={'time_sample_batch': 0.00014463834176030953, 'time_algorithm_update': 0.004326202060204161, 'loss': 0.03047883990812125, 'time_step': 0.004530911565098513, 'init_value': -1.7431156635284424, 'ave_value': -1.7441894749333289, 'soft_opc': nan} step=21950




2022-04-16 02:32.24 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_21950.pt


Epoch 26/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.29 [info     ] FQE_20220416023035: epoch=26 step=22828 epoch=26 metrics={'time_sample_batch': 0.00014882695973598333, 'time_algorithm_update': 0.004618687347290456, 'loss': 0.0275054973336132, 'time_step': 0.004828963029900553, 'init_value': -1.6903009414672852, 'ave_value': -1.6911078343957524, 'soft_opc': nan} step=22828




2022-04-16 02:32.29 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_22828.pt


Epoch 27/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.34 [info     ] FQE_20220416023035: epoch=27 step=23706 epoch=27 metrics={'time_sample_batch': 0.00014512522617070713, 'time_algorithm_update': 0.00448395152298356, 'loss': 0.028326731827074417, 'time_step': 0.0046899834635045915, 'init_value': -1.7107409238815308, 'ave_value': -1.711418800133264, 'soft_opc': nan} step=23706




2022-04-16 02:32.34 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_23706.pt


Epoch 28/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.39 [info     ] FQE_20220416023035: epoch=28 step=24584 epoch=28 metrics={'time_sample_batch': 0.00014933882650318883, 'time_algorithm_update': 0.004580956656731886, 'loss': 0.030346161860706905, 'time_step': 0.004792011951802803, 'init_value': -1.7637284994125366, 'ave_value': -1.7644309711968191, 'soft_opc': nan} step=24584




2022-04-16 02:32.39 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_24584.pt


Epoch 29/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.44 [info     ] FQE_20220416023035: epoch=29 step=25462 epoch=29 metrics={'time_sample_batch': 0.00013869145465060084, 'time_algorithm_update': 0.004145365641165974, 'loss': 0.030744309375907365, 'time_step': 0.004343791811775781, 'init_value': -1.7662991285324097, 'ave_value': -1.7669989213987976, 'soft_opc': nan} step=25462




2022-04-16 02:32.44 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_25462.pt


Epoch 30/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.49 [info     ] FQE_20220416023035: epoch=30 step=26340 epoch=30 metrics={'time_sample_batch': 0.00015354428041497234, 'time_algorithm_update': 0.004875266470507228, 'loss': 0.02957598905034114, 'time_step': 0.005093423000500794, 'init_value': -1.7704755067825317, 'ave_value': -1.7712099605395633, 'soft_opc': nan} step=26340




2022-04-16 02:32.49 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_26340.pt


Epoch 31/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.54 [info     ] FQE_20220416023035: epoch=31 step=27218 epoch=31 metrics={'time_sample_batch': 0.00014105934760022, 'time_algorithm_update': 0.00428303770703987, 'loss': 0.029563225869690277, 'time_step': 0.004484878585659018, 'init_value': -1.7354005575180054, 'ave_value': -1.7362601257036867, 'soft_opc': nan} step=27218




2022-04-16 02:32.54 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_27218.pt


Epoch 32/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:32.59 [info     ] FQE_20220416023035: epoch=32 step=28096 epoch=32 metrics={'time_sample_batch': 0.00014556649062519465, 'time_algorithm_update': 0.004472258150713319, 'loss': 0.028671632226459947, 'time_step': 0.004681357490174591, 'init_value': -1.7392584085464478, 'ave_value': -1.739950036772536, 'soft_opc': nan} step=28096




2022-04-16 02:32.59 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_28096.pt


Epoch 33/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.03 [info     ] FQE_20220416023035: epoch=33 step=28974 epoch=33 metrics={'time_sample_batch': 0.0001400893804424173, 'time_algorithm_update': 0.004213993804601439, 'loss': 0.02879818527986191, 'time_step': 0.004412340140288403, 'init_value': -1.750124216079712, 'ave_value': -1.7508753112040985, 'soft_opc': nan} step=28974




2022-04-16 02:33.03 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_28974.pt


Epoch 34/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.08 [info     ] FQE_20220416023035: epoch=34 step=29852 epoch=34 metrics={'time_sample_batch': 0.0001476870039329442, 'time_algorithm_update': 0.004615496122755603, 'loss': 0.029663681985756384, 'time_step': 0.004829249783909131, 'init_value': -1.7925469875335693, 'ave_value': -1.7931977181664227, 'soft_opc': nan} step=29852




2022-04-16 02:33.08 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_29852.pt


Epoch 35/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.14 [info     ] FQE_20220416023035: epoch=35 step=30730 epoch=35 metrics={'time_sample_batch': 0.00015107808732226118, 'time_algorithm_update': 0.004729251655196274, 'loss': 0.03169065988115091, 'time_step': 0.004944545261409125, 'init_value': -1.8089278936386108, 'ave_value': -1.809744958812335, 'soft_opc': nan} step=30730




2022-04-16 02:33.14 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_30730.pt


Epoch 36/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.19 [info     ] FQE_20220416023035: epoch=36 step=31608 epoch=36 metrics={'time_sample_batch': 0.00014640584350446905, 'time_algorithm_update': 0.004500647612204584, 'loss': 0.03188671774939139, 'time_step': 0.004708938012090522, 'init_value': -1.8292759656906128, 'ave_value': -1.830096120219216, 'soft_opc': nan} step=31608




2022-04-16 02:33.19 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_31608.pt


Epoch 37/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.24 [info     ] FQE_20220416023035: epoch=37 step=32486 epoch=37 metrics={'time_sample_batch': 0.00013965409002977515, 'time_algorithm_update': 0.004309186902839122, 'loss': 0.033639269637101246, 'time_step': 0.004508757102462317, 'init_value': -1.911298155784607, 'ave_value': -1.91226389748077, 'soft_opc': nan} step=32486




2022-04-16 02:33.24 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_32486.pt


Epoch 38/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.28 [info     ] FQE_20220416023035: epoch=38 step=33364 epoch=38 metrics={'time_sample_batch': 0.0001458187581194832, 'time_algorithm_update': 0.004480876520717334, 'loss': 0.03275476183350846, 'time_step': 0.004687369548649885, 'init_value': -1.8340646028518677, 'ave_value': -1.8349596584880514, 'soft_opc': nan} step=33364




2022-04-16 02:33.28 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_33364.pt


Epoch 39/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.33 [info     ] FQE_20220416023035: epoch=39 step=34242 epoch=39 metrics={'time_sample_batch': 0.0001397765578876052, 'time_algorithm_update': 0.004271186294208084, 'loss': 0.03299221052990315, 'time_step': 0.0044703258197236985, 'init_value': -1.8769742250442505, 'ave_value': -1.8779120190601373, 'soft_opc': nan} step=34242




2022-04-16 02:33.33 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_34242.pt


Epoch 40/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.38 [info     ] FQE_20220416023035: epoch=40 step=35120 epoch=40 metrics={'time_sample_batch': 0.00014861841136610862, 'time_algorithm_update': 0.004568259373884266, 'loss': 0.033787317322510585, 'time_step': 0.00478131298595246, 'init_value': -1.8500256538391113, 'ave_value': -1.8510389510944938, 'soft_opc': nan} step=35120




2022-04-16 02:33.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_35120.pt


Epoch 41/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.43 [info     ] FQE_20220416023035: epoch=41 step=35998 epoch=41 metrics={'time_sample_batch': 0.00014751402826678508, 'time_algorithm_update': 0.00457637836829948, 'loss': 0.032868853872077455, 'time_step': 0.004788737633776827, 'init_value': -1.838673710823059, 'ave_value': -1.8396480957839823, 'soft_opc': nan} step=35998




2022-04-16 02:33.43 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_35998.pt


Epoch 42/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.48 [info     ] FQE_20220416023035: epoch=42 step=36876 epoch=42 metrics={'time_sample_batch': 0.00013275841645607917, 'time_algorithm_update': 0.0039938780604299486, 'loss': 0.032860633971842725, 'time_step': 0.0041831970757938465, 'init_value': -1.8601897954940796, 'ave_value': -1.8611401822810296, 'soft_opc': nan} step=36876




2022-04-16 02:33.48 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_36876.pt


Epoch 43/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.52 [info     ] FQE_20220416023035: epoch=43 step=37754 epoch=43 metrics={'time_sample_batch': 0.0001437634161773195, 'time_algorithm_update': 0.00453113749249921, 'loss': 0.03276707478193376, 'time_step': 0.0047353419193102726, 'init_value': -1.8805840015411377, 'ave_value': -1.8815401830648164, 'soft_opc': nan} step=37754




2022-04-16 02:33.52 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_37754.pt


Epoch 44/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:33.58 [info     ] FQE_20220416023035: epoch=44 step=38632 epoch=44 metrics={'time_sample_batch': 0.00015553988193868232, 'time_algorithm_update': 0.004855139108888109, 'loss': 0.031820492851487796, 'time_step': 0.005076848836042886, 'init_value': -1.8011033535003662, 'ave_value': -1.8020631308334647, 'soft_opc': nan} step=38632




2022-04-16 02:33.58 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_38632.pt


Epoch 45/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.03 [info     ] FQE_20220416023035: epoch=45 step=39510 epoch=45 metrics={'time_sample_batch': 0.0001458649211701065, 'time_algorithm_update': 0.004554569313900889, 'loss': 0.029658214711368803, 'time_step': 0.004763913860625178, 'init_value': -1.7536741495132446, 'ave_value': -1.7545088305067054, 'soft_opc': nan} step=39510




2022-04-16 02:34.03 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_39510.pt


Epoch 46/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.08 [info     ] FQE_20220416023035: epoch=46 step=40388 epoch=46 metrics={'time_sample_batch': 0.00014555671492035677, 'time_algorithm_update': 0.004380924826602458, 'loss': 0.029396214036720113, 'time_step': 0.004588917067490841, 'init_value': -1.767454743385315, 'ave_value': -1.7680199972479629, 'soft_opc': nan} step=40388




2022-04-16 02:34.08 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_40388.pt


Epoch 47/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.13 [info     ] FQE_20220416023035: epoch=47 step=41266 epoch=47 metrics={'time_sample_batch': 0.0001473166133385313, 'time_algorithm_update': 0.004525725825229376, 'loss': 0.02933070735589395, 'time_step': 0.004734723605979276, 'init_value': -1.7582861185073853, 'ave_value': -1.7588222786023489, 'soft_opc': nan} step=41266




2022-04-16 02:34.13 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_41266.pt


Epoch 48/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.17 [info     ] FQE_20220416023035: epoch=48 step=42144 epoch=48 metrics={'time_sample_batch': 0.0001388313015392538, 'time_algorithm_update': 0.004317575272232091, 'loss': 0.0289198401462095, 'time_step': 0.00451718756169556, 'init_value': -1.7919765710830688, 'ave_value': -1.7926710625999824, 'soft_opc': nan} step=42144




2022-04-16 02:34.17 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_42144.pt


Epoch 49/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.22 [info     ] FQE_20220416023035: epoch=49 step=43022 epoch=49 metrics={'time_sample_batch': 0.0001453994890008809, 'time_algorithm_update': 0.00469416692208049, 'loss': 0.030152450756191034, 'time_step': 0.004901665761421916, 'init_value': -1.7787705659866333, 'ave_value': -1.7795009903001273, 'soft_opc': nan} step=43022




2022-04-16 02:34.22 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_43022.pt


Epoch 50/50:   0%|          | 0/878 [00:00<?, ?it/s]



2022-04-16 02:34.28 [info     ] FQE_20220416023035: epoch=50 step=43900 epoch=50 metrics={'time_sample_batch': 0.0001511989258959516, 'time_algorithm_update': 0.00497674344612417, 'loss': 0.029767772812324636, 'time_step': 0.005192360193691384, 'init_value': -1.7842823266983032, 'ave_value': -1.7847246024878274, 'soft_opc': nan} step=43900




2022-04-16 02:34.28 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220416023035/model_43900.pt


[(1,
  {'time_sample_batch': 0.00011914167425898593,
   'time_algorithm_update': 0.0029855746614634313,
   'loss': 0.0002461599849627247,
   'time_step': 0.003155057142427136,
   'init_value': -0.22843682765960693,
   'ave_value': -0.22831709708927517,
   'soft_opc': nan}),
 (2,
  {'time_sample_batch': 0.0001231980486719918,
   'time_algorithm_update': 0.0030236969233102297,
   'loss': 0.0014282974988773878,
   'time_step': 0.0031988797263838437,
   'init_value': -0.5504694581031799,
   'ave_value': -0.5503293672049991,
   'soft_opc': nan}),
 (3,
  {'time_sample_batch': 0.000123508698847951,
   'time_algorithm_update': 0.0031934004438222132,
   'loss': 0.004029400426500787,
   'time_step': 0.003368657922418894,
   'init_value': -0.787994921207428,
   'ave_value': -0.7878976750239928,
   'soft_opc': nan}),
 (4,
  {'time_sample_batch': 0.000124397473346128,
   'time_algorithm_update': 0.0031993448870057126,
   'loss': 0.007525220854168146,
   'time_step': 0.003378618822553978,
   'init_v

In [14]:
from d3rlpy.ope import FQE
# metrics to evaluate with
from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset, states_ope, actions_ope, rewards_ope = get_dataset([2,4,6,8], path="collected_data/rl_stoch_small.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })