# Sample Workflow for d3rlpy Experiments

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [2]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_deterministic.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data="collected_data/rl_stochastic_coarse.txt")
    samples.setting("coarse")
    states = []
    actions = []
    rewards = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    terminals = np.zeros(len(states))
    terminals[::1111] = 1 #episode length 100, change if necessary
    print(states.shape)
    dataset = d3rlpy.dataset.MDPDataset(states.numpy(), 
                                        actions.numpy(), 
                                        rewards.numpy(), terminals)
    return dataset

We can build the dataset from there, just like this, and split into train and test sets.

In [3]:
dataset = get_dataset([i for i in range(2000)])

start
[ 0.00000000e+00  7.95731469e+08  1.33108923e-02 -1.71999953e-02
 -2.33000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.14820287e-01 -3.71075138e-01 -4.58600996e-01]
Read chunk # 1 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.35989108e-01 -2.71999953e-02
 -2.16000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.42911681e-01  2.75094390e-01 -2.52970956e-01]
Read chunk # 2 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.83389108e-01 -4.83999953e-02
  6.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.54419292e-02  6.00000000e-01 -3.78270480e-02]
Read chunk # 3 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -7.09891077e-02  1.22000047e-02
  1.70999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.89277181e-01 -4.18905482e-01  1.70641772e-01]
Read chunk # 4 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.96989108e-01 -3.89999953e-02
  2.91999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.41872279

[ 0.00000000e+00  7.95731469e+08  3.91910892e-01 -2.13999953e-02
  1.43999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.67501944e-01  1.05838918e-01  6.00000000e-01]
Read chunk # 70 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  9.71089229e-03  3.48000047e-02
  2.89999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.01795012e-01 -2.41953066e-01  9.12517659e-02]
Read chunk # 71 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.61891077e-02  5.82000047e-02
  4.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.35818187e-02 -5.77263510e-01 -6.00000000e-01]
Read chunk # 72 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.50889108e-01  4.14000047e-02
 -2.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.67517925e-01  6.52870541e-02 -4.88590720e-01]
Read chunk # 73 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.87110892e-01 -4.03999953e-02
 -2.96000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-

[ 0.00000000e+00  7.95731469e+08 -1.86789108e-01 -3.51999953e-02
 -1.60000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.72017270e-01 -3.18717408e-01  6.00000000e-01]
Read chunk # 136 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.78889108e-01 -1.91999953e-02
 -4.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.18447329e-02  3.40893121e-01  1.82186680e-01]
Read chunk # 137 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.93410892e-01 -4.01999953e-02
 -6.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.00932762e-01 -1.21514178e-02 -6.00000000e-01]
Read chunk # 138 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.99489108e-01  1.26000047e-02
  1.52999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.64036624e-02  6.00000000e-01  5.96353633e-01]
Read chunk # 139 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.99489108e-01 -2.57999953e-02
 -2.27000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.674027

start
[ 0.00000000e+00  7.95731469e+08 -4.45389108e-01  1.16000047e-02
 -1.28000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 209 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.89610892e-01  3.14000047e-02
 -3.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.86888161e-01  3.10754398e-01 -6.00000000e-01]
Read chunk # 210 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.86089108e-01 -1.61999953e-02
  1.34999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.68432610e-03 -6.00000000e-01  6.00000000e-01]
Read chunk # 211 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.77510892e-01 -5.39999531e-03
  2.38999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.35632829e-01  6.00000000e-01  1.65353798e-01]
Read chunk # 212 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  5.86108923e-02  4.26000047e-02
  1.16999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.

[ 0.00000000e+00  7.95731469e+08  3.60110892e-01 -2.97999953e-02
 -1.43000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  6.25000554e-02]
Read chunk # 279 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.91289108e-01 -1.17999953e-02
 -1.31000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.24242184e-01  4.86107166e-01  2.67685047e-01]
Read chunk # 280 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.14689108e-01 -3.59999953e-02
  7.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.96737366e-02 -6.00000000e-01 -5.22606707e-01]
Read chunk # 281 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.48210892e-01  2.82000047e-02
  1.04999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.45877711e-01  6.00000000e-01  6.00000000e-01]
Read chunk # 282 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.20089108e-01 -1.99999953e-02
  1.88999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

start
[ 0.00000000e+00  7.95731469e+08 -2.06589108e-01 -2.39999531e-03
  2.61999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.46578804e-01  2.42838501e-02 -7.83211412e-03]
Read chunk # 349 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.60889108e-01  2.94000047e-02
  3.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.23398608e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 350 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.55010892e-01 -7.39999531e-03
  1.27999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  3.22355715e-01 -5.96000568e-01]
Read chunk # 351 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.00089108e-01  2.30000047e-02
  1.64999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.60881413e-01 -3.04025798e-02  6.00000000e-01]
Read chunk # 352 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.43110892e-01  4.50000047e-02
 -1.76000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.

[ 0.00000000e+00  7.95731469e+08 -3.31689108e-01 -1.79999531e-03
 -2.86000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.82941132e-01  3.71736835e-01  3.00452261e-01]
Read chunk # 418 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.06310892e-01  2.82000047e-02
 -1.97000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.56901514e-01  2.96098632e-02 -4.82362997e-01]
Read chunk # 419 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.11689108e-01 -1.41999953e-02
 -2.49000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 420 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.49610892e-01 -1.55999953e-02
  6.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.06803286e-01 -4.92400926e-01  3.30339187e-01]
Read chunk # 421 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.29289108e-01 -1.39999953e-02
  1.52999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

[ 0.00000000e+00  7.95731469e+08  2.92310892e-01  8.00000469e-03
  1.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  2.46058643e-01 -1.97253380e-01]
Read chunk # 486 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.37810892e-01  5.52000047e-02
  1.46999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.18954834e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 487 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.22589108e-01  1.84000047e-02
 -2.84000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -9.07977974e-02 -4.48442848e-01  2.20792176e-01]
Read chunk # 488 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.99108923e-02 -1.13999953e-02
  3.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01 -5.59279803e-01]
Read chunk # 489 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.47710892e-01  4.32000047e-02
 -2.37000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

[ 0.00000000e+00  7.95731469e+08  3.88610892e-01 -5.03999953e-02
  5.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.41859154e-01  3.03681151e-01  6.00000000e-01]
Read chunk # 553 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.82310892e-01 -5.81999953e-02
  1.24999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.81726270e-02 -3.87057360e-02  6.00000000e-01]
Read chunk # 554 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.41289108e-01 -5.59999953e-02
 -2.06000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -2.17932179e-01 -4.76872310e-01]
Read chunk # 555 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.41610892e-01 -7.99999531e-03
  1.52999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.41256473e-01  4.31354846e-01 -2.69003000e-01]
Read chunk # 556 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.65489108e-01 -3.47999953e-02
  2.75999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

[ 0.00000000e+00  7.95731469e+08  2.41610892e-01 -1.07999953e-02
  1.28999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  5.09594218e-02 -6.00000000e-01]
Read chunk # 630 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.77189108e-01  5.08000047e-02
 -2.57000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -3.02462570e-01 -6.00000000e-01]
Read chunk # 631 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.96810892e-01  5.88000047e-02
  8.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  3.50840564e-01  5.21832352e-01]
Read chunk # 632 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.50689108e-01  2.74000047e-02
 -1.33000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.26132636e-01  1.82074005e-01 -6.00000000e-01]
Read chunk # 633 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.95010892e-01  1.04000047e-02
  1.13999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

[ 0.00000000e+00  7.95731469e+08  3.46310892e-01 -3.27999953e-02
 -2.06000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.00100224e-01 -6.00000000e-01 -2.15369820e-01]
Read chunk # 694 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.66108923e-02 -1.19999953e-02
  3.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.84317688e-02  6.00000000e-01 -1.82522903e-02]
Read chunk # 695 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.90989108e-01  3.80000469e-03
  1.93999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  2.01908833e-01]
Read chunk # 696 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.33010892e-01 -1.49999953e-02
  2.30999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  3.98559068e-01]
Read chunk # 697 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.26989108e-01  2.20000469e-03
  3.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.210837

start
[ 0.00000000e+00  7.95731469e+08 -3.20889108e-01 -5.65999953e-02
  2.03999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.58777022e-01 -2.39496919e-01  6.00000000e-01]
Read chunk # 757 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.20710892e-01  2.20000469e-03
 -2.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.08278421e-01  4.12262819e-01  2.89961868e-02]
Read chunk # 758 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.06410892e-01 -8.39999531e-03
  1.15999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.65425179e-02 -6.00000000e-01]
Read chunk # 759 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  9.71089229e-03 -3.21999953e-02
 -1.00000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.38686736e-01 -2.80457395e-01 -5.43679120e-01]
Read chunk # 760 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.86989108e-01  1.46000047e-02
  2.32999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.

start
[ 0.00000000e+00  7.95731469e+08  3.89210892e-01 -3.77999953e-02
  9.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.27726942e-01  3.83714933e-01  2.79151253e-01]
Read chunk # 820 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -8.50891077e-02  2.90000047e-02
  4.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.03992697e-02 -6.00000000e-01  4.85523747e-01]
Read chunk # 821 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.63389108e-01 -5.55999953e-02
  5.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -3.15529542e-01  6.00000000e-01]
Read chunk # 822 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.26789108e-01  2.08000047e-02
  1.71999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.67389835e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 823 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.10389108e-01 -1.45999953e-02
  1.22999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.

start
[ 0.00000000e+00  7.95731469e+08  2.47310892e-01  4.58000047e-02
  2.21999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  4.29737858e-01  1.70441730e-01]
Read chunk # 888 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.96089108e-01 -3.41999953e-02
  3.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.75940698e-01  9.05217478e-02 -6.00000000e-01]
Read chunk # 889 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.86289108e-01  5.26000047e-02
  2.05999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  6.00000000e-01]
Read chunk # 890 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.11710892e-01 -2.43999953e-02
 -3.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.65685865e-01  3.83429099e-01 -1.97159321e-01]
Read chunk # 891 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.76389108e-01 -3.79999953e-02
 -1.74000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.

[ 0.00000000e+00  7.95731469e+08  2.56910892e-01  9.00000469e-03
 -1.43000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.78866071e-01  5.24619546e-01 -4.03786102e-01]
Read chunk # 929 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.01510892e-01 -5.97999953e-02
 -2.30000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  1.74641683e-01 -1.95879908e-01]
Read chunk # 930 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.89610892e-01  2.74000047e-02
  1.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 931 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.36189108e-01 -1.39999531e-03
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.15611131e-01  4.89060160e-01  2.37447268e-01]
Read chunk # 932 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.46108923e-02 -4.23999953e-02
  1.18999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.165962

[ 0.00000000e+00  7.95731469e+08 -1.81989108e-01  1.36000047e-02
 -2.81000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.22623553e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1003 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -8.29891077e-02 -4.17999953e-02
  7.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.54405633e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1004 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.22810892e-01 -5.85999953e-02
 -1.25000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -5.47797334e-01  3.96024423e-01]
Read chunk # 1005 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.94589108e-01 -3.87999953e-02
  2.96999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.65767397e-01  1.10764138e-01  6.00000000e-01]
Read chunk # 1006 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.90589108e-01  5.02000047e-02
 -2.43000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.39

[ 0.00000000e+00  7.95731469e+08 -1.11489108e-01  4.69199999e-09
  1.81999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  1.00562255e-01  3.66352336e-01]
Read chunk # 1041 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.04689108e-01  2.16000047e-02
 -9.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  5.76939098e-01]
Read chunk # 1042 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -9.04891077e-02  3.14000047e-02
 -1.90000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -5.88415302e-01  6.00000000e-01]
Read chunk # 1043 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.31189108e-01  2.44000047e-02
  2.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.82156101e-01  6.22289409e-02  1.17524946e-03]
Read chunk # 1044 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.05010892e-01  2.78000047e-02
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00

[ 0.00000000e+00  7.95731469e+08  3.10910892e-01  1.60000047e-02
 -2.56000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  5.99961357e-01  6.00000000e-01]
Read chunk # 1113 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.28810892e-01 -4.37999953e-02
 -2.24000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1114 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.81810892e-01 -9.59999531e-03
  1.50999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.87452599e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1115 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.23110892e-01  5.92000047e-02
 -2.17000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.51221713e-01  3.45954210e-01]
Read chunk # 1116 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.08389108e-01 -4.13999953e-02
  1.02999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.20

start
[ 0.00000000e+00  7.95731469e+08  3.61010892e-01  2.76000047e-02
  8.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -9.29191756e-03 -5.95798207e-01]
Read chunk # 1179 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.90108923e-02 -4.95999953e-02
 -6.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.36971309e-01 -3.53038401e-01 -3.80783274e-01]
Read chunk # 1180 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  5.83108923e-02 -2.59999531e-03
 -1.40000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.78323061e-01 -5.28182476e-01  2.22294664e-01]
Read chunk # 1181 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.47989108e-01 -4.77999953e-02
 -2.13000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  1.17734505e-01 -5.75481862e-01]
Read chunk # 1182 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.42889108e-01 -3.95999953e-02
  2.82999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08  4.23110892e-01  3.76000047e-02
 -2.39000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.83501989e-01 -5.66218229e-01 -6.00000000e-01]
Read chunk # 1254 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  6.19108923e-02  4.66000047e-02
 -4.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  5.99029570e-01]
Read chunk # 1255 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.72910892e-01 -4.39999531e-03
  2.83999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  6.00000000e-01]
Read chunk # 1256 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.50389108e-01  3.70000047e-02
  2.96999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.07823061e-01  3.95542407e-01  1.99808921e-02]
Read chunk # 1257 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -5.71891077e-02 -1.05999953e-02
  3.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00

start
[ 0.00000000e+00  7.95731469e+08  2.29108923e-02 -4.49999953e-02
  2.40999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -1.30160881e-01  5.89425038e-01]
Read chunk # 1294 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.52989108e-01 -1.23999953e-02
 -7.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.64192424e-01  5.34820497e-01  6.00000000e-01]
Read chunk # 1295 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.17789108e-01 -5.39999531e-03
 -1.24000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01 -6.00000000e-01]
Read chunk # 1296 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -5.88910771e-03  2.82000047e-02
  2.46999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01  2.30206169e-01]
Read chunk # 1297 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.71510892e-01 -1.37999953e-02
  4.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08  1.21108923e-02 -5.39999953e-02
  3.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -7.15996498e-02 -5.04238742e-01]
Read chunk # 1361 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.09689108e-01 -9.59999531e-03
 -2.88000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.50021126e-01  5.87254850e-01  6.00000000e-01]
Read chunk # 1362 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -5.86891077e-02 -4.19999531e-03
  1.49999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.25042152e-01  2.34098181e-01 -8.02202332e-02]
Read chunk # 1363 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -7.08910771e-03  5.56000047e-02
 -2.98000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.95400961e-01  6.03934047e-02 -4.77218793e-01]
Read chunk # 1364 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.41210892e-01  4.80000469e-03
 -2.93000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.89

[ 0.00000000e+00  7.95731469e+08 -1.71189108e-01 -1.65999953e-02
  2.84999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.26507942e-02 -6.00000000e-01 -6.00000000e-01]
Read chunk # 1399 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.33108923e-02 -4.41999953e-02
 -1.89000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.18721340e-01  6.00000000e-01 -5.00722720e-01]
Read chunk # 1400 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.29110892e-01 -5.03999953e-02
 -1.95000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.46901709e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1401 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.07510892e-01  2.40000047e-02
 -1.75000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  1.18096718e-02  6.00000000e-01]
Read chunk # 1402 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.32789108e-01 -2.39999531e-03
 -3.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.21

[ 0.00000000e+00  7.95731469e+08  1.97210892e-01 -2.69999953e-02
  1.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.42729685e-02 -3.14636604e-01]
Read chunk # 1437 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.74310892e-01  4.78000047e-02
  5.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.35825996e-01 -3.23539851e-01  6.00000000e-01]
Read chunk # 1438 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.40289108e-01  3.60000469e-03
 -2.09000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  1.14544617e-01]
Read chunk # 1439 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.93210892e-01  4.74000047e-02
  1.50999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.10871077e-01  6.00000000e-01  3.22352822e-01]
Read chunk # 1440 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.66108923e-02 -1.71999953e-02
  2.43999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.56

[ 0.00000000e+00  7.95731469e+08  3.43610892e-01 -5.79999953e-02
 -5.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.01805903e-01  6.00000000e-01]
Read chunk # 1503 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.40889108e-01  2.12000047e-02
  2.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1504 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.14789108e-01 -4.65999953e-02
 -2.28000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  1.91325577e-01]
Read chunk # 1505 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.25710892e-01 -5.19999953e-02
 -2.27000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.74537816e-01  5.61299155e-01  4.05705190e-01]
Read chunk # 1506 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.45610892e-01 -3.75999953e-02
 -2.24000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.53

[ 0.00000000e+00  7.95731469e+08 -1.15689108e-01  4.48000047e-02
 -3.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  4.57295464e-02 -1.69460219e-01]
Read chunk # 1566 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.66910892e-01 -2.03999953e-02
 -4.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.81904629e-01 -3.02410134e-01 -6.00000000e-01]
Read chunk # 1567 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.85910892e-01 -4.67999953e-02
  8.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  3.47754623e-01  2.45804068e-01]
Read chunk # 1568 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.45108923e-02  1.34000047e-02
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.59992216e-01  6.00000000e-01  4.51847807e-01]
Read chunk # 1569 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.20610892e-01  1.80000469e-03
 -1.75000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.12

[ 0.00000000e+00  7.95731469e+08 -4.15089108e-01 -3.33999953e-02
 -1.34000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.82065309e-01 -6.00000000e-01 -3.87486904e-01]
Read chunk # 1604 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.99889108e-01 -3.65999953e-02
  1.48999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.74861459e-01 -1.62111065e-01 -1.65047469e-01]
Read chunk # 1605 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.54891077e-02  4.80000469e-03
  1.17999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.33078783e-02  4.84508249e-01  3.74543737e-02]
Read chunk # 1606 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.19589108e-01  4.46000047e-02
 -2.28000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.14427524e-01  1.32789194e-01  2.44731451e-01]
Read chunk # 1607 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.86891077e-02  1.76000047e-02
 -2.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.85

[ 0.00000000e+00  7.95731469e+08 -2.65689108e-01 -4.91999953e-02
  2.93999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01  6.00000000e-01]
Read chunk # 1667 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.66789108e-01  1.44000047e-02
 -2.45000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.01211239e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1668 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.00610892e-01  1.80000047e-02
  4.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.87275273e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1669 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.67489108e-01 -4.63999953e-02
 -1.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.31070251e-01 -6.00000000e-01]
Read chunk # 1670 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.12889108e-01  3.00000469e-03
 -5.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.10

[ 0.00000000e+00  7.95731469e+08  5.51089229e-03  3.00000469e-03
  7.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.63770125e-01 -1.95049547e-01  6.00000000e-01]
Read chunk # 1738 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.35210892e-01  6.20000469e-03
 -1.76000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.12514047e-01  4.48587912e-01  6.00000000e-01]
Read chunk # 1739 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.90089108e-01  3.38000047e-02
  1.15999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  1.37654753e-02]
Read chunk # 1740 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.66310892e-01  2.88000047e-02
 -8.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.50212857e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1741 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.94089108e-01 -2.79999953e-02
  3.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00

[ 0.00000000e+00  7.95731469e+08 -2.53089108e-01 -2.23999953e-02
 -2.65000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.45796232e-01  6.00000000e-01  5.22291465e-01]
Read chunk # 1808 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.34110892e-01  2.46000047e-02
 -2.43000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01  6.00000000e-01]
Read chunk # 1809 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  6.10108923e-02  3.64000047e-02
 -1.75000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  2.81122863e-01 -6.30991631e-02]
Read chunk # 1810 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.94089108e-01  4.52000047e-02
  1.13999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.49745358e-02  1.48420740e-01]
Read chunk # 1811 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.53710892e-01 -3.13999953e-02
  2.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00

Read chunk # 1881 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.83210892e-01  1.08000047e-02
 -2.84000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -3.98740866e-01 -1.01644060e-01]
Read chunk # 1882 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.14410892e-01 -3.43999953e-02
 -2.11000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.61322817e-01  2.57372467e-01  1.11872186e-01]
Read chunk # 1883 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.36210892e-01 -2.21999953e-02
 -2.67000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.25049241e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1884 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.82810892e-01 -2.45999953e-02
  1.37999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.95648565e-01 -4.47188535e-01 -4.78312842e-01]
Read chunk # 1885 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.71108923e-02  1.14000047e-02
 -1.65000134e-02  0.00000000e+00 -5

start
[ 0.00000000e+00  7.95731469e+08  3.02210892e-01  3.80000469e-03
  2.18999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  8.18392211e-02  3.14575602e-01]
Read chunk # 1942 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.61089229e-03 -1.79999953e-02
  1.61999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.16385804e-01  6.00000000e-01  5.39608749e-01]
Read chunk # 1943 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.06910892e-01 -4.77999953e-02
 -1.97000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.48319135e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1944 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.46889108e-01  1.04000047e-02
 -2.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1945 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  5.74108923e-02 -1.21999953e-02
 -1.94000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


torch.Size([2220000, 6])


In [4]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -266.96417,
 'std': 124.46084,
 'min': -676.4719,
 'max': 0.0,
 'histogram': (array([  1,   4,  17,  16,  37,  56,  64,  93, 136, 143, 154, 192, 219,
         184, 177, 163, 164, 150,  27,   2]),
  array([-676.4719  , -642.6483  , -608.8247  , -575.00116 , -541.17755 ,
         -507.35394 , -473.53033 , -439.70676 , -405.88315 , -372.05957 ,
         -338.23596 , -304.41235 , -270.58878 , -236.76517 , -202.94157 ,
         -169.11798 , -135.29439 , -101.47079 ,  -67.647194,  -33.823597,
            0.      ], dtype=float32))}

In [5]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

## Setting up an Algorithm

In [6]:
from d3rlpy.algos import CQL

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

model = CQL(q_func_factory='mean', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=False) #change it to true if you have one
model.build_with_dataset(dataset)

In [7]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

0.0751052818464132


In [8]:
%load_ext tensorboard
%tensorboard --logdir runs

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
Traceback (most recent call last):
  File "/home/dasc/anaconda3/envs/jbreeden3.10/bin/tensorboard", line 6, in <module>
    from tensorboard.main import run_main
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/main.py", line 40, in <module>
    from tensorboard import default
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/default.py", line 38, in <module>
    from tensorboard.plugins.audio import audio_plugin
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugins/audio/audio_plugin.py", line 25, in <module>
    from tensorboard import plugin_util
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugin_util.py", line 21, in <module>
    from tensorboard._vendor import bleach
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorb

In [9]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=40, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-21 11:50.13 [debug    ] RoundIterator is selected.
2022-04-21 11:50.13 [info     ] Directory is created at d3rlpy_logs/CQL_20220421115013
2022-04-21 11:50.13 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-21 11:50.13 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-21 11:50.13 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220421115013/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conser

Epoch 1/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:00.24 [info     ] CQL_20220421115013: epoch=1 step=6928 epoch=1 metrics={'time_sample_batch': 0.00033248131600042963, 'time_algorithm_update': 0.08607760863271101, 'temp_loss': 3.2389753310668, 'temp': 0.7392859111086187, 'alpha_loss': -13.069082023227462, 'alpha': 1.4248492442604852, 'critic_loss': 22.17708755668261, 'actor_loss': 40.39207670698833, 'time_step': 0.08674688960746034, 'td_error': 16.560567173508446, 'init_value': -82.84752655029297, 'ave_value': -83.19478249575641} step=6928
2022-04-21 12:00.24 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_6928.pt


Epoch 2/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:11.49 [info     ] CQL_20220421115013: epoch=2 step=13856 epoch=2 metrics={'time_sample_batch': 0.00035565465206912705, 'time_algorithm_update': 0.09660828736445094, 'temp_loss': 0.7820279873513959, 'temp': 0.42524828426281797, 'alpha_loss': -10.031452970154472, 'alpha': 2.767355453334835, 'critic_loss': 54.5907691992458, 'actor_loss': 122.31750482554799, 'time_step': 0.09735213743208737, 'td_error': 44.08262380888173, 'init_value': -162.3415985107422, 'ave_value': -162.9516385855288} step=13856
2022-04-21 12:11.49 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_13856.pt


Epoch 3/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:22.33 [info     ] CQL_20220421115013: epoch=3 step=20784 epoch=3 metrics={'time_sample_batch': 0.00034207487629412504, 'time_algorithm_update': 0.09070611182454001, 'temp_loss': -0.031370135485415084, 'temp': 0.3827949507768518, 'alpha_loss': -1.6990500891197267, 'alpha': 4.318377103389824, 'critic_loss': 109.7870766075186, 'actor_loss': 201.58327104643365, 'time_step': 0.09143093094126457, 'td_error': 73.65508703299268, 'init_value': -237.88648986816406, 'ave_value': -238.62776678081892} step=20784
2022-04-21 12:22.33 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_20784.pt


Epoch 4/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:34.03 [info     ] CQL_20220421115013: epoch=4 step=27712 epoch=4 metrics={'time_sample_batch': 0.00035602556365191523, 'time_algorithm_update': 0.09727242227652332, 'temp_loss': 0.007291995188477825, 'temp': 0.40115584704517354, 'alpha_loss': 1.4236882650341804, 'alpha': 3.870152088084342, 'critic_loss': 174.9931096444229, 'actor_loss': 269.88629097971574, 'time_step': 0.09802449998761985, 'td_error': 102.32785023277717, 'init_value': -294.13165283203125, 'ave_value': -295.0563412412695} step=27712
2022-04-21 12:34.03 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_27712.pt


Epoch 5/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:45.01 [info     ] CQL_20220421115013: epoch=5 step=34640 epoch=5 metrics={'time_sample_batch': 0.0003524545158580049, 'time_algorithm_update': 0.09285861453338252, 'temp_loss': 0.006121914672536784, 'temp': 0.38508266993899404, 'alpha_loss': 0.47146528850017305, 'alpha': 3.302177274543749, 'critic_loss': 226.41000870014724, 'actor_loss': 312.78629176115606, 'time_step': 0.0936035912045155, 'td_error': 125.98956129866822, 'init_value': -324.8154602050781, 'ave_value': -325.92431236631376} step=34640
2022-04-21 12:45.01 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_34640.pt


Epoch 6/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 12:55.52 [info     ] CQL_20220421115013: epoch=6 step=41568 epoch=6 metrics={'time_sample_batch': 0.0003440085858726061, 'time_algorithm_update': 0.09165968658191778, 'temp_loss': 0.001111056212601725, 'temp': 0.3771757337955143, 'alpha_loss': 0.2012662655432428, 'alpha': 3.1290670004232375, 'critic_loss': 261.56188339311586, 'actor_loss': 338.74686862690066, 'time_step': 0.09238530062225069, 'td_error': 138.4893357259216, 'init_value': -344.94696044921875, 'ave_value': -346.2934297870945} step=41568
2022-04-21 12:55.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_41568.pt


Epoch 7/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 13:07.05 [info     ] CQL_20220421115013: epoch=7 step=48496 epoch=7 metrics={'time_sample_batch': 0.0003498173200505986, 'time_algorithm_update': 0.09488137407060583, 'temp_loss': 0.0009336879437382809, 'temp': 0.37271618468828893, 'alpha_loss': 0.1075093630947577, 'alpha': 3.0520785240323107, 'critic_loss': 287.21312825096123, 'actor_loss': 356.10531039601386, 'time_step': 0.09562255528176905, 'td_error': 150.50713562510649, 'init_value': -360.363037109375, 'ave_value': -361.95446174411944} step=48496
2022-04-21 13:07.05 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_48496.pt


Epoch 8/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 13:17.54 [info     ] CQL_20220421115013: epoch=8 step=55424 epoch=8 metrics={'time_sample_batch': 0.00034577483805993414, 'time_algorithm_update': 0.09134256788131546, 'temp_loss': 0.0009413873349982109, 'temp': 0.37095305992088196, 'alpha_loss': 0.14969404571060269, 'alpha': 2.9683321133496854, 'critic_loss': 308.30287659849637, 'actor_loss': 369.25994686716973, 'time_step': 0.09207490275831201, 'td_error': 159.3897181846129, 'init_value': -371.3344421386719, 'ave_value': -373.25395297103745} step=55424
2022-04-21 13:17.54 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_55424.pt


Epoch 9/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 13:28.55 [info     ] CQL_20220421115013: epoch=9 step=62352 epoch=9 metrics={'time_sample_batch': 0.00034685935351920185, 'time_algorithm_update': 0.09322074854071091, 'temp_loss': 0.0015006012839263016, 'temp': 0.3692227552664266, 'alpha_loss': 0.10392393775057408, 'alpha': 2.896740374505933, 'critic_loss': 321.75856122801264, 'actor_loss': 377.914640748198, 'time_step': 0.09396069897790428, 'td_error': 165.12380180463168, 'init_value': -377.9372253417969, 'ave_value': -380.1420800606324} step=62352
2022-04-21 13:28.55 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_62352.pt


Epoch 10/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 13:40.35 [info     ] CQL_20220421115013: epoch=10 step=69280 epoch=10 metrics={'time_sample_batch': 0.00035731501309480733, 'time_algorithm_update': 0.09862417746636389, 'temp_loss': 0.0012681218558547622, 'temp': 0.3684202055500575, 'alpha_loss': 0.1515770615696167, 'alpha': 2.8177254600213804, 'critic_loss': 333.2292992143653, 'actor_loss': 385.1046184425266, 'time_step': 0.09938575045203502, 'td_error': 172.1211399006587, 'init_value': -384.555419921875, 'ave_value': -386.9004716346672} step=69280
2022-04-21 13:40.35 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_69280.pt


Epoch 11/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 13:52.10 [info     ] CQL_20220421115013: epoch=11 step=76208 epoch=11 metrics={'time_sample_batch': 0.0003585660567735139, 'time_algorithm_update': 0.09796939298132, 'temp_loss': 1.6746255657882978e-05, 'temp': 0.36547123341375903, 'alpha_loss': 0.03503560832092896, 'alpha': 2.7717734593166767, 'critic_loss': 342.99529306291157, 'actor_loss': 390.85300527929434, 'time_step': 0.09872808147385269, 'td_error': 175.1543666792613, 'init_value': -389.1753845214844, 'ave_value': -391.74348497641625} step=76208
2022-04-21 13:52.10 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_76208.pt


Epoch 12/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:03.28 [info     ] CQL_20220421115013: epoch=12 step=83136 epoch=12 metrics={'time_sample_batch': 0.00034692171126535253, 'time_algorithm_update': 0.09554103970114684, 'temp_loss': -0.0009712803984914775, 'temp': 0.36610473489375917, 'alpha_loss': -0.006250801446922763, 'alpha': 2.761431029845055, 'critic_loss': 348.80758486202905, 'actor_loss': 395.52045586918575, 'time_step': 0.09626424013191778, 'td_error': 177.2163282899627, 'init_value': -393.40667724609375, 'ave_value': -396.11357302536834} step=83136
2022-04-21 14:03.28 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_83136.pt


Epoch 13/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:14.16 [info     ] CQL_20220421115013: epoch=13 step=90064 epoch=13 metrics={'time_sample_batch': 0.00032177347233884607, 'time_algorithm_update': 0.09150119697111583, 'temp_loss': -0.0009721430900665839, 'temp': 0.36759291317821513, 'alpha_loss': 0.10278147696235192, 'alpha': 2.7262383962338297, 'critic_loss': 350.300065931492, 'actor_loss': 397.83247016483864, 'time_step': 0.0921649882617228, 'td_error': 176.49191251676035, 'init_value': -393.9203796386719, 'ave_value': -396.94941175000304} step=90064
2022-04-21 14:14.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_90064.pt


Epoch 14/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:25.09 [info     ] CQL_20220421115013: epoch=14 step=96992 epoch=14 metrics={'time_sample_batch': 0.0003249383344385971, 'time_algorithm_update': 0.09214452662726748, 'temp_loss': 0.0038507886726786137, 'temp': 0.3662289146282312, 'alpha_loss': -0.022483001650088843, 'alpha': 2.70475452997255, 'critic_loss': 348.53530670616425, 'actor_loss': 398.7760192747755, 'time_step': 0.0928169010899634, 'td_error': 174.83922750364198, 'init_value': -394.4867248535156, 'ave_value': -397.6879507944949} step=96992
2022-04-21 14:25.09 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_96992.pt


Epoch 15/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:35.58 [info     ] CQL_20220421115013: epoch=15 step=103920 epoch=15 metrics={'time_sample_batch': 0.00032353036398149803, 'time_algorithm_update': 0.09161030485641047, 'temp_loss': -0.002446396822202843, 'temp': 0.3649185357430378, 'alpha_loss': 0.0963416375972976, 'alpha': 2.684158604187723, 'critic_loss': 345.85602966976495, 'actor_loss': 398.6164425019579, 'time_step': 0.09227670863374, 'td_error': 173.41514014054846, 'init_value': -393.3376159667969, 'ave_value': -396.95405322141903} step=103920
2022-04-21 14:35.58 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_103920.pt


Epoch 16/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:46.51 [info     ] CQL_20220421115013: epoch=16 step=110848 epoch=16 metrics={'time_sample_batch': 0.00032688281552620903, 'time_algorithm_update': 0.09209509551800427, 'temp_loss': 0.0008673044141716975, 'temp': 0.36439287055963454, 'alpha_loss': 0.08642494658226715, 'alpha': 2.632425872165942, 'critic_loss': 342.2782390092179, 'actor_loss': 397.2269722848236, 'time_step': 0.092767263945476, 'td_error': 172.1287611125964, 'init_value': -392.37652587890625, 'ave_value': -396.310009416185} step=110848
2022-04-21 14:46.51 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_110848.pt


Epoch 17/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 14:57.55 [info     ] CQL_20220421115013: epoch=17 step=117776 epoch=17 metrics={'time_sample_batch': 0.0003239320070583881, 'time_algorithm_update': 0.09370522135675091, 'temp_loss': 0.0006810816209002422, 'temp': 0.3629982714685364, 'alpha_loss': 0.03625618569291115, 'alpha': 2.5694400906287496, 'critic_loss': 339.46451183857596, 'actor_loss': 396.2262317269834, 'time_step': 0.09437147667308877, 'td_error': 169.1923493781587, 'init_value': -386.9462585449219, 'ave_value': -391.41588389656135} step=117776
2022-04-21 14:57.55 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_117776.pt


Epoch 18/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 15:09.05 [info     ] CQL_20220421115013: epoch=18 step=124704 epoch=18 metrics={'time_sample_batch': 0.0003292814895132122, 'time_algorithm_update': 0.09451565308190933, 'temp_loss': 0.0006029767892148677, 'temp': 0.36350554710912236, 'alpha_loss': 0.054510936099135335, 'alpha': 2.5250505440573767, 'critic_loss': 333.70101388745167, 'actor_loss': 393.90564168353006, 'time_step': 0.09519228270367718, 'td_error': 166.65178552178062, 'init_value': -386.75897216796875, 'ave_value': -391.73727577085754} step=124704
2022-04-21 15:09.05 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_124704.pt


Epoch 19/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 15:19.55 [info     ] CQL_20220421115013: epoch=19 step=131632 epoch=19 metrics={'time_sample_batch': 0.00032722857064518035, 'time_algorithm_update': 0.09183631651533669, 'temp_loss': -0.0003157147547747944, 'temp': 0.3656944708412569, 'alpha_loss': 0.11574303269185976, 'alpha': 2.48160609984921, 'critic_loss': 327.86961024166806, 'actor_loss': 391.5019153744739, 'time_step': 0.09250937312772588, 'td_error': 164.24851129191117, 'init_value': -383.9179992675781, 'ave_value': -389.5521777916642} step=131632
2022-04-21 15:19.55 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_131632.pt


Epoch 20/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 15:30.50 [info     ] CQL_20220421115013: epoch=20 step=138560 epoch=20 metrics={'time_sample_batch': 0.0003276263249663778, 'time_algorithm_update': 0.09253235932172858, 'temp_loss': -0.000803443935323898, 'temp': 0.36405753765435883, 'alpha_loss': 0.07446995459914342, 'alpha': 2.4372971633290033, 'critic_loss': 324.1151957347404, 'actor_loss': 390.28327567913243, 'time_step': 0.09320770847604677, 'td_error': 161.1472516706148, 'init_value': -383.5263366699219, 'ave_value': -389.4972804302525} step=138560
2022-04-21 15:30.50 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_138560.pt


Epoch 21/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 15:41.37 [info     ] CQL_20220421115013: epoch=21 step=145488 epoch=21 metrics={'time_sample_batch': 0.00032325918349603033, 'time_algorithm_update': 0.09138222761434991, 'temp_loss': 0.0023087718405136537, 'temp': 0.36506183394693337, 'alpha_loss': 0.07145465938295928, 'alpha': 2.3780580226323345, 'critic_loss': 317.14554573805447, 'actor_loss': 388.71824369804966, 'time_step': 0.09205060460826266, 'td_error': 156.21058388132695, 'init_value': -379.9534912109375, 'ave_value': -386.17341501469656} step=145488
2022-04-21 15:41.37 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_145488.pt


Epoch 22/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 15:52.14 [info     ] CQL_20220421115013: epoch=22 step=152416 epoch=22 metrics={'time_sample_batch': 0.0003211216755882422, 'time_algorithm_update': 0.09031135914094454, 'temp_loss': -0.0006649349533843079, 'temp': 0.36311562696157784, 'alpha_loss': 0.06671501420333334, 'alpha': 2.343121531963624, 'critic_loss': 310.42988414590957, 'actor_loss': 386.7907238931634, 'time_step': 0.09097466388971095, 'td_error': 156.54176283794678, 'init_value': -380.8991394042969, 'ave_value': -387.2388916392971} step=152416
2022-04-21 15:52.14 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_152416.pt


Epoch 23/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:03.12 [info     ] CQL_20220421115013: epoch=23 step=159344 epoch=23 metrics={'time_sample_batch': 0.00032859042665424305, 'time_algorithm_update': 0.0929140043176074, 'temp_loss': -0.0006962519806288678, 'temp': 0.36287562139160223, 'alpha_loss': 0.07117197230760489, 'alpha': 2.3050500192694523, 'critic_loss': 307.8786829575793, 'actor_loss': 385.1747925562341, 'time_step': 0.09359419893034611, 'td_error': 152.89913267123512, 'init_value': -378.15374755859375, 'ave_value': -384.55256071310646} step=159344
2022-04-21 16:03.12 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_159344.pt


Epoch 24/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:14.00 [info     ] CQL_20220421115013: epoch=24 step=166272 epoch=24 metrics={'time_sample_batch': 0.0003260050923939941, 'time_algorithm_update': 0.0915383514507263, 'temp_loss': -0.00013338399147770774, 'temp': 0.36235369468327205, 'alpha_loss': 0.059480640033294806, 'alpha': 2.279450998742641, 'critic_loss': 305.45263398801757, 'actor_loss': 383.6212434614503, 'time_step': 0.09220948853200908, 'td_error': 153.7780558300682, 'init_value': -376.93389892578125, 'ave_value': -383.46135572522826} step=166272
2022-04-21 16:14.00 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_166272.pt


Epoch 25/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:24.45 [info     ] CQL_20220421115013: epoch=25 step=173200 epoch=25 metrics={'time_sample_batch': 0.00032344050963536145, 'time_algorithm_update': 0.0909941667910261, 'temp_loss': 0.0010177633871230798, 'temp': 0.3617817800552272, 'alpha_loss': -0.015428371302688449, 'alpha': 2.249002670546602, 'critic_loss': 305.4925143580376, 'actor_loss': 382.35815847277917, 'time_step': 0.09166323983779405, 'td_error': 151.9830163939512, 'init_value': -372.828125, 'ave_value': -379.50794296732033} step=173200
2022-04-21 16:24.45 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_173200.pt


Epoch 26/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:35.27 [info     ] CQL_20220421115013: epoch=26 step=180128 epoch=26 metrics={'time_sample_batch': 0.0003240840126680722, 'time_algorithm_update': 0.0907728812435758, 'temp_loss': 0.0012598659248829562, 'temp': 0.3618432446866883, 'alpha_loss': 0.05611779912422367, 'alpha': 2.232187041303027, 'critic_loss': 305.37339683809677, 'actor_loss': 381.0610024934553, 'time_step': 0.09144208100183455, 'td_error': 151.055012673217, 'init_value': -371.3172912597656, 'ave_value': -378.1840906400251} step=180128
2022-04-21 16:35.27 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_180128.pt


Epoch 27/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:46.06 [info     ] CQL_20220421115013: epoch=27 step=187056 epoch=27 metrics={'time_sample_batch': 0.00032276747959040055, 'time_algorithm_update': 0.09020295370946581, 'temp_loss': -0.000996722697547054, 'temp': 0.3611607913779607, 'alpha_loss': 0.02553116674774932, 'alpha': 2.199727045780792, 'critic_loss': 303.8417127645181, 'actor_loss': 379.27491507409076, 'time_step': 0.0908678944887245, 'td_error': 151.1714805092278, 'init_value': -372.2703857421875, 'ave_value': -379.3775645807627} step=187056
2022-04-21 16:46.06 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_187056.pt


Epoch 28/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 16:56.47 [info     ] CQL_20220421115013: epoch=28 step=193984 epoch=28 metrics={'time_sample_batch': 0.00032247547877578206, 'time_algorithm_update': 0.0905953502889043, 'temp_loss': -0.002661704557881138, 'temp': 0.3603977707083451, 'alpha_loss': 0.042370417688789214, 'alpha': 2.183184081907911, 'critic_loss': 303.5562005534458, 'actor_loss': 377.9231832649636, 'time_step': 0.09126210563727945, 'td_error': 150.64589724695094, 'init_value': -369.5162353515625, 'ave_value': -376.9608241176777} step=193984
2022-04-21 16:56.47 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_193984.pt


Epoch 29/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 17:07.22 [info     ] CQL_20220421115013: epoch=29 step=200912 epoch=29 metrics={'time_sample_batch': 0.00032239536352576066, 'time_algorithm_update': 0.08971680284372378, 'temp_loss': 0.0030905455615364476, 'temp': 0.36107396472925124, 'alpha_loss': 0.049012036066280594, 'alpha': 2.150564372401865, 'critic_loss': 303.9037359236432, 'actor_loss': 377.110741952275, 'time_step': 0.09038400519252099, 'td_error': 152.11518715758464, 'init_value': -367.2611083984375, 'ave_value': -375.005281074438} step=200912
2022-04-21 17:07.22 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_200912.pt


Epoch 30/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 17:18.11 [info     ] CQL_20220421115013: epoch=30 step=207840 epoch=30 metrics={'time_sample_batch': 0.0003254886794035209, 'time_algorithm_update': 0.09166120126946693, 'temp_loss': -0.0005623962813918425, 'temp': 0.3596801210934722, 'alpha_loss': -0.004097521271960137, 'alpha': 2.143575791433006, 'critic_loss': 303.5697760577152, 'actor_loss': 375.7562272113685, 'time_step': 0.09233160406419917, 'td_error': 151.80717670504805, 'init_value': -366.3521728515625, 'ave_value': -374.31117526765775} step=207840
2022-04-21 17:18.11 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_207840.pt


Epoch 31/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 17:28.45 [info     ] CQL_20220421115013: epoch=31 step=214768 epoch=31 metrics={'time_sample_batch': 0.0003232552259128033, 'time_algorithm_update': 0.0894827664793226, 'temp_loss': 0.00012451626421447078, 'temp': 0.36164807951116534, 'alpha_loss': 0.053641710409832885, 'alpha': 2.126187172518316, 'critic_loss': 301.85631993123894, 'actor_loss': 373.8286955307042, 'time_step': 0.09015002082732203, 'td_error': 149.9945924684128, 'init_value': -364.87469482421875, 'ave_value': -373.04969961284945} step=214768
2022-04-21 17:28.45 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_214768.pt


Epoch 32/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 17:39.25 [info     ] CQL_20220421115013: epoch=32 step=221696 epoch=32 metrics={'time_sample_batch': 0.0003223392002577044, 'time_algorithm_update': 0.09044677290988024, 'temp_loss': 0.0016147525276056884, 'temp': 0.3578099250991336, 'alpha_loss': 0.03625635007326538, 'alpha': 2.090304651037375, 'critic_loss': 300.0765135741812, 'actor_loss': 371.7944967212633, 'time_step': 0.09111578439041869, 'td_error': 148.13623125843387, 'init_value': -361.79571533203125, 'ave_value': -370.2364439860851} step=221696
2022-04-21 17:39.25 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_221696.pt


Epoch 33/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 17:50.03 [info     ] CQL_20220421115013: epoch=33 step=228624 epoch=33 metrics={'time_sample_batch': 0.0003229059605895785, 'time_algorithm_update': 0.09017169568472591, 'temp_loss': 0.00011560220318264597, 'temp': 0.3566744191653803, 'alpha_loss': 0.010538321456203615, 'alpha': 2.0762205004829735, 'critic_loss': 296.5471138163334, 'actor_loss': 368.3070477756562, 'time_step': 0.09084160075314326, 'td_error': 146.4533405497028, 'init_value': -356.5564880371094, 'ave_value': -365.1387829347559} step=228624
2022-04-21 17:50.03 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_228624.pt


Epoch 34/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:00.44 [info     ] CQL_20220421115013: epoch=34 step=235552 epoch=34 metrics={'time_sample_batch': 0.0003215058020576318, 'time_algorithm_update': 0.09044981925096578, 'temp_loss': -0.0010950191072285106, 'temp': 0.35856941925667174, 'alpha_loss': 0.04142675211807505, 'alpha': 2.071470718013497, 'critic_loss': 291.8196984599699, 'actor_loss': 364.3278335016264, 'time_step': 0.09111290650472355, 'td_error': 145.6282216482493, 'init_value': -352.9852600097656, 'ave_value': -361.8486650639182} step=235552
2022-04-21 18:00.44 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_235552.pt


Epoch 35/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:11.30 [info     ] CQL_20220421115013: epoch=35 step=242480 epoch=35 metrics={'time_sample_batch': 0.00032443444805938155, 'time_algorithm_update': 0.09131960811708596, 'temp_loss': -0.0015185632508430376, 'temp': 0.3589464257671293, 'alpha_loss': 0.014045644557186793, 'alpha': 2.0533931623567887, 'critic_loss': 286.35663724558185, 'actor_loss': 360.7822869897715, 'time_step': 0.09199312780112648, 'td_error': 141.00195396904383, 'init_value': -348.13836669921875, 'ave_value': -356.93436600269285} step=242480
2022-04-21 18:11.30 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_242480.pt


Epoch 36/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:22.16 [info     ] CQL_20220421115013: epoch=36 step=249408 epoch=36 metrics={'time_sample_batch': 0.0003249190971427371, 'time_algorithm_update': 0.0910941548559462, 'temp_loss': 0.00042742281694714774, 'temp': 0.35977009587058706, 'alpha_loss': -0.013534218509742756, 'alpha': 2.0500835082547493, 'critic_loss': 281.40552348146264, 'actor_loss': 356.75915294761745, 'time_step': 0.0917667726928037, 'td_error': 138.4451647975536, 'init_value': -343.50421142578125, 'ave_value': -352.4618402976861} step=249408
2022-04-21 18:22.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_249408.pt


Epoch 37/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:32.52 [info     ] CQL_20220421115013: epoch=37 step=256336 epoch=37 metrics={'time_sample_batch': 0.00032253676969514985, 'time_algorithm_update': 0.08986017923569845, 'temp_loss': -0.0010712851957951975, 'temp': 0.360286415616958, 'alpha_loss': -0.01475848322492601, 'alpha': 2.0511225345641306, 'critic_loss': 278.27052969139663, 'actor_loss': 354.3621027640327, 'time_step': 0.0905305280008162, 'td_error': 137.7702907378289, 'init_value': -343.20147705078125, 'ave_value': -352.1003512276787} step=256336
2022-04-21 18:32.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_256336.pt


Epoch 38/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:43.42 [info     ] CQL_20220421115013: epoch=38 step=263264 epoch=38 metrics={'time_sample_batch': 0.0003248480671272542, 'time_algorithm_update': 0.09184583395237736, 'temp_loss': 0.00012206523104177736, 'temp': 0.36190666509985786, 'alpha_loss': 0.0034238402221727455, 'alpha': 2.044888884723875, 'critic_loss': 275.5247867621533, 'actor_loss': 352.3196665343181, 'time_step': 0.0925158032024705, 'td_error': 137.03004360076602, 'init_value': -341.3427429199219, 'ave_value': -350.2393808206395} step=263264
2022-04-21 18:43.42 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_263264.pt


Epoch 39/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 18:54.21 [info     ] CQL_20220421115013: epoch=39 step=270192 epoch=39 metrics={'time_sample_batch': 0.0003210365359282108, 'time_algorithm_update': 0.09014628548820354, 'temp_loss': 0.0022400938514923873, 'temp': 0.36381051433433975, 'alpha_loss': 0.04880675139081679, 'alpha': 2.0301015755596117, 'critic_loss': 274.2025833403128, 'actor_loss': 350.8632940199854, 'time_step': 0.09081239499340982, 'td_error': 137.92932967987878, 'init_value': -341.0992126464844, 'ave_value': -349.7714659517542} step=270192
2022-04-21 18:54.21 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_270192.pt


Epoch 40/40:   0%|          | 0/6928 [00:00<?, ?it/s]

2022-04-21 19:05.06 [info     ] CQL_20220421115013: epoch=40 step=277120 epoch=40 metrics={'time_sample_batch': 0.0003227231546582581, 'time_algorithm_update': 0.09107090109490357, 'temp_loss': -0.0017881930465167431, 'temp': 0.3647903068675111, 'alpha_loss': 0.063628239421584, 'alpha': 2.020732377799912, 'critic_loss': 273.50766204769013, 'actor_loss': 350.067744420252, 'time_step': 0.09173897290752886, 'td_error': 134.35094199225546, 'init_value': -337.4706115722656, 'ave_value': -346.0951132238864} step=277120
2022-04-21 19:05.06 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220421115013/model_277120.pt


[(1,
  {'time_sample_batch': 0.00033248131600042963,
   'time_algorithm_update': 0.08607760863271101,
   'temp_loss': 3.2389753310668,
   'temp': 0.7392859111086187,
   'alpha_loss': -13.069082023227462,
   'alpha': 1.4248492442604852,
   'critic_loss': 22.17708755668261,
   'actor_loss': 40.39207670698833,
   'time_step': 0.08674688960746034,
   'td_error': 16.560567173508446,
   'init_value': -82.84752655029297,
   'ave_value': -83.19478249575641}),
 (2,
  {'time_sample_batch': 0.00035565465206912705,
   'time_algorithm_update': 0.09660828736445094,
   'temp_loss': 0.7820279873513959,
   'temp': 0.42524828426281797,
   'alpha_loss': -10.031452970154472,
   'alpha': 2.767355453334835,
   'critic_loss': 54.5907691992458,
   'actor_loss': 122.31750482554799,
   'time_step': 0.09735213743208737,
   'td_error': 44.08262380888173,
   'init_value': -162.3415985107422,
   'ave_value': -162.9516385855288}),
 (3,
  {'time_sample_batch': 0.00034207487629412504,
   'time_algorithm_update': 0.090

In [10]:
model.save_model('cqlStochC2000_Ep40model_CPUOnly.pt')
model.save_policy('cqlStochC2000_Ep40_CPUOnly.pt')

  minimum = torch.tensor(
  maximum = torch.tensor(


## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [11]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i+2000 for i in range(100)]) #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [12]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i*2 for i in range(100)], path="collected_data/rl_stochastic.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [13]:
# from d3rlpy.torch_utility import to_cpu
# to_cpu(model)
# model.save_policy("cqlStochpid2000Ep40CPU.pt")
# model.save_model("cqlStochpid2000Ep40modelCPU.pt")

In [14]:
# for key in dir(model):
#     module = getattr(model, key)
#     if isinstance(module, (torch.nn.Module, torch.nn.Parameter)):
#         print(yes)
#         print(key)
# dir(model)
# type(model)
# model.cpu()
# from d3rlpy.algos.torch.base import TorchImplBase
# new_model = TorchImplBase()
# from d3rlpy.torch_utility import _get_attributes
# model._device = "cpu:0"
# print(model._device)


# def my_get_state_dict(impl: Any) -> Dict[str, Any]:
#     rets = {}
#     for key in _get_attributes(impl):
#         obj = getattr(impl, key)
#         if isinstance(obj, (torch.nn.Module, torch.optim.Optimizer)):
#             if isinstance(obj, (torch.nn.Module, torch.nn.Parameter)):
#                 obj.cpu()
#             rets[key] = obj.state_dict()
#     return rets

# torch.save(my_get_state_dict(model), "my_test_model.pt")

# for key in dir(model):
#     obj = getattr(model, key)
#     if isinstance(obj, (torch.nn.Module, torch.nn.Parameter)):
#         obj.cpu()
#         print("convert to cpu")
# model.save_policy("cqlStochpid2000Ep40modelCPU.pt")

# import trace
# tracer = trace.Trace()
# tracer.run('model.save_policy("cqlStochpid2000Ep40modelCPU.pt")')