# Sample Workflow for d3rlpy Experiments

In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler_confounder import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [5]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_det_small.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data=path)
    samples.setting("coarse")
    states = []
    actions = []
    rewards = []
    metrics = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk, metricChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
        metrics.append(metricChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    metrics = np.hstack(metrics)
    
    terminals = np.zeros(len(states))
    terminals[::1111] = 1 #episode length 100, change if necessary
    dataset = d3rlpy.dataset.MDPDataset(np.hstack([states.numpy(), metrics[:,None]]), 
                                        actions.numpy(), 
                                        rewards.numpy(), 
                                        metrics, terminals)
    return dataset, states.numpy(), actions.numpy(), rewards.numpy(), metrics

We can build the dataset from there, just like this, and split into train and test sets.

In [6]:
dataset, statesOrig, actions, rewards, metrics = get_dataset([i for i in range(1000)], path="collected_data/rl_stochpid.txt")

start
[ 0.00000000e+00  7.95731469e+08 -4.75891077e-02 -3.69999953e-02
  2.00999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.50429671e-01 -4.92727243e-01 -5.31666025e-03 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 1 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.25610892e-01 -3.35999953e-02
 -2.42000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.08749986e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 2 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.90489108e-01 -5.87999953e-02
 -1.01000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.76979602e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+0

[ 0.00000000e+00  7.95731469e+08  1.14110892e-01  5.68000047e-02
  7.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.54213589e-03  3.91473614e-01 -9.17637410e-02 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 48 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.60389108e-01  1.12000047e-02
  1.94999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.68637874e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 49 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.59710892e-01 -1.19999531e-03
  1.38999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.86128266e-01  3.12352813e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0

[ 0.00000000e+00  7.95731469e+08 -7.15891077e-02 -7.99999531e-03
  2.33999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.79765872e-01 -3.17402568e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 97 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.09789108e-01 -4.57999953e-02
 -2.20000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.84885686e-01 -2.78683294e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 98 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.58389108e-01  5.40000469e-03
 -2.15000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.93350106e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0

[ 0.00000000e+00  7.95731469e+08 -3.01891077e-02  5.08000047e-02
 -7.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.53141120e-01 -8.63096507e-02  2.37374998e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 145 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.02189108e-01  5.78000047e-02
  5.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.40938487e-01 -2.83568202e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 146 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.14710892e-01  4.44000047e-02
  1.91999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  1.85963953e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -3.40389108e-01  5.44000047e-02
  8.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.85300440e-02 -4.75342936e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 193 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -8.29891077e-02 -5.09999953e-02
  1.86999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.55508053e-01 -4.43781082e-01  4.05012232e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 194 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.13110892e-01  5.96000047e-02
 -2.41000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08  3.01910892e-01 -3.27999953e-02
 -2.03000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.62788272e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 241 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.14410892e-01  1.20000469e-03
 -2.86000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 242 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.45210892e-01  1.30000047e-02
  2.58999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  8.35709207e-02  3.10798870e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08 -1.45689108e-01 -4.39999953e-02
 -6.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.09050466e-02 -3.37320496e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 288 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.19289108e-01  1.42000047e-02
 -4.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.31225955e-02 -1.03251256e-02  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 289 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.67710892e-01 -4.85999953e-02
 -8.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.81703176e-02  3.82277322e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08 -1.41189108e-01  4.20000047e-02
  1.77999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.79177971e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 334 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.76789108e-01  2.14000047e-02
  2.16999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.59804672e-01 -3.51601165e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 335 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.86108923e-02 -4.67999953e-02
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.45477645e-02  1.85831962e-01 -3.48464679e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

Read chunk # 380 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.17410892e-01  5.96000047e-02
 -1.10000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.80355676e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 381 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.56810892e-01  1.14000047e-02
 -2.73000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.61819354e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 382 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.68389108e-01 -4.81999953e-02
  4.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+

start
[ 0.00000000e+00  7.95731469e+08  3.13310892e-01 -2.69999953e-02
 -5.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.68944324e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 422 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.35410892e-01  4.48000047e-02
  2.36999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.39982990e-01  3.28420700e-01 -4.44519591e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 423 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.04289108e-01 -4.35999953e-02
  1.66999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

start
[ 0.00000000e+00  7.95731469e+08  3.87110892e-01 -5.63999953e-02
  2.15999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 466 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.51389108e-01 -4.71999953e-02
  2.04999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 467 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.02889108e-01 -3.33999953e-02
  6.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.61621743e-04 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08  1.83410892e-01  2.96000047e-02
  1.44999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.22527237e-01  5.29807774e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 510 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.39891077e-02 -5.59999953e-02
  6.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.16438740e-01 -1.45522057e-01  5.82813624e-02 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 511 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.29589108e-01  4.46000047e-02
  1.89999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.64467258e-01 -4.88968803e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -1.64589108e-01 -2.21999953e-02
  2.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.96296333e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 554 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.31189108e-01  8.60000469e-03
  1.65999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 555 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.45889108e-01  5.68000047e-02
 -2.89000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -5.33417713e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -4.14189108e-01  2.04000047e-02
 -1.60000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.11403847e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 600 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.48910892e-01  2.32000047e-02
 -2.99000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 601 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.88510892e-01 -2.59999953e-02
  1.90999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.21786566e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -2.87589108e-01  5.80000469e-03
 -1.33000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 644 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.11010892e-01 -2.03999953e-02
 -1.19000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.75683874e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 645 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.51510892e-01 -4.09999953e-02
  1.70999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.50572308e-01 -1.09160785e-02 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08 -4.08789108e-01  5.38000047e-02
  1.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.80022059e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 690 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.20891077e-02  5.28000047e-02
 -1.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.82589892e-01 -2.02580476e-01  3.23273366e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 691 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.29110892e-01 -2.19999531e-03
  5.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.21216304e-01  5.21296054e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

start
[ 0.00000000e+00  7.95731469e+08 -3.98889108e-01 -5.01999953e-02
 -1.28000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.13384948e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 734 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.10589108e-01 -5.11999953e-02
  1.26999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.28693871e-01 -4.74144892e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 735 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.50910892e-01 -5.43999953e-02
 -4.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.00525643e-01  1.01541962e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08  5.41108923e-02 -5.35999953e-02
  2.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.09880485e-02  2.50924100e-01 -4.52974371e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 779 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.74789108e-01 -1.67999953e-02
 -1.34000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.60230695e-01 -2.30482844e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 780 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.59889108e-01  1.12000047e-02
 -2.61000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08  4.24610892e-01 -4.09999953e-02
 -2.74000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 824 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.42989108e-01  3.84000047e-02
 -1.26000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.97331372e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 825 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.81710892e-01  4.84000047e-02
  2.88999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.74754239e-02  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08  3.84710892e-01 -4.93999953e-02
 -2.78000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.05778192e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 870 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.04891077e-02 -9.99995308e-04
 -9.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.03183804e-01 -1.95233071e-02  1.11702868e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 871 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.33010892e-01 -9.19999531e-03
 -6.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.61482841e-02  2.96465732e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08  1.42010892e-01 -2.87999953e-02
 -2.29000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 911 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.47410892e-01  4.02000047e-02
 -3.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.89646077e-02  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 912 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.37610892e-01 -3.33999953e-02
  4.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.53570111e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08 -2.76489108e-01 -5.51999953e-02
 -3.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.57069552e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 956 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.08289108e-01  3.56000047e-02
  1.78999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.41320942e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 957 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.15910892e-01 -2.87999953e-02
 -2.54000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.82318925e-01  4.73669145e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

In [7]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -128.34366,
 'std': 83.63828,
 'min': -410.7058,
 'max': 0.0,
 'histogram': (array([  4,   2,  11,  11,  22,  18,  25,  33,  39,  22,  22,  41,  73,
          76, 103, 118, 146, 179,  54,   1]),
  array([-410.7058  , -390.17053 , -369.63522 , -349.09995 , -328.56464 ,
         -308.02936 , -287.49408 , -266.95877 , -246.4235  , -225.8882  ,
         -205.3529  , -184.81761 , -164.28232 , -143.74704 , -123.21175 ,
         -102.67645 ,  -82.14116 ,  -61.605873,  -41.07058 ,  -20.53529 ,
            0.      ], dtype=float32))}

In [9]:
# plt.plot(states[:,6])

In [10]:
# plt.plot(actions)

In [11]:
#plt.plot(model.predict(np.array(states)))

In [12]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=5)

## Setting up an Algorithm

In [13]:
from d3rlpy.algos import CQL
from d3rlpy.models.encoders import VectorEncoderFactory

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

#actor_encoder = VectorEncoderFactory(hidden_units=[12, 24, 36, 24, 12],
#                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)
#critic_encoder = VectorEncoderFactory(hidden_units=[12, 24, 24, 12],
#                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)

model = CQL(q_func_factory='mean', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=True) #change it to true if you have one
model.build_with_dataset(dataset)

In [14]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

0.11205020533958533


In [15]:
%load_ext tensorboard
%tensorboard --logdir runs

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
Traceback (most recent call last):
  File "/home/dasc/anaconda3/envs/jbreeden3.10/bin/tensorboard", line 6, in <module>
    from tensorboard.main import run_main
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/main.py", line 40, in <module>
    from tensorboard import default
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/default.py", line 38, in <module>
    from tensorboard.plugins.audio import audio_plugin
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugins/audio/audio_plugin.py", line 25, in <module>
    from tensorboard import plugin_util
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugin_util.py", line 21, in <module>
    from tensorboard._vendor import bleach
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorb

In [16]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=20, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-17 14:30.51 [debug    ] RoundIterator is selected.
2022-04-17 14:30.51 [info     ] Directory is created at d3rlpy_logs/CQL_20220417143051
2022-04-17 14:30.51 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-17 14:30.51 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-17 14:30.51 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220417143051/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conser

Epoch 1/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:31.38 [info     ] CQL_20220417143051: epoch=1 step=4309 epoch=1 metrics={'time_sample_batch': 0.00025647398655861096, 'time_algorithm_update': 0.010486876391899727, 'temp_loss': 3.394779772901236, 'temp': 0.8271505036329719, 'alpha_loss': -5.329435448875717, 'alpha': 1.1590072212044393, 'critic_loss': 11.456173180039894, 'actor_loss': 13.556259882130398, 'time_step': 0.01089877417204302, 'td_error': 14.039907260250546, 'init_value': -36.51660919189453, 'ave_value': -23.69267963622168} step=4309
2022-04-17 14:31.38 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_4309.pt


Epoch 2/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:32.26 [info     ] CQL_20220417143051: epoch=2 step=8618 epoch=2 metrics={'time_sample_batch': 0.00026100028775689443, 'time_algorithm_update': 0.010442913809683572, 'temp_loss': 1.5059043224905284, 'temp': 0.5785030979143547, 'alpha_loss': 1.0276504781109979, 'alpha': 1.1825857058316158, 'critic_loss': 36.939313554907656, 'actor_loss': 62.00821614304173, 'time_step': 0.010876100614967797, 'td_error': 23.539338564792306, 'init_value': -85.83417510986328, 'ave_value': -64.78024209718849} step=8618
2022-04-17 14:32.26 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_8618.pt


Epoch 3/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:33.11 [info     ] CQL_20220417143051: epoch=3 step=12927 epoch=3 metrics={'time_sample_batch': 0.0002515944562563627, 'time_algorithm_update': 0.01006496315020147, 'temp_loss': 0.7399809290995368, 'temp': 0.4084584610551824, 'alpha_loss': 0.950213479524517, 'alpha': 0.8621799513009453, 'critic_loss': 88.46618178516565, 'actor_loss': 105.33465759957244, 'time_step': 0.010470159647497397, 'td_error': 36.49356692211578, 'init_value': -130.56927490234375, 'ave_value': -102.68417747868342} step=12927
2022-04-17 14:33.11 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_12927.pt


Epoch 4/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:33.55 [info     ] CQL_20220417143051: epoch=4 step=17236 epoch=4 metrics={'time_sample_batch': 0.00024859759724037806, 'time_algorithm_update': 0.009842102254839667, 'temp_loss': 0.4050366407578881, 'temp': 0.28418372508184236, 'alpha_loss': 0.12863432517910298, 'alpha': 0.7222284934282247, 'critic_loss': 151.25670760655575, 'actor_loss': 138.09696764801018, 'time_step': 0.01023816697197465, 'td_error': 58.39741216728756, 'init_value': -168.6675567626953, 'ave_value': -133.44459658387277} step=17236
2022-04-17 14:33.55 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_17236.pt


Epoch 5/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:34.39 [info     ] CQL_20220417143051: epoch=5 step=21545 epoch=5 metrics={'time_sample_batch': 0.0002393098960871153, 'time_algorithm_update': 0.00966177055303999, 'temp_loss': 0.2113980704580902, 'temp': 0.19776277426681266, 'alpha_loss': -0.13365597059078213, 'alpha': 0.7517724861869626, 'critic_loss': 191.00956416545043, 'actor_loss': 161.20800544701586, 'time_step': 0.010040610980257188, 'td_error': 57.91613208929736, 'init_value': -192.46023559570312, 'ave_value': -151.44290812425515} step=21545
2022-04-17 14:34.39 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_21545.pt


Epoch 6/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:35.24 [info     ] CQL_20220417143051: epoch=6 step=25854 epoch=6 metrics={'time_sample_batch': 0.0002441293929341941, 'time_algorithm_update': 0.009917642932448317, 'temp_loss': 0.056407830946176185, 'temp': 0.14656922213921666, 'alpha_loss': 0.3092040966349088, 'alpha': 0.734481450658742, 'critic_loss': 165.42193409848252, 'actor_loss': 170.48796651313023, 'time_step': 0.010310423515268478, 'td_error': 39.17976503198212, 'init_value': -198.9853515625, 'ave_value': -153.2848266230188} step=25854
2022-04-17 14:35.24 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_25854.pt


Epoch 7/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:36.08 [info     ] CQL_20220417143051: epoch=7 step=30163 epoch=7 metrics={'time_sample_batch': 0.00024069221480646884, 'time_algorithm_update': 0.009838116033416415, 'temp_loss': -0.022674334258291535, 'temp': 0.15144115515238282, 'alpha_loss': 0.44277389303474696, 'alpha': 0.5830790231609101, 'critic_loss': 110.61384977915912, 'actor_loss': 166.74116064440378, 'time_step': 0.010221961424892423, 'td_error': 24.093584087772573, 'init_value': -195.16958618164062, 'ave_value': -140.80252951750614} step=30163
2022-04-17 14:36.08 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_30163.pt


Epoch 8/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:36.51 [info     ] CQL_20220417143051: epoch=8 step=34472 epoch=8 metrics={'time_sample_batch': 0.0002354882825248118, 'time_algorithm_update': 0.009673060715737989, 'temp_loss': -0.011024754676143523, 'temp': 0.18527828975145988, 'alpha_loss': 0.14988621996417562, 'alpha': 0.4987037983245544, 'critic_loss': 71.4497553603393, 'actor_loss': 154.30534805258898, 'time_step': 0.010044538385493487, 'td_error': 16.113859330157673, 'init_value': -183.17391967773438, 'ave_value': -126.23017434303824} step=34472
2022-04-17 14:36.51 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_34472.pt


Epoch 9/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:37.35 [info     ] CQL_20220417143051: epoch=9 step=38781 epoch=9 metrics={'time_sample_batch': 0.00023643177605250728, 'time_algorithm_update': 0.009722521200509822, 'temp_loss': 0.007556327618982744, 'temp': 0.18629563593081516, 'alpha_loss': 0.06652783177895878, 'alpha': 0.4643402095805412, 'critic_loss': 54.367164103606925, 'actor_loss': 142.25170976342665, 'time_step': 0.010101073688463701, 'td_error': 9.892079591850994, 'init_value': -167.46286010742188, 'ave_value': -113.19075203925371} step=38781
2022-04-17 14:37.35 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_38781.pt


Epoch 10/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:38.20 [info     ] CQL_20220417143051: epoch=10 step=43090 epoch=10 metrics={'time_sample_batch': 0.0002409994643701149, 'time_algorithm_update': 0.00984986405495453, 'temp_loss': 0.009653684497971056, 'temp': 0.16985342507227913, 'alpha_loss': 0.09253391647402473, 'alpha': 0.42633187788134247, 'critic_loss': 46.50388413981679, 'actor_loss': 131.05878330566236, 'time_step': 0.010236612575793936, 'td_error': 7.426147986452112, 'init_value': -153.69137573242188, 'ave_value': -101.09596368198459} step=43090
2022-04-17 14:38.20 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_43090.pt


Epoch 11/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:39.04 [info     ] CQL_20220417143051: epoch=11 step=47399 epoch=11 metrics={'time_sample_batch': 0.00024294709351478305, 'time_algorithm_update': 0.009861676093734702, 'temp_loss': 0.00335006576785647, 'temp': 0.15596353130734528, 'alpha_loss': 0.04249540691179722, 'alpha': 0.4016627168456206, 'critic_loss': 40.25331124946229, 'actor_loss': 119.0042305798286, 'time_step': 0.010255077372683979, 'td_error': 5.61112264910924, 'init_value': -137.51687622070312, 'ave_value': -88.14751367455399} step=47399
2022-04-17 14:39.04 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_47399.pt


Epoch 12/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:39.49 [info     ] CQL_20220417143051: epoch=12 step=51708 epoch=12 metrics={'time_sample_batch': 0.00024320449041197046, 'time_algorithm_update': 0.009888271910321494, 'temp_loss': 0.0035049754280180264, 'temp': 0.14933948374439154, 'alpha_loss': 0.004429345684528522, 'alpha': 0.3887758792816113, 'critic_loss': 34.68189726215583, 'actor_loss': 106.19589042884895, 'time_step': 0.010281238403196372, 'td_error': 4.525804014295499, 'init_value': -121.49470520019531, 'ave_value': -75.49016029945163} step=51708
2022-04-17 14:39.49 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_51708.pt


Epoch 13/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:40.36 [info     ] CQL_20220417143051: epoch=13 step=56017 epoch=13 metrics={'time_sample_batch': 0.00026555419881151845, 'time_algorithm_update': 0.010555751696073364, 'temp_loss': -0.0002877899366109979, 'temp': 0.14575994070761994, 'alpha_loss': 0.010282139494310628, 'alpha': 0.38623289326471694, 'critic_loss': 29.9788296922616, 'actor_loss': 93.91436071389235, 'time_step': 0.011006697393343437, 'td_error': 3.847571368943194, 'init_value': -109.60237884521484, 'ave_value': -66.21730027680394} step=56017
2022-04-17 14:40.36 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_56017.pt


Epoch 14/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:41.58 [info     ] CQL_20220417143051: epoch=14 step=60326 epoch=14 metrics={'time_sample_batch': 0.00032322797954538967, 'time_algorithm_update': 0.01829489955605529, 'temp_loss': 0.0023372169868738743, 'temp': 0.14399754550471386, 'alpha_loss': 0.023614515373486474, 'alpha': 0.3795676348546343, 'critic_loss': 28.22323860905513, 'actor_loss': 82.71612669846718, 'time_step': 0.018864346369426636, 'td_error': 3.487472083822079, 'init_value': -96.92069244384766, 'ave_value': -55.79940745739786} step=60326
2022-04-17 14:41.58 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_60326.pt


Epoch 15/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:43.13 [info     ] CQL_20220417143051: epoch=15 step=64635 epoch=15 metrics={'time_sample_batch': 0.00034221443119611226, 'time_algorithm_update': 0.016598830452151242, 'temp_loss': 0.0012416781296849383, 'temp': 0.13782290004264675, 'alpha_loss': 0.044504919879827236, 'alpha': 0.3653922773107461, 'critic_loss': 26.16757603140103, 'actor_loss': 72.88567449764908, 'time_step': 0.01720664489570807, 'td_error': 3.1692989183630162, 'init_value': -87.14334869384766, 'ave_value': -47.931966016644296} step=64635
2022-04-17 14:43.13 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_64635.pt


Epoch 16/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:44.26 [info     ] CQL_20220417143051: epoch=16 step=68944 epoch=16 metrics={'time_sample_batch': 0.0003480515642712802, 'time_algorithm_update': 0.016208562310927703, 'temp_loss': -0.0008198101981591635, 'temp': 0.1378824412442327, 'alpha_loss': 0.021834030234583576, 'alpha': 0.347257697822932, 'critic_loss': 24.729340107998677, 'actor_loss': 64.30400097400579, 'time_step': 0.016832435012112696, 'td_error': 3.2022509015536054, 'init_value': -77.48933410644531, 'ave_value': -40.090120797261996} step=68944
2022-04-17 14:44.26 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_68944.pt


Epoch 17/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:45.39 [info     ] CQL_20220417143051: epoch=17 step=73253 epoch=17 metrics={'time_sample_batch': 0.00034883614896649023, 'time_algorithm_update': 0.0162096544771688, 'temp_loss': -0.0019094368435760597, 'temp': 0.1425683992828618, 'alpha_loss': -0.004704788709074858, 'alpha': 0.3414055703963888, 'critic_loss': 24.1483662056519, 'actor_loss': 56.46900077628602, 'time_step': 0.016829040327711764, 'td_error': 3.171507567201073, 'init_value': -68.4445571899414, 'ave_value': -32.62961093186258} step=73253
2022-04-17 14:45.39 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_73253.pt


Epoch 18/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:46.52 [info     ] CQL_20220417143051: epoch=18 step=77562 epoch=18 metrics={'time_sample_batch': 0.00034836838598952934, 'time_algorithm_update': 0.01612993541574998, 'temp_loss': -0.0020406958122136205, 'temp': 0.14524628639041368, 'alpha_loss': 0.001010524847482699, 'alpha': 0.34600377589789255, 'critic_loss': 23.476547409442936, 'actor_loss': 49.633640789050574, 'time_step': 0.01674685928599689, 'td_error': 3.034049813059566, 'init_value': -61.09284591674805, 'ave_value': -26.98976060308986} step=77562
2022-04-17 14:46.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_77562.pt


Epoch 19/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:48.05 [info     ] CQL_20220417143051: epoch=19 step=81871 epoch=19 metrics={'time_sample_batch': 0.0003489413873367504, 'time_algorithm_update': 0.01612538377324067, 'temp_loss': -0.0010165284560620699, 'temp': 0.14930280870491072, 'alpha_loss': -0.010404387441515948, 'alpha': 0.34648079730623915, 'critic_loss': 23.361932005841318, 'actor_loss': 43.79587581932863, 'time_step': 0.016744195239177074, 'td_error': 2.862667779896601, 'init_value': -55.07027053833008, 'ave_value': -22.754897393433897} step=81871
2022-04-17 14:48.05 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_81871.pt


Epoch 20/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-17 14:49.17 [info     ] CQL_20220417143051: epoch=20 step=86180 epoch=20 metrics={'time_sample_batch': 0.0003393909775704512, 'time_algorithm_update': 0.015996473243325624, 'temp_loss': -0.0012708520888504143, 'temp': 0.1519829111299539, 'alpha_loss': 0.004844973477449226, 'alpha': 0.3466879540099487, 'critic_loss': 23.24989951414917, 'actor_loss': 38.84604859944147, 'time_step': 0.01660438595362566, 'td_error': 2.666530147944117, 'init_value': -48.76219940185547, 'ave_value': -17.776884348079697} step=86180
2022-04-17 14:49.17 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220417143051/model_86180.pt


[(1,
  {'time_sample_batch': 0.00025647398655861096,
   'time_algorithm_update': 0.010486876391899727,
   'temp_loss': 3.394779772901236,
   'temp': 0.8271505036329719,
   'alpha_loss': -5.329435448875717,
   'alpha': 1.1590072212044393,
   'critic_loss': 11.456173180039894,
   'actor_loss': 13.556259882130398,
   'time_step': 0.01089877417204302,
   'td_error': 14.039907260250546,
   'init_value': -36.51660919189453,
   'ave_value': -23.69267963622168}),
 (2,
  {'time_sample_batch': 0.00026100028775689443,
   'time_algorithm_update': 0.010442913809683572,
   'temp_loss': 1.5059043224905284,
   'temp': 0.5785030979143547,
   'alpha_loss': 1.0276504781109979,
   'alpha': 1.1825857058316158,
   'critic_loss': 36.939313554907656,
   'actor_loss': 62.00821614304173,
   'time_step': 0.010876100614967797,
   'td_error': 23.539338564792306,
   'init_value': -85.83417510986328,
   'ave_value': -64.78024209718849}),
 (3,
  {'time_sample_batch': 0.0002515944562563627,
   'time_algorithm_update':

## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [17]:
from d3rlpy.ope import FQE
# metrics to evaluate with
from d3rlpy.metrics.scorer import soft_opc_scorer


ope_dataset, states_ope, actions_ope, rewards_ope, metrics_ope = get_dataset([500+i for i in range(50)], path="collected_data/rl_stochpid.txt") #change if you'd prefer different chunks
ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=5)

fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
fqe.fit(ope_train_episodes,
        eval_episodes=ope_test_episodes,
        tensorboard_dir='runs',
        n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
        scorers={
           'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer,
           'soft_opc': soft_opc_scorer(return_threshold=0)
        })

start
[ 0.00000000e+00  7.95731469e+08 -4.19889108e-01 -1.83999953e-02
  2.71999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 501 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.30889108e-01 -2.27999953e-02
 -3.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.19275575e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 502 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.99289108e-01 -3.69999953e-02
  2.18999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.02595006e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08  1.26710892e-01  7.60000469e-03
 -1.01000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  4.54273830e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 531 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.57108923e-02  5.62000047e-02
  8.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.86939797e-02  1.54513207e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 532 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.60210892e-01 -1.49999953e-02
  4.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.25177022e-02  3.21207732e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

Epoch 1/50:   0%|          | 0/499 [00:00<?, ?it/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


2022-04-17 14:49.19 [info     ] FQE_20220417144917: epoch=1 step=499 epoch=1 metrics={'time_sample_batch': 0.00017088209698816578, 'time_algorithm_update': 0.00361565979783664, 'loss': 0.0064157983843416335, 'time_step': 0.0038696974217294453, 'init_value': -0.5316524505615234, 'ave_value': -0.3708868765649763, 'soft_opc': nan} step=499




2022-04-17 14:49.19 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_499.pt


Epoch 2/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.21 [info     ] FQE_20220417144917: epoch=2 step=998 epoch=2 metrics={'time_sample_batch': 0.0001712734092452483, 'time_algorithm_update': 0.003700443164619033, 'loss': 0.004909136375498198, 'time_step': 0.003952516582542526, 'init_value': -0.606988251209259, 'ave_value': -0.39794752515919574, 'soft_opc': nan} step=998




2022-04-17 14:49.21 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_998.pt


Epoch 3/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.23 [info     ] FQE_20220417144917: epoch=3 step=1497 epoch=3 metrics={'time_sample_batch': 0.00017068285741404684, 'time_algorithm_update': 0.0037714713561033198, 'loss': 0.0040915756946418415, 'time_step': 0.004023511806327499, 'init_value': -0.644119918346405, 'ave_value': -0.4146075309094813, 'soft_opc': nan} step=1497




2022-04-17 14:49.23 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_1497.pt


Epoch 4/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.25 [info     ] FQE_20220417144917: epoch=4 step=1996 epoch=4 metrics={'time_sample_batch': 0.00015766396073396794, 'time_algorithm_update': 0.0034528452313256886, 'loss': 0.0036742373671520777, 'time_step': 0.0036867032786887253, 'init_value': -0.759827733039856, 'ave_value': -0.486348281747049, 'soft_opc': nan} step=1996




2022-04-17 14:49.25 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_1996.pt


Epoch 5/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.27 [info     ] FQE_20220417144917: epoch=5 step=2495 epoch=5 metrics={'time_sample_batch': 0.00016188095948978034, 'time_algorithm_update': 0.003676584106170104, 'loss': 0.0038613347035229865, 'time_step': 0.003913469926149907, 'init_value': -0.9854704141616821, 'ave_value': -0.6078298386891146, 'soft_opc': nan} step=2495




2022-04-17 14:49.27 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_2495.pt


Epoch 6/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.29 [info     ] FQE_20220417144917: epoch=6 step=2994 epoch=6 metrics={'time_sample_batch': 0.00016771194213377927, 'time_algorithm_update': 0.003749705029871755, 'loss': 0.004282835563845402, 'time_step': 0.003993875755814608, 'init_value': -1.2087445259094238, 'ave_value': -0.7206144355052898, 'soft_opc': nan} step=2994




2022-04-17 14:49.29 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_2994.pt


Epoch 7/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.32 [info     ] FQE_20220417144917: epoch=7 step=3493 epoch=7 metrics={'time_sample_batch': 0.00017432124915725005, 'time_algorithm_update': 0.003948211669921875, 'loss': 0.00546504368255277, 'time_step': 0.00420607115797146, 'init_value': -1.4237947463989258, 'ave_value': -0.8347172226773591, 'soft_opc': nan} step=3493




2022-04-17 14:49.32 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_3493.pt


Epoch 8/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.34 [info     ] FQE_20220417144917: epoch=8 step=3992 epoch=8 metrics={'time_sample_batch': 0.00017195235273403253, 'time_algorithm_update': 0.0038980114197205446, 'loss': 0.006226940929492141, 'time_step': 0.004153176156695716, 'init_value': -1.7547106742858887, 'ave_value': -1.022899863362178, 'soft_opc': nan} step=3992




2022-04-17 14:49.34 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_3992.pt


Epoch 9/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.36 [info     ] FQE_20220417144917: epoch=9 step=4491 epoch=9 metrics={'time_sample_batch': 0.00017917275667668345, 'time_algorithm_update': 0.003907503728159444, 'loss': 0.007452741192041604, 'time_step': 0.004172670100638288, 'init_value': -2.0161585807800293, 'ave_value': -1.1618772529493513, 'soft_opc': nan} step=4491




2022-04-17 14:49.36 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_4491.pt


Epoch 10/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.38 [info     ] FQE_20220417144917: epoch=10 step=4990 epoch=10 metrics={'time_sample_batch': 0.0001730952329769402, 'time_algorithm_update': 0.003681645842496761, 'loss': 0.008290827251692678, 'time_step': 0.00394064247727633, 'init_value': -2.409961700439453, 'ave_value': -1.366599453841378, 'soft_opc': nan} step=4990




2022-04-17 14:49.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_4990.pt


Epoch 11/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.40 [info     ] FQE_20220417144917: epoch=11 step=5489 epoch=11 metrics={'time_sample_batch': 0.00017397484941807442, 'time_algorithm_update': 0.003788818577248491, 'loss': 0.009657414043546171, 'time_step': 0.004047604505428092, 'init_value': -2.712460994720459, 'ave_value': -1.4494164830929523, 'soft_opc': nan} step=5489




2022-04-17 14:49.40 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_5489.pt


Epoch 12/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.42 [info     ] FQE_20220417144917: epoch=12 step=5988 epoch=12 metrics={'time_sample_batch': 0.00017513063006506177, 'time_algorithm_update': 0.0038028327162136773, 'loss': 0.011020979552878766, 'time_step': 0.004060049573022999, 'init_value': -2.9527406692504883, 'ave_value': -1.523013715203516, 'soft_opc': nan} step=5988




2022-04-17 14:49.42 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_5988.pt


Epoch 13/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.44 [info     ] FQE_20220417144917: epoch=13 step=6487 epoch=13 metrics={'time_sample_batch': 0.00018187180788579112, 'time_algorithm_update': 0.003913086258576724, 'loss': 0.012410423454067078, 'time_step': 0.004177085383382732, 'init_value': -3.311816692352295, 'ave_value': -1.680697571580206, 'soft_opc': nan} step=6487




2022-04-17 14:49.44 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_6487.pt


Epoch 14/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.46 [info     ] FQE_20220417144917: epoch=14 step=6986 epoch=14 metrics={'time_sample_batch': 0.00016659008477159398, 'time_algorithm_update': 0.0035979728660507047, 'loss': 0.013890692644812287, 'time_step': 0.003842176559692872, 'init_value': -3.667163372039795, 'ave_value': -1.8186877040091802, 'soft_opc': nan} step=6986




2022-04-17 14:49.46 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_6986.pt


Epoch 15/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.48 [info     ] FQE_20220417144917: epoch=15 step=7485 epoch=15 metrics={'time_sample_batch': 0.00017376844295280012, 'time_algorithm_update': 0.003773087728955225, 'loss': 0.015850980155112108, 'time_step': 0.004031921436409195, 'init_value': -3.891869068145752, 'ave_value': -1.9237028636608844, 'soft_opc': nan} step=7485




2022-04-17 14:49.48 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_7485.pt


Epoch 16/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.51 [info     ] FQE_20220417144917: epoch=16 step=7984 epoch=16 metrics={'time_sample_batch': 0.0001759290217397686, 'time_algorithm_update': 0.003994376004817252, 'loss': 0.017090847407549785, 'time_step': 0.004260179752816179, 'init_value': -4.112129211425781, 'ave_value': -2.0796345988674476, 'soft_opc': nan} step=7984




2022-04-17 14:49.51 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_7984.pt


Epoch 17/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.53 [info     ] FQE_20220417144917: epoch=17 step=8483 epoch=17 metrics={'time_sample_batch': 0.00017544453989766642, 'time_algorithm_update': 0.0037790486712255076, 'loss': 0.01936964588757709, 'time_step': 0.00404532686741892, 'init_value': -4.276358127593994, 'ave_value': -2.118044812353076, 'soft_opc': nan} step=8483




2022-04-17 14:49.53 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_8483.pt


Epoch 18/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.55 [info     ] FQE_20220417144917: epoch=18 step=8982 epoch=18 metrics={'time_sample_batch': 0.0001743690284316191, 'time_algorithm_update': 0.003753344377200446, 'loss': 0.02077911463206847, 'time_step': 0.004013489625735847, 'init_value': -4.4546661376953125, 'ave_value': -2.1632649623478453, 'soft_opc': nan} step=8982




2022-04-17 14:49.55 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_8982.pt


Epoch 19/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.57 [info     ] FQE_20220417144917: epoch=19 step=9481 epoch=19 metrics={'time_sample_batch': 0.00017409143084753492, 'time_algorithm_update': 0.0039557894628368066, 'loss': 0.022813275800775253, 'time_step': 0.004217385290142052, 'init_value': -4.809579849243164, 'ave_value': -2.364219073390773, 'soft_opc': nan} step=9481




2022-04-17 14:49.57 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_9481.pt


Epoch 20/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:49.59 [info     ] FQE_20220417144917: epoch=20 step=9980 epoch=20 metrics={'time_sample_batch': 0.00017330928412611355, 'time_algorithm_update': 0.0038693811229331223, 'loss': 0.024269293551613204, 'time_step': 0.004126407340437711, 'init_value': -4.934338569641113, 'ave_value': -2.4498279555910476, 'soft_opc': nan} step=9980




2022-04-17 14:49.59 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_9980.pt


Epoch 21/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.01 [info     ] FQE_20220417144917: epoch=21 step=10479 epoch=21 metrics={'time_sample_batch': 0.00017175359095265727, 'time_algorithm_update': 0.0037155500155890394, 'loss': 0.02565567613530807, 'time_step': 0.003971379840063427, 'init_value': -5.025557994842529, 'ave_value': -2.474843244737031, 'soft_opc': nan} step=10479




2022-04-17 14:50.01 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_10479.pt


Epoch 22/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.03 [info     ] FQE_20220417144917: epoch=22 step=10978 epoch=22 metrics={'time_sample_batch': 0.00018666168014128844, 'time_algorithm_update': 0.0040485190007395165, 'loss': 0.027541052754912414, 'time_step': 0.004320211544304429, 'init_value': -5.222118854522705, 'ave_value': -2.6281153683542264, 'soft_opc': nan} step=10978




2022-04-17 14:50.03 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_10978.pt


Epoch 23/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.06 [info     ] FQE_20220417144917: epoch=23 step=11477 epoch=23 metrics={'time_sample_batch': 0.00017382434470381193, 'time_algorithm_update': 0.003818348557772283, 'loss': 0.029384158648044723, 'time_step': 0.00407820330831952, 'init_value': -5.446779251098633, 'ave_value': -2.7913325526424355, 'soft_opc': nan} step=11477




2022-04-17 14:50.06 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_11477.pt


Epoch 24/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.08 [info     ] FQE_20220417144917: epoch=24 step=11976 epoch=24 metrics={'time_sample_batch': 0.00017483965428415424, 'time_algorithm_update': 0.0037727035835892976, 'loss': 0.03022727899800217, 'time_step': 0.004033832129591214, 'init_value': -5.6123504638671875, 'ave_value': -2.8954869152726355, 'soft_opc': nan} step=11976




2022-04-17 14:50.08 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_11976.pt


Epoch 25/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.10 [info     ] FQE_20220417144917: epoch=25 step=12475 epoch=25 metrics={'time_sample_batch': 0.00017670447936277828, 'time_algorithm_update': 0.003848948794041941, 'loss': 0.03152495792288639, 'time_step': 0.004116024426324573, 'init_value': -5.81099796295166, 'ave_value': -2.989064640083627, 'soft_opc': nan} step=12475




2022-04-17 14:50.10 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_12475.pt


Epoch 26/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.12 [info     ] FQE_20220417144917: epoch=26 step=12974 epoch=26 metrics={'time_sample_batch': 0.00018038635024565734, 'time_algorithm_update': 0.0038847349927516165, 'loss': 0.03386724041634891, 'time_step': 0.004148048484970429, 'init_value': -5.8928680419921875, 'ave_value': -3.0567543098837273, 'soft_opc': nan} step=12974




2022-04-17 14:50.12 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_12974.pt


Epoch 27/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.14 [info     ] FQE_20220417144917: epoch=27 step=13473 epoch=27 metrics={'time_sample_batch': 0.00016936462723420473, 'time_algorithm_update': 0.003754438522583497, 'loss': 0.03486456270124352, 'time_step': 0.004009192357799094, 'init_value': -5.993074417114258, 'ave_value': -3.081640326298001, 'soft_opc': nan} step=13473




2022-04-17 14:50.14 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_13473.pt


Epoch 28/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.16 [info     ] FQE_20220417144917: epoch=28 step=13972 epoch=28 metrics={'time_sample_batch': 0.00017496101364105165, 'time_algorithm_update': 0.0038334272189704116, 'loss': 0.036730198681873376, 'time_step': 0.004097239049020893, 'init_value': -5.855929851531982, 'ave_value': -3.0318732339116905, 'soft_opc': nan} step=13972




2022-04-17 14:50.16 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_13972.pt


Epoch 29/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.18 [info     ] FQE_20220417144917: epoch=29 step=14471 epoch=29 metrics={'time_sample_batch': 0.00017468962736263543, 'time_algorithm_update': 0.0037960332476782177, 'loss': 0.037690400280793734, 'time_step': 0.0040626215313622855, 'init_value': -5.9630231857299805, 'ave_value': -3.114852908065727, 'soft_opc': nan} step=14471




2022-04-17 14:50.18 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_14471.pt


Epoch 30/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.20 [info     ] FQE_20220417144917: epoch=30 step=14970 epoch=30 metrics={'time_sample_batch': 0.00017565620208312131, 'time_algorithm_update': 0.0038511409071499934, 'loss': 0.038215707248927866, 'time_step': 0.0041170693590550245, 'init_value': -6.002175331115723, 'ave_value': -3.250533936075396, 'soft_opc': nan} step=14970




2022-04-17 14:50.20 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_14970.pt


Epoch 31/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.23 [info     ] FQE_20220417144917: epoch=31 step=15469 epoch=31 metrics={'time_sample_batch': 0.00017653581852425555, 'time_algorithm_update': 0.003762179242824027, 'loss': 0.03822059449011867, 'time_step': 0.004025078010941316, 'init_value': -6.1679229736328125, 'ave_value': -3.4049095189618366, 'soft_opc': nan} step=15469




2022-04-17 14:50.23 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_15469.pt


Epoch 32/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.25 [info     ] FQE_20220417144917: epoch=32 step=15968 epoch=32 metrics={'time_sample_batch': 0.0001786175615085151, 'time_algorithm_update': 0.003846536418479048, 'loss': 0.03918211603721973, 'time_step': 0.004114230314572015, 'init_value': -6.064202308654785, 'ave_value': -3.3368416479545404, 'soft_opc': nan} step=15968




2022-04-17 14:50.25 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_15968.pt


Epoch 33/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.27 [info     ] FQE_20220417144917: epoch=33 step=16467 epoch=33 metrics={'time_sample_batch': 0.00017752054936900167, 'time_algorithm_update': 0.003887473222965707, 'loss': 0.039955984937532885, 'time_step': 0.0041521446021620875, 'init_value': -6.117964744567871, 'ave_value': -3.319005310948233, 'soft_opc': nan} step=16467




2022-04-17 14:50.27 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_16467.pt


Epoch 34/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.29 [info     ] FQE_20220417144917: epoch=34 step=16966 epoch=34 metrics={'time_sample_batch': 0.000174494210130466, 'time_algorithm_update': 0.003838617959337865, 'loss': 0.03992145575973987, 'time_step': 0.004098919446100453, 'init_value': -6.042357444763184, 'ave_value': -3.34981762459166, 'soft_opc': nan} step=16966




2022-04-17 14:50.29 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_16966.pt


Epoch 35/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.31 [info     ] FQE_20220417144917: epoch=35 step=17465 epoch=35 metrics={'time_sample_batch': 0.00017890567053296046, 'time_algorithm_update': 0.003931068466278259, 'loss': 0.04094574340821649, 'time_step': 0.00419138189785944, 'init_value': -5.907964706420898, 'ave_value': -3.2525193051518957, 'soft_opc': nan} step=17465




2022-04-17 14:50.31 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_17465.pt


Epoch 36/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.33 [info     ] FQE_20220417144917: epoch=36 step=17964 epoch=36 metrics={'time_sample_batch': 0.00018014793166655577, 'time_algorithm_update': 0.0038854559820018456, 'loss': 0.04161045838516945, 'time_step': 0.004151818269718147, 'init_value': -5.873055458068848, 'ave_value': -3.23255893373778, 'soft_opc': nan} step=17964




2022-04-17 14:50.33 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_17964.pt


Epoch 37/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.35 [info     ] FQE_20220417144917: epoch=37 step=18463 epoch=37 metrics={'time_sample_batch': 0.00017360790459092012, 'time_algorithm_update': 0.0037580764365339565, 'loss': 0.041890645227516315, 'time_step': 0.004015363051083857, 'init_value': -6.093533515930176, 'ave_value': -3.435860408375224, 'soft_opc': nan} step=18463




2022-04-17 14:50.35 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_18463.pt


Epoch 38/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.38 [info     ] FQE_20220417144917: epoch=38 step=18962 epoch=38 metrics={'time_sample_batch': 0.00017608287100323694, 'time_algorithm_update': 0.0038777095282483912, 'loss': 0.042951546783208144, 'time_step': 0.004135554682515666, 'init_value': -6.261590957641602, 'ave_value': -3.599656532834041, 'soft_opc': nan} step=18962




2022-04-17 14:50.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_18962.pt


Epoch 39/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.40 [info     ] FQE_20220417144917: epoch=39 step=19461 epoch=39 metrics={'time_sample_batch': 0.0001729700512780933, 'time_algorithm_update': 0.00369210520344889, 'loss': 0.04345299310693872, 'time_step': 0.0039519150414782205, 'init_value': -6.2820844650268555, 'ave_value': -3.631102782852008, 'soft_opc': nan} step=19461




2022-04-17 14:50.40 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_19461.pt


Epoch 40/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.42 [info     ] FQE_20220417144917: epoch=40 step=19960 epoch=40 metrics={'time_sample_batch': 0.00016218053554007428, 'time_algorithm_update': 0.0037320147535366143, 'loss': 0.043559279374942514, 'time_step': 0.0039711438104480445, 'init_value': -6.251253128051758, 'ave_value': -3.6712330294017854, 'soft_opc': nan} step=19960




2022-04-17 14:50.42 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_19960.pt


Epoch 41/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.44 [info     ] FQE_20220417144917: epoch=41 step=20459 epoch=41 metrics={'time_sample_batch': 0.00017542542818791882, 'time_algorithm_update': 0.0038722517017372145, 'loss': 0.044412857299028545, 'time_step': 0.004128815415865911, 'init_value': -6.519750595092773, 'ave_value': -3.810371078273097, 'soft_opc': nan} step=20459




2022-04-17 14:50.44 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_20459.pt


Epoch 42/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.46 [info     ] FQE_20220417144917: epoch=42 step=20958 epoch=42 metrics={'time_sample_batch': 0.00017440008495995897, 'time_algorithm_update': 0.0038611970110264474, 'loss': 0.04503006588729221, 'time_step': 0.004117407158524813, 'init_value': -6.739398002624512, 'ave_value': -3.9870245434092104, 'soft_opc': nan} step=20958




2022-04-17 14:50.46 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_20958.pt


Epoch 43/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.48 [info     ] FQE_20220417144917: epoch=43 step=21457 epoch=43 metrics={'time_sample_batch': 0.00017496053584830794, 'time_algorithm_update': 0.003902568129117121, 'loss': 0.04604465627104029, 'time_step': 0.004160795995372092, 'init_value': -6.731407642364502, 'ave_value': -3.991797029383078, 'soft_opc': nan} step=21457




2022-04-17 14:50.48 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_21457.pt


Epoch 44/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.50 [info     ] FQE_20220417144917: epoch=44 step=21956 epoch=44 metrics={'time_sample_batch': 0.00017916702316375915, 'time_algorithm_update': 0.004009181846358733, 'loss': 0.047640407571096895, 'time_step': 0.004283694800489651, 'init_value': -6.904108047485352, 'ave_value': -4.158822243835207, 'soft_opc': nan} step=21956




2022-04-17 14:50.50 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_21956.pt


Epoch 45/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.52 [info     ] FQE_20220417144917: epoch=45 step=22455 epoch=45 metrics={'time_sample_batch': 0.00017044635000592004, 'time_algorithm_update': 0.0038172267004100975, 'loss': 0.04676636774139572, 'time_step': 0.004072802339144843, 'init_value': -6.898653984069824, 'ave_value': -4.21733201975758, 'soft_opc': nan} step=22455




2022-04-17 14:50.52 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_22455.pt


Epoch 46/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.55 [info     ] FQE_20220417144917: epoch=46 step=22954 epoch=46 metrics={'time_sample_batch': 0.000171821915315005, 'time_algorithm_update': 0.003826374042488052, 'loss': 0.04791156197958747, 'time_step': 0.004080933416056968, 'init_value': -6.905007839202881, 'ave_value': -4.2184542259125895, 'soft_opc': nan} step=22954




2022-04-17 14:50.55 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_22954.pt


Epoch 47/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.57 [info     ] FQE_20220417144917: epoch=47 step=23453 epoch=47 metrics={'time_sample_batch': 0.0001731955694531152, 'time_algorithm_update': 0.0037948277765858865, 'loss': 0.048170131867758946, 'time_step': 0.00405500312606414, 'init_value': -6.8578033447265625, 'ave_value': -4.190737831881186, 'soft_opc': nan} step=23453




2022-04-17 14:50.57 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_23453.pt


Epoch 48/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:50.59 [info     ] FQE_20220417144917: epoch=48 step=23952 epoch=48 metrics={'time_sample_batch': 0.00017546795174210725, 'time_algorithm_update': 0.003907040746990807, 'loss': 0.04807426966190196, 'time_step': 0.004172225275593912, 'init_value': -6.735415458679199, 'ave_value': -4.19954396194468, 'soft_opc': nan} step=23952




2022-04-17 14:50.59 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_23952.pt


Epoch 49/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:51.01 [info     ] FQE_20220417144917: epoch=49 step=24451 epoch=49 metrics={'time_sample_batch': 0.00016683232569264506, 'time_algorithm_update': 0.003730607653906446, 'loss': 0.04838682368714064, 'time_step': 0.003979090459361105, 'init_value': -6.630803108215332, 'ave_value': -4.155116377551075, 'soft_opc': nan} step=24451




2022-04-17 14:51.01 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_24451.pt


Epoch 50/50:   0%|          | 0/499 [00:00<?, ?it/s]



2022-04-17 14:51.03 [info     ] FQE_20220417144917: epoch=50 step=24950 epoch=50 metrics={'time_sample_batch': 0.0001380429239215736, 'time_algorithm_update': 0.003499410912125765, 'loss': 0.05065426302682758, 'time_step': 0.003701590344996634, 'init_value': -6.726969242095947, 'ave_value': -4.322503234024856, 'soft_opc': nan} step=24950




2022-04-17 14:51.03 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220417144917/model_24950.pt


[(1,
  {'time_sample_batch': 0.00017088209698816578,
   'time_algorithm_update': 0.00361565979783664,
   'loss': 0.0064157983843416335,
   'time_step': 0.0038696974217294453,
   'init_value': -0.5316524505615234,
   'ave_value': -0.3708868765649763,
   'soft_opc': nan}),
 (2,
  {'time_sample_batch': 0.0001712734092452483,
   'time_algorithm_update': 0.003700443164619033,
   'loss': 0.004909136375498198,
   'time_step': 0.003952516582542526,
   'init_value': -0.606988251209259,
   'ave_value': -0.39794752515919574,
   'soft_opc': nan}),
 (3,
  {'time_sample_batch': 0.00017068285741404684,
   'time_algorithm_update': 0.0037714713561033198,
   'loss': 0.0040915756946418415,
   'time_step': 0.004023511806327499,
   'init_value': -0.644119918346405,
   'ave_value': -0.4146075309094813,
   'soft_opc': nan}),
 (4,
  {'time_sample_batch': 0.00015766396073396794,
   'time_algorithm_update': 0.0034528452313256886,
   'loss': 0.0036742373671520777,
   'time_step': 0.0036867032786887253,
   'init_

In [18]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset, states_ope, actions_ope, rewards_ope, metrics_ope = get_dataset([2,4,6,8], path="collected_data/rl_stoch_small.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [19]:
model.save_policy("workflowConfound.pt")
model.save_model("workflowConfoundmodel.pt")

  minimum = torch.tensor(
  maximum = torch.tensor(


In [21]:
from d3rlpy.torch_utility import to_cpu
to_cpu(model)
model.save_policy("workflowConfoundCPU.pt")
model.save_model("workflowConfoundmodelCPU.pt")