# Sample Workflow for d3rlpy Experiments

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler_confounder import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [2]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_det_small.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data=path)
    samples.setting("coarse")
    states = []
    actions = []
    rewards = []
    metrics = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk, metricChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
        metrics.append(metricChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    metrics = np.hstack(metrics)
    
    terminals = np.zeros(len(states))
    terminals[::1111] = 1 #episode length 100, change if necessary
    dataset = d3rlpy.dataset.MDPDataset(np.hstack([states.numpy(), metrics[:,None]]), 
                                        actions.numpy(), 
                                        rewards.numpy(),
                                        terminals)
    return dataset, states.numpy(), actions.numpy(), rewards.numpy(), metrics

We can build the dataset from there, just like this, and split into train and test sets.

In [3]:
dataset, statesOrig, actions, rewards, metrics = get_dataset([i for i in range(1000)], path="collected_data/rl_stochpid.txt")

start
[ 0.00000000e+00  7.95731469e+08 -4.75891077e-02 -3.69999953e-02
  2.00999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.50429671e-01 -4.92727243e-01 -5.31666025e-03 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 1 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.25610892e-01 -3.35999953e-02
 -2.42000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.08749986e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 2 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.90489108e-01 -5.87999953e-02
 -1.01000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.76979602e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+0

start
[ 0.00000000e+00  7.95731469e+08  2.74610892e-01  1.40000047e-02
  7.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.58596924e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 42 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.29289108e-01 -1.49999953e-02
  1.51999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 43 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -7.03891077e-02 -1.45999953e-02
  2.71999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.45547944e-01 -1.42260606e-01  1.10586690e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e

[ 0.00000000e+00  7.95731469e+08  3.81410892e-01  4.20000469e-03
  2.96999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 83 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.15010892e-01  2.00000469e-03
 -2.16000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.16022107e-01  3.75612115e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 84 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.51610892e-01 -2.79999953e-02
  7.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.22183941e-03  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0

[ 0.00000000e+00  7.95731469e+08  2.81510892e-01 -5.59999531e-03
  1.82999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.78617263e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 126 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.38089108e-01 -5.87999953e-02
 -2.28000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.05010402e-01 -4.22664253e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 127 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.07889108e-01  2.08000047e-02
  2.89999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

Read chunk # 169 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.08810892e-01 -2.45999953e-02
 -6.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.40229373e-02  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 170 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.28010892e-01  2.18000047e-02
 -5.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.97408545e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 171 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.64110892e-01  3.64000047e-02
  2.99999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.62531137e-01  4.43218747e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+

start
[ 0.00000000e+00  7.95731469e+08 -1.25589108e-01 -3.99999531e-03
 -2.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -1.77044027e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 211 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.43810892e-01  5.14000047e-02
 -1.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.70996225e-01  6.00000000e-01 -5.87902697e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 212 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.34110892e-01 -9.59999531e-03
 -1.16000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  4.88245618e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08 -3.79989108e-01  2.20000469e-03
 -1.09000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.56992418e-01 -2.96852309e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 251 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.74189108e-01  3.20000469e-03
 -3.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.84206661e-01 -4.52868124e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 252 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.58789108e-01  1.86000047e-02
 -2.16000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -1.79889108e-01 -2.13999953e-02
 -8.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.48109576e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 294 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.24108923e-02 -5.33999953e-02
  6.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.81836613e-01  3.10953755e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 295 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.41289108e-01  3.60000047e-02
 -1.40001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.27789575e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08  2.86108923e-02 -4.67999953e-02
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.45477645e-02  1.85831962e-01 -3.48464679e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 336 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.39589108e-01  5.60000047e-02
 -8.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 337 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.50310892e-01  6.00004692e-04
  1.55999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.71425846e-01  5.62564812e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08 -4.60891077e-02 -3.75999953e-02
  2.17999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.26744318e-01 -1.77785343e-01  2.32142424e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 379 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.51410892e-01  1.16000047e-02
 -6.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.11237961e-02  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 380 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.17410892e-01  5.96000047e-02
 -1.10000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.80355676e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08  4.24108923e-02 -2.53999953e-02
 -2.47000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.96936582e-01  4.32361244e-01 -5.72850429e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 421 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.13310892e-01 -2.69999953e-02
 -5.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.68944324e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 422 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.35410892e-01  4.48000047e-02
  2.36999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.39982990e-01  3.28420700e-01 -4.44519591e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08  9.79108923e-02 -3.07999953e-02
 -8.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.30596921e-01  3.81428057e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 463 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.10892292e-04  3.64000047e-02
 -5.00001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.64499871e-01  1.94654762e-01  2.53659250e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 464 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.07510892e-01 -9.99995308e-04
  2.90999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08  4.27610892e-01 -2.69999953e-02
 -2.42000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.47210887e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 506 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.92108923e-02  5.98000047e-02
 -1.50000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.86552443e-01 -3.30897196e-02  3.81565194e-02 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 507 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.71489108e-01 -3.63999953e-02
 -1.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.50002015e-02 -4.59228413e-01  1.85328784e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08  3.06410892e-01 -2.75999953e-02
 -2.58000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 546 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.93689108e-01 -4.61999953e-02
 -2.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -9.38454379e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 547 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.79310892e-01  5.60000047e-02
  2.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.77325909e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -3.61389108e-01  3.10000047e-02
  2.16999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.59282033e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 588 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.35410892e-01 -3.61999953e-02
 -3.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.60741722e-02  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 589 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.02108923e-02  4.70000047e-02
  1.75999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.48938817e-01  4.91111161e-02 -1.57500779e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -3.10989108e-01 -2.43999953e-02
  2.71999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 630 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.21789108e-01 -2.17999953e-02
 -2.92000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -3.82416703e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 631 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.38289108e-01 -5.57999953e-02
 -1.59000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.62532413e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08 -1.38489108e-01 -2.81999953e-02
  4.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.13255921e-01 -4.32566674e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 672 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.55389108e-01  6.80000469e-03
  1.46999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.32480300e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 673 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.64510892e-01  3.54000047e-02
 -2.15000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -9.66830564e-02  6.00000000e-01  8.48428897e-03 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08 -1.19889108e-01 -3.89999953e-02
 -2.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.47814416e-02  4.88277814e-02  2.15746351e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 713 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.06010892e-01  2.40000469e-03
 -1.79000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.75113582e-01  4.91313101e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 714 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.34610892e-01  8.00004692e-04
  1.94999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  5.90440929e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -9.43891077e-02 -2.07999953e-02
  1.10999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.62169314e-01  5.48901001e-03  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 755 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.91089108e-01 -1.59999953e-02
  1.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.11554359e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 756 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -5.17891077e-02  1.26000047e-02
 -2.89000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  4.16364354e-04  1.35081714e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

start
[ 0.00000000e+00  7.95731469e+08 -8.92891077e-02 -3.59999531e-03
 -1.07000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.36854482e-01 -3.88110213e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 797 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.34110892e-01 -3.27999953e-02
 -2.75000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  5.45336489e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 798 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.32010892e-01  4.46000047e-02
  2.64999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.10023985e-01  4.11402961e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

start
[ 0.00000000e+00  7.95731469e+08  4.43210892e-01 -3.45999953e-02
 -1.11000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.98925499e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 840 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.85891077e-02 -3.55999953e-02
 -1.07000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.92845363e-02  2.31902228e-02  3.64467207e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 841 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.02710892e-01 -1.31999953e-02
 -2.19000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

[ 0.00000000e+00  7.95731469e+08 -4.06089108e-01 -5.17999953e-02
 -1.20000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.50229477e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 881 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.44289108e-01  2.46000047e-02
  1.88999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.25252275e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 882 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.19810892e-01  2.10000047e-02
  1.23999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.07509890e-02  3.81243688e-01 -5.66420114e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -5.59891077e-02 -5.01999953e-02
  2.85999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -3.95459471e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 921 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.01289108e-01  3.82000047e-02
 -1.08000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.36003757e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 922 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.61789108e-01  3.98000047e-02
  1.17999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.43758421e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

[ 0.00000000e+00  7.95731469e+08 -1.31889108e-01  4.84000047e-02
  2.36999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -4.34342001e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 962 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.24189108e-01  1.02000047e-02
  2.12999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 963 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.17310892e-01 -1.63999953e-02
 -5.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.05230079e-01  6.00000000e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06 

In [4]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -128.34366,
 'std': 83.63828,
 'min': -410.7058,
 'max': 0.0,
 'histogram': (array([  4,   2,  11,  11,  22,  18,  25,  33,  39,  22,  22,  41,  73,
          76, 103, 118, 146, 179,  54,   1]),
  array([-410.7058  , -390.17053 , -369.63522 , -349.09995 , -328.56464 ,
         -308.02936 , -287.49408 , -266.95877 , -246.4235  , -225.8882  ,
         -205.3529  , -184.81761 , -164.28232 , -143.74704 , -123.21175 ,
         -102.67645 ,  -82.14116 ,  -61.605873,  -41.07058 ,  -20.53529 ,
            0.      ], dtype=float32))}

In [5]:
# plt.plot(states[:,6])

In [6]:
# plt.plot(actions)

In [7]:
#plt.plot(model.predict(np.array(states)))

In [8]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=5)

## Setting up an Algorithm

In [9]:
from d3rlpy.algos import CQL
from d3rlpy.models.encoders import VectorEncoderFactory

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

#actor_encoder = VectorEncoderFactory(hidden_units=[12, 24, 36, 24, 12],
#                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)
#critic_encoder = VectorEncoderFactory(hidden_units=[12, 24, 24, 12],
#                                       activation='relu', use_batch_norm=True, dropout_rate=0.2)

model = CQL(q_func_factory='mean', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=False) #change it to true if you have one
model.build_with_dataset(dataset)

In [10]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

0.036240149109481694


In [11]:
%load_ext tensorboard
%tensorboard --logdir runs

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
Traceback (most recent call last):
  File "/home/dasc/anaconda3/envs/jbreeden3.10/bin/tensorboard", line 6, in <module>
    from tensorboard.main import run_main
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/main.py", line 40, in <module>
    from tensorboard import default
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/default.py", line 38, in <module>
    from tensorboard.plugins.audio import audio_plugin
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugins/audio/audio_plugin.py", line 25, in <module>
    from tensorboard import plugin_util
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugin_util.py", line 21, in <module>
    from tensorboard._vendor import bleach
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorb

In [12]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=20, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-19 22:11.02 [debug    ] RoundIterator is selected.
2022-04-19 22:11.02 [info     ] Directory is created at d3rlpy_logs/CQL_20220419221102
2022-04-19 22:11.02 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-19 22:11.02 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-19 22:11.02 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220419221102/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conser

Epoch 1/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:17.55 [info     ] CQL_20220419221102: epoch=1 step=4309 epoch=1 metrics={'time_sample_batch': 0.0002942477005157152, 'time_algorithm_update': 0.0949602771120742, 'temp_loss': 3.3563365883162137, 'temp': 0.8274853446694421, 'alpha_loss': -5.3284055914767325, 'alpha': 1.1596206161311342, 'critic_loss': 14.556474138008909, 'actor_loss': 14.42031031683473, 'time_step': 0.09546220493471576, 'td_error': 42.75731813645616, 'init_value': -40.43320083618164, 'ave_value': -42.613514377727164} step=4309
2022-04-19 22:17.55 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_4309.pt


Epoch 2/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:24.31 [info     ] CQL_20220419221102: epoch=2 step=8618 epoch=2 metrics={'time_sample_batch': 0.00029686355458027414, 'time_algorithm_update': 0.09109495576392418, 'temp_loss': 1.433966846913134, 'temp': 0.5810934323104249, 'alpha_loss': 0.8743829604048756, 'alpha': 1.2082513440205842, 'critic_loss': 58.028899153248844, 'actor_loss': 59.137008949841054, 'time_step': 0.09161658169747948, 'td_error': 106.87199282273092, 'init_value': -89.94698333740234, 'ave_value': -92.61096563358834} step=8618
2022-04-19 22:24.31 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_8618.pt


Epoch 3/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:31.30 [info     ] CQL_20220419221102: epoch=3 step=12927 epoch=3 metrics={'time_sample_batch': 0.0003005738434312404, 'time_algorithm_update': 0.09634248555799349, 'temp_loss': 0.749083649471703, 'temp': 0.4081343829106167, 'alpha_loss': 0.650248865760611, 'alpha': 0.9301082586766285, 'critic_loss': 146.9632348514649, 'actor_loss': 100.72336047020498, 'time_step': 0.09687018062487975, 'td_error': 193.55068503932813, 'init_value': -131.15982055664062, 'ave_value': -136.00240695253217} step=12927
2022-04-19 22:31.30 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_12927.pt


Epoch 4/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:39.49 [info     ] CQL_20220419221102: epoch=4 step=17236 epoch=4 metrics={'time_sample_batch': 0.0003109431427242751, 'time_algorithm_update': 0.11488491555437684, 'temp_loss': 0.39013049914917663, 'temp': 0.2848145050089828, 'alpha_loss': 0.2087285240038324, 'alpha': 0.8030027640089963, 'critic_loss': 240.0664827328432, 'actor_loss': 134.08949815206114, 'time_step': 0.11542637156178653, 'td_error': 261.57004429956646, 'init_value': -156.93736267089844, 'ave_value': -168.5158550030906} step=17236
2022-04-19 22:39.49 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_17236.pt


Epoch 5/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:48.47 [info     ] CQL_20220419221102: epoch=5 step=21545 epoch=5 metrics={'time_sample_batch': 0.00031468917499647586, 'time_algorithm_update': 0.1236955491537615, 'temp_loss': 0.17618451108762648, 'temp': 0.20046993068543242, 'alpha_loss': -0.08472264432545398, 'alpha': 0.7743930268188064, 'critic_loss': 282.5280676681499, 'actor_loss': 156.27328294146278, 'time_step': 0.12424190970080223, 'td_error': 252.4103402473303, 'init_value': -170.63327026367188, 'ave_value': -182.52380850228909} step=21545
2022-04-19 22:48.47 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_21545.pt


Epoch 6/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 22:57.45 [info     ] CQL_20220419221102: epoch=6 step=25854 epoch=6 metrics={'time_sample_batch': 0.00031902801688958495, 'time_algorithm_update': 0.12379728610733043, 'temp_loss': -0.004596691555203081, 'temp': 0.1719077514589566, 'alpha_loss': -0.17038375066212785, 'alpha': 0.8714529377010423, 'critic_loss': 236.80534115554727, 'actor_loss': 158.33001848682503, 'time_step': 0.12435068777637212, 'td_error': 198.5880509401033, 'init_value': -165.71743774414062, 'ave_value': -172.61232575403275} step=25854
2022-04-19 22:57.45 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_25854.pt


Epoch 7/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:06.54 [info     ] CQL_20220419221102: epoch=7 step=30163 epoch=7 metrics={'time_sample_batch': 0.0003180600120064609, 'time_algorithm_update': 0.12633513673935945, 'temp_loss': -0.017925693101976278, 'temp': 0.21080256721887103, 'alpha_loss': 0.10270508740004454, 'alpha': 0.8559894490391396, 'critic_loss': 186.372580247181, 'actor_loss': 152.10947347921183, 'time_step': 0.12688564020344098, 'td_error': 161.79219923503183, 'init_value': -160.05250549316406, 'ave_value': -164.16293524019352} step=30163
2022-04-19 23:06.54 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_30163.pt


Epoch 8/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:15.52 [info     ] CQL_20220419221102: epoch=8 step=34472 epoch=8 metrics={'time_sample_batch': 0.000289961588468483, 'time_algorithm_update': 0.12377106222151382, 'temp_loss': -0.005759977841998125, 'temp': 0.23127082938812094, 'alpha_loss': 0.22606762075897036, 'alpha': 0.7972883741743301, 'critic_loss': 155.25799165902987, 'actor_loss': 144.18125360630094, 'time_step': 0.12427494536422194, 'td_error': 138.43777733396627, 'init_value': -152.90280151367188, 'ave_value': -154.5705534830072} step=34472
2022-04-19 23:15.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_34472.pt


Epoch 9/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:24.25 [info     ] CQL_20220419221102: epoch=9 step=38781 epoch=9 metrics={'time_sample_batch': 0.0002788225883506904, 'time_algorithm_update': 0.11802902656576422, 'temp_loss': 0.007313199625486699, 'temp': 0.23281729777760063, 'alpha_loss': 0.14905008324245658, 'alpha': 0.7144407852671489, 'critic_loss': 133.52626686744662, 'actor_loss': 135.68263947191855, 'time_step': 0.11851645982246106, 'td_error': 118.34011759411678, 'init_value': -146.5259246826172, 'ave_value': -144.604932321353} step=38781
2022-04-19 23:24.25 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_38781.pt


Epoch 10/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:32.49 [info     ] CQL_20220419221102: epoch=10 step=43090 epoch=10 metrics={'time_sample_batch': 0.00027853392979251617, 'time_algorithm_update': 0.11615104331138046, 'temp_loss': 0.009121860412896015, 'temp': 0.2110497505441014, 'alpha_loss': 0.1223275434449392, 'alpha': 0.661769500531794, 'critic_loss': 118.33193561227948, 'actor_loss': 126.44100323241281, 'time_step': 0.11663800548527793, 'td_error': 105.96753625339002, 'init_value': -139.5361785888672, 'ave_value': -134.54899241407847} step=43090
2022-04-19 23:32.49 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_43090.pt


Epoch 11/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:40.19 [info     ] CQL_20220419221102: epoch=11 step=47399 epoch=11 metrics={'time_sample_batch': 0.0002738023529088146, 'time_algorithm_update': 0.10345847564054549, 'temp_loss': 0.008475536929335304, 'temp': 0.19362516949170702, 'alpha_loss': 0.12927227651405057, 'alpha': 0.6182584137893327, 'critic_loss': 106.1702196261921, 'actor_loss': 115.99569073772453, 'time_step': 0.10393995291452503, 'td_error': 96.32544417943325, 'init_value': -130.2510528564453, 'ave_value': -123.55505539042456} step=47399
2022-04-19 23:40.19 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_47399.pt


Epoch 12/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:47.21 [info     ] CQL_20220419221102: epoch=12 step=51708 epoch=12 metrics={'time_sample_batch': 0.000271410808178133, 'time_algorithm_update': 0.09735728660241776, 'temp_loss': 0.005181032489167375, 'temp': 0.17666317732781206, 'alpha_loss': 0.16099750202322236, 'alpha': 0.5571298645582352, 'critic_loss': 95.83278830368465, 'actor_loss': 104.9805906627316, 'time_step': 0.09783677408550809, 'td_error': 86.06879753195109, 'init_value': -118.01374816894531, 'ave_value': -111.48864969556128} step=51708
2022-04-19 23:47.21 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_51708.pt


Epoch 13/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-19 23:53.53 [info     ] CQL_20220419221102: epoch=13 step=56017 epoch=13 metrics={'time_sample_batch': 0.00026742984314035665, 'time_algorithm_update': 0.09007087537323302, 'temp_loss': 0.0016318530144548594, 'temp': 0.17279985630752925, 'alpha_loss': 0.20061577807912073, 'alpha': 0.48085259131777724, 'critic_loss': 85.32960751931445, 'actor_loss': 93.63471656333655, 'time_step': 0.0905451445381446, 'td_error': 77.19506397297816, 'init_value': -107.0666275024414, 'ave_value': -99.88611851515265} step=56017
2022-04-19 23:53.53 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_56017.pt


Epoch 14/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:00.29 [info     ] CQL_20220419221102: epoch=14 step=60326 epoch=14 metrics={'time_sample_batch': 0.00026590482738793217, 'time_algorithm_update': 0.09130026055311874, 'temp_loss': 0.0035376046392694895, 'temp': 0.1671432684402395, 'alpha_loss': 0.11052481622281572, 'alpha': 0.4139938001415242, 'critic_loss': 75.44673093534395, 'actor_loss': 82.71799307783564, 'time_step': 0.09177315901868939, 'td_error': 68.67558767531212, 'init_value': -95.81947326660156, 'ave_value': -88.19525177018801} step=60326
2022-04-20 00:00.29 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_60326.pt


Epoch 15/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:06.58 [info     ] CQL_20220419221102: epoch=15 step=64635 epoch=15 metrics={'time_sample_batch': 0.0002666952771027256, 'time_algorithm_update': 0.0894916065886226, 'temp_loss': 0.0014608436022216424, 'temp': 0.16087286907078965, 'alpha_loss': 0.04240428444321001, 'alpha': 0.3821500287403997, 'critic_loss': 66.284584440718, 'actor_loss': 72.8176559258791, 'time_step': 0.0899654419634065, 'td_error': 61.055948395236726, 'init_value': -85.58955383300781, 'ave_value': -77.41612184125024} step=64635
2022-04-20 00:06.58 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_64635.pt


Epoch 16/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:13.27 [info     ] CQL_20220419221102: epoch=16 step=68944 epoch=16 metrics={'time_sample_batch': 0.0002671359835270802, 'time_algorithm_update': 0.08951672065780898, 'temp_loss': 0.0014139697121347767, 'temp': 0.15803881152776145, 'alpha_loss': 0.02480023988895907, 'alpha': 0.3722928632824513, 'critic_loss': 59.65695867021943, 'actor_loss': 63.787340154889066, 'time_step': 0.08999228593671811, 'td_error': 55.27025692120207, 'init_value': -77.82819366455078, 'ave_value': -69.48831853793817} step=68944
2022-04-20 00:13.27 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_68944.pt


Epoch 17/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:20.03 [info     ] CQL_20220419221102: epoch=17 step=73253 epoch=17 metrics={'time_sample_batch': 0.00026828740859811837, 'time_algorithm_update': 0.09110659284806356, 'temp_loss': -0.0005183148424473523, 'temp': 0.15847869600904432, 'alpha_loss': 0.037226675149273775, 'alpha': 0.3562882154605925, 'critic_loss': 54.49308914536803, 'actor_loss': 55.96208312157883, 'time_step': 0.09158216958036, 'td_error': 50.78203982574372, 'init_value': -71.2973861694336, 'ave_value': -61.40176129468658} step=73253
2022-04-20 00:20.03 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_73253.pt


Epoch 18/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:26.34 [info     ] CQL_20220419221102: epoch=18 step=77562 epoch=18 metrics={'time_sample_batch': 0.00026656043698268137, 'time_algorithm_update': 0.09003088208134585, 'temp_loss': -0.0007079859010183163, 'temp': 0.1591937603412966, 'alpha_loss': 0.03546344073660514, 'alpha': 0.3395906947665548, 'critic_loss': 50.38716014073658, 'actor_loss': 49.12737749435698, 'time_step': 0.09050565447600377, 'td_error': 47.37245499388279, 'init_value': -65.65174865722656, 'ave_value': -54.84566032555055} step=77562
2022-04-20 00:26.34 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_77562.pt


Epoch 19/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:33.00 [info     ] CQL_20220419221102: epoch=19 step=81871 epoch=19 metrics={'time_sample_batch': 0.0002866302020147522, 'time_algorithm_update': 0.08872183106786055, 'temp_loss': 0.001319684875674825, 'temp': 0.15800843214416038, 'alpha_loss': 0.03141124635772278, 'alpha': 0.3235924220688245, 'critic_loss': 47.11704555173904, 'actor_loss': 43.250848152211994, 'time_step': 0.08922901107086939, 'td_error': 44.971144851405676, 'init_value': -60.14690399169922, 'ave_value': -49.00401668789419} step=81871
2022-04-20 00:33.00 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_81871.pt


Epoch 20/20:   0%|          | 0/4309 [00:00<?, ?it/s]

2022-04-20 00:39.24 [info     ] CQL_20220419221102: epoch=20 step=86180 epoch=20 metrics={'time_sample_batch': 0.00029989416512404994, 'time_algorithm_update': 0.08839562365836681, 'temp_loss': 0.0002608095024154541, 'temp': 0.15719266224762324, 'alpha_loss': 0.024233677144664956, 'alpha': 0.307824973510628, 'critic_loss': 45.30775872274739, 'actor_loss': 38.40102719209343, 'time_step': 0.08892502378604063, 'td_error': 43.14962546233025, 'init_value': -56.522430419921875, 'ave_value': -44.782660068351404} step=86180
2022-04-20 00:39.24 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220419221102/model_86180.pt


[(1,
  {'time_sample_batch': 0.0002942477005157152,
   'time_algorithm_update': 0.0949602771120742,
   'temp_loss': 3.3563365883162137,
   'temp': 0.8274853446694421,
   'alpha_loss': -5.3284055914767325,
   'alpha': 1.1596206161311342,
   'critic_loss': 14.556474138008909,
   'actor_loss': 14.42031031683473,
   'time_step': 0.09546220493471576,
   'td_error': 42.75731813645616,
   'init_value': -40.43320083618164,
   'ave_value': -42.613514377727164}),
 (2,
  {'time_sample_batch': 0.00029686355458027414,
   'time_algorithm_update': 0.09109495576392418,
   'temp_loss': 1.433966846913134,
   'temp': 0.5810934323104249,
   'alpha_loss': 0.8743829604048756,
   'alpha': 1.2082513440205842,
   'critic_loss': 58.028899153248844,
   'actor_loss': 59.137008949841054,
   'time_step': 0.09161658169747948,
   'td_error': 106.87199282273092,
   'init_value': -89.94698333740234,
   'ave_value': -92.61096563358834}),
 (3,
  {'time_sample_batch': 0.0003005738434312404,
   'time_algorithm_update': 0.0

## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [13]:
from d3rlpy.ope import FQE
# metrics to evaluate with
from d3rlpy.metrics.scorer import soft_opc_scorer


ope_dataset, states_ope, actions_ope, rewards_ope, metrics_ope = get_dataset([500+i for i in range(50)], path="collected_data/rl_stochpid.txt") #change if you'd prefer different chunks
ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=5)

fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
fqe.fit(ope_train_episodes,
        eval_episodes=ope_test_episodes,
        tensorboard_dir='runs',
        n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
        scorers={
           'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer,
           'soft_opc': soft_opc_scorer(return_threshold=0)
        })

start
[ 0.00000000e+00  7.95731469e+08 -4.19889108e-01 -1.83999953e-02
  2.71999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 501 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.30889108e-01 -2.27999953e-02
 -3.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.19275575e-02 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 502 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.99289108e-01 -3.69999953e-02
  2.18999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.02595006e-01 -6.00000000e-01  6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

start
[ 0.00000000e+00  7.95731469e+08  4.78108923e-02 -2.91999953e-02
  1.03999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.09695008e-02 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 542 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.13810892e-01  4.16000047e-02
  2.69999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.26782935e-01 -6.00000000e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.77814500e+06  0.00000000e+00
  0.00000000e+00  5.33423489e+00 -6.44487563e-06 -2.76958740e-05
 -7.93993673e-06]
Read chunk # 543 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.69891077e-02  1.12000047e-02
  2.97999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.70870777e-01 -2.76326953e-01  5.89576724e-01 -1.48980673e+11
  2.13971128e+07  0.00000000e+00  6.7781450

Epoch 1/50:   0%|          | 0/488 [00:00<?, ?it/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


2022-04-20 00:39.26 [info     ] FQE_20220420003924: epoch=1 step=488 epoch=1 metrics={'time_sample_batch': 0.00012289451771095152, 'time_algorithm_update': 0.0027904974632575862, 'loss': 0.006417373936722574, 'time_step': 0.002968227277036573, 'init_value': -0.5982769727706909, 'ave_value': -0.43931858867675333, 'soft_opc': nan} step=488




2022-04-20 00:39.26 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_488.pt


Epoch 2/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.28 [info     ] FQE_20220420003924: epoch=2 step=976 epoch=2 metrics={'time_sample_batch': 0.00012487124224178127, 'time_algorithm_update': 0.0029631032318365377, 'loss': 0.005279628622925795, 'time_step': 0.0031398119496517493, 'init_value': -0.7484461069107056, 'ave_value': -0.5154725129056621, 'soft_opc': nan} step=976




2022-04-20 00:39.28 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_976.pt


Epoch 3/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.29 [info     ] FQE_20220420003924: epoch=3 step=1464 epoch=3 metrics={'time_sample_batch': 0.00012738050007429278, 'time_algorithm_update': 0.003082964263978552, 'loss': 0.004364420129530147, 'time_step': 0.0032620928326591117, 'init_value': -0.8721564412117004, 'ave_value': -0.5933400834492735, 'soft_opc': nan} step=1464




2022-04-20 00:39.29 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_1464.pt


Epoch 4/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.31 [info     ] FQE_20220420003924: epoch=4 step=1952 epoch=4 metrics={'time_sample_batch': 0.00013102371184552303, 'time_algorithm_update': 0.0031307862430322367, 'loss': 0.003956772288504118, 'time_step': 0.0033208200188933825, 'init_value': -0.983167827129364, 'ave_value': -0.6887489480639363, 'soft_opc': nan} step=1952




2022-04-20 00:39.31 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_1952.pt


Epoch 5/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.33 [info     ] FQE_20220420003924: epoch=5 step=2440 epoch=5 metrics={'time_sample_batch': 0.00013042277977114817, 'time_algorithm_update': 0.0031579337159141165, 'loss': 0.004161490146051635, 'time_step': 0.0033448372707992305, 'init_value': -1.153977870941162, 'ave_value': -0.8386571701126055, 'soft_opc': nan} step=2440




2022-04-20 00:39.33 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_2440.pt


Epoch 6/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.34 [info     ] FQE_20220420003924: epoch=6 step=2928 epoch=6 metrics={'time_sample_batch': 0.00012741567658596352, 'time_algorithm_update': 0.0029806670595387942, 'loss': 0.004999283205700626, 'time_step': 0.0031634134347321556, 'init_value': -1.373203992843628, 'ave_value': -1.0266364138271358, 'soft_opc': nan} step=2928




2022-04-20 00:39.34 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_2928.pt


Epoch 7/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.36 [info     ] FQE_20220420003924: epoch=7 step=3416 epoch=7 metrics={'time_sample_batch': 0.00014430919631582792, 'time_algorithm_update': 0.0035258653711100095, 'loss': 0.006157608581234541, 'time_step': 0.0037346970839578597, 'init_value': -1.5891978740692139, 'ave_value': -1.203451210974573, 'soft_opc': nan} step=3416




2022-04-20 00:39.36 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_3416.pt


Epoch 8/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.38 [info     ] FQE_20220420003924: epoch=8 step=3904 epoch=8 metrics={'time_sample_batch': 0.0001422855697694372, 'time_algorithm_update': 0.0034551884307235966, 'loss': 0.007524814536378979, 'time_step': 0.0036637250517235426, 'init_value': -1.8449022769927979, 'ave_value': -1.4208285344882055, 'soft_opc': nan} step=3904




2022-04-20 00:39.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_3904.pt


Epoch 9/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.40 [info     ] FQE_20220420003924: epoch=9 step=4392 epoch=9 metrics={'time_sample_batch': 0.0001346405412330002, 'time_algorithm_update': 0.0033497092176656253, 'loss': 0.008831948822471083, 'time_step': 0.00354531358499996, 'init_value': -1.9440587759017944, 'ave_value': -1.4518538788820172, 'soft_opc': nan} step=4392




2022-04-20 00:39.40 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_4392.pt


Epoch 10/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.42 [info     ] FQE_20220420003924: epoch=10 step=4880 epoch=10 metrics={'time_sample_batch': 0.00012271228383799068, 'time_algorithm_update': 0.00301492947046874, 'loss': 0.009790376015293763, 'time_step': 0.003193025706244297, 'init_value': -2.024367332458496, 'ave_value': -1.5473622371484568, 'soft_opc': nan} step=4880




2022-04-20 00:39.42 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_4880.pt


Epoch 11/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.43 [info     ] FQE_20220420003924: epoch=11 step=5368 epoch=11 metrics={'time_sample_batch': 0.0001337953278275787, 'time_algorithm_update': 0.0032646387326912803, 'loss': 0.011848693605002824, 'time_step': 0.003457534020064307, 'init_value': -2.3174424171447754, 'ave_value': -1.7441613090736372, 'soft_opc': nan} step=5368




2022-04-20 00:39.43 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_5368.pt


Epoch 12/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.45 [info     ] FQE_20220420003924: epoch=12 step=5856 epoch=12 metrics={'time_sample_batch': 0.00013716201313206407, 'time_algorithm_update': 0.003342354395350472, 'loss': 0.01320765173223946, 'time_step': 0.003540106972710031, 'init_value': -2.54571270942688, 'ave_value': -1.9141415060667304, 'soft_opc': nan} step=5856




2022-04-20 00:39.45 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_5856.pt


Epoch 13/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.47 [info     ] FQE_20220420003924: epoch=13 step=6344 epoch=13 metrics={'time_sample_batch': 0.00013066803822751904, 'time_algorithm_update': 0.0031198395080253728, 'loss': 0.015324753359891474, 'time_step': 0.0033090656898060785, 'init_value': -2.7945895195007324, 'ave_value': -2.1522641044640327, 'soft_opc': nan} step=6344




2022-04-20 00:39.47 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_6344.pt


Epoch 14/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.49 [info     ] FQE_20220420003924: epoch=14 step=6832 epoch=14 metrics={'time_sample_batch': 0.00013119177740128314, 'time_algorithm_update': 0.0032259230730963535, 'loss': 0.017214422466302077, 'time_step': 0.0034163858069748173, 'init_value': -3.0538666248321533, 'ave_value': -2.3805908503618327, 'soft_opc': nan} step=6832




2022-04-20 00:39.49 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_6832.pt


Epoch 15/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.50 [info     ] FQE_20220420003924: epoch=15 step=7320 epoch=15 metrics={'time_sample_batch': 0.00013633878504643675, 'time_algorithm_update': 0.003261429364563989, 'loss': 0.019388009240606525, 'time_step': 0.003454603132654409, 'init_value': -3.2629752159118652, 'ave_value': -2.5689745998704754, 'soft_opc': nan} step=7320




2022-04-20 00:39.50 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_7320.pt


Epoch 16/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.52 [info     ] FQE_20220420003924: epoch=16 step=7808 epoch=16 metrics={'time_sample_batch': 0.00012590846077340548, 'time_algorithm_update': 0.003084328330931116, 'loss': 0.02038884438339575, 'time_step': 0.003268139772727841, 'init_value': -3.348785877227783, 'ave_value': -2.628985432794502, 'soft_opc': nan} step=7808




2022-04-20 00:39.52 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_7808.pt


Epoch 17/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.54 [info     ] FQE_20220420003924: epoch=17 step=8296 epoch=17 metrics={'time_sample_batch': 0.00013634367067305768, 'time_algorithm_update': 0.0033837713179041126, 'loss': 0.022147912566443204, 'time_step': 0.0035795080857198747, 'init_value': -3.4434609413146973, 'ave_value': -2.7030916678126866, 'soft_opc': nan} step=8296




2022-04-20 00:39.54 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_8296.pt


Epoch 18/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.55 [info     ] FQE_20220420003924: epoch=18 step=8784 epoch=18 metrics={'time_sample_batch': 0.00012732529249347624, 'time_algorithm_update': 0.0031359576788104948, 'loss': 0.02448768632249815, 'time_step': 0.003319314268768811, 'init_value': -3.6947708129882812, 'ave_value': -2.9765469510651923, 'soft_opc': nan} step=8784




2022-04-20 00:39.55 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_8784.pt


Epoch 19/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.57 [info     ] FQE_20220420003924: epoch=19 step=9272 epoch=19 metrics={'time_sample_batch': 0.00014022530102338946, 'time_algorithm_update': 0.0034393296867120463, 'loss': 0.025644345533438636, 'time_step': 0.003643076927935491, 'init_value': -3.6932075023651123, 'ave_value': -2.994259896493173, 'soft_opc': nan} step=9272




2022-04-20 00:39.57 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_9272.pt


Epoch 20/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:39.59 [info     ] FQE_20220420003924: epoch=20 step=9760 epoch=20 metrics={'time_sample_batch': 0.00013383538996587034, 'time_algorithm_update': 0.0032492152980116546, 'loss': 0.027254946861219532, 'time_step': 0.003442578139852305, 'init_value': -3.8820242881774902, 'ave_value': -3.1814473303534965, 'soft_opc': nan} step=9760




2022-04-20 00:39.59 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_9760.pt


Epoch 21/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.01 [info     ] FQE_20220420003924: epoch=21 step=10248 epoch=21 metrics={'time_sample_batch': 0.00013568997383117676, 'time_algorithm_update': 0.003350267644788398, 'loss': 0.027967500375377655, 'time_step': 0.00354421431901025, 'init_value': -3.8817286491394043, 'ave_value': -3.3098189225873433, 'soft_opc': nan} step=10248




2022-04-20 00:40.01 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_10248.pt


Epoch 22/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.03 [info     ] FQE_20220420003924: epoch=22 step=10736 epoch=22 metrics={'time_sample_batch': 0.00015444149736498224, 'time_algorithm_update': 0.003778496238051868, 'loss': 0.029363211961682564, 'time_step': 0.00400198043369856, 'init_value': -4.0058794021606445, 'ave_value': -3.4362077859070923, 'soft_opc': nan} step=10736




2022-04-20 00:40.03 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_10736.pt


Epoch 23/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.05 [info     ] FQE_20220420003924: epoch=23 step=11224 epoch=23 metrics={'time_sample_batch': 0.00012813239801125447, 'time_algorithm_update': 0.003098508373635714, 'loss': 0.030118491618224174, 'time_step': 0.003283233916173216, 'init_value': -4.065986156463623, 'ave_value': -3.611926514076757, 'soft_opc': nan} step=11224




2022-04-20 00:40.05 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_11224.pt


Epoch 24/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.07 [info     ] FQE_20220420003924: epoch=24 step=11712 epoch=24 metrics={'time_sample_batch': 0.00014258310443065206, 'time_algorithm_update': 0.0035002817873094902, 'loss': 0.031683429674632835, 'time_step': 0.0037075852761503125, 'init_value': -4.153302192687988, 'ave_value': -3.648630090914331, 'soft_opc': nan} step=11712




2022-04-20 00:40.07 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_11712.pt


Epoch 25/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.08 [info     ] FQE_20220420003924: epoch=25 step=12200 epoch=25 metrics={'time_sample_batch': 0.00013553265665398268, 'time_algorithm_update': 0.0033124924683180013, 'loss': 0.03221388939208901, 'time_step': 0.0035091980558926944, 'init_value': -4.2006402015686035, 'ave_value': -3.802707265824885, 'soft_opc': nan} step=12200




2022-04-20 00:40.08 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_12200.pt


Epoch 26/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.10 [info     ] FQE_20220420003924: epoch=26 step=12688 epoch=26 metrics={'time_sample_batch': 0.00013553021384067222, 'time_algorithm_update': 0.003337342719562718, 'loss': 0.033268087292347696, 'time_step': 0.0035342911227804717, 'init_value': -4.172314643859863, 'ave_value': -3.87185348237957, 'soft_opc': nan} step=12688




2022-04-20 00:40.10 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_12688.pt


Epoch 27/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.12 [info     ] FQE_20220420003924: epoch=27 step=13176 epoch=27 metrics={'time_sample_batch': 0.000145026406303781, 'time_algorithm_update': 0.003570414957452993, 'loss': 0.033936252757634004, 'time_step': 0.003779176805840164, 'init_value': -4.2653350830078125, 'ave_value': -4.0487124434032955, 'soft_opc': nan} step=13176




2022-04-20 00:40.12 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_13176.pt


Epoch 28/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.14 [info     ] FQE_20220420003924: epoch=28 step=13664 epoch=28 metrics={'time_sample_batch': 0.0001334039891352419, 'time_algorithm_update': 0.0032691613572542783, 'loss': 0.03494575255470289, 'time_step': 0.0034627464951061814, 'init_value': -4.233272552490234, 'ave_value': -4.085520140641444, 'soft_opc': nan} step=13664




2022-04-20 00:40.14 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_13664.pt


Epoch 29/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.16 [info     ] FQE_20220420003924: epoch=29 step=14152 epoch=29 metrics={'time_sample_batch': 0.00012761891865339435, 'time_algorithm_update': 0.0031704560655062314, 'loss': 0.03574185624034488, 'time_step': 0.0033553286654050235, 'init_value': -4.314733982086182, 'ave_value': -4.2874289956178755, 'soft_opc': nan} step=14152




2022-04-20 00:40.16 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_14152.pt


Epoch 30/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.17 [info     ] FQE_20220420003924: epoch=30 step=14640 epoch=30 metrics={'time_sample_batch': 0.00014407859473931984, 'time_algorithm_update': 0.0034528887662731234, 'loss': 0.035811741158154366, 'time_step': 0.0036592385807975394, 'init_value': -4.177708625793457, 'ave_value': -4.301293865367099, 'soft_opc': nan} step=14640




2022-04-20 00:40.18 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_14640.pt


Epoch 31/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.19 [info     ] FQE_20220420003924: epoch=31 step=15128 epoch=31 metrics={'time_sample_batch': 0.0001277034399939365, 'time_algorithm_update': 0.003222054633937898, 'loss': 0.03749301874596, 'time_step': 0.003407031297683716, 'init_value': -4.229691505432129, 'ave_value': -4.537267483483563, 'soft_opc': nan} step=15128




2022-04-20 00:40.19 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_15128.pt


Epoch 32/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.21 [info     ] FQE_20220420003924: epoch=32 step=15616 epoch=32 metrics={'time_sample_batch': 0.00014084724129223434, 'time_algorithm_update': 0.0034241818013738415, 'loss': 0.038104243888080765, 'time_step': 0.0036288240893942413, 'init_value': -4.211088180541992, 'ave_value': -4.606346569050539, 'soft_opc': nan} step=15616




2022-04-20 00:40.21 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_15616.pt


Epoch 33/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.23 [info     ] FQE_20220420003924: epoch=33 step=16104 epoch=33 metrics={'time_sample_batch': 0.00013874397903192238, 'time_algorithm_update': 0.003461291066935805, 'loss': 0.03749363966091521, 'time_step': 0.0036630576751271234, 'init_value': -4.086598873138428, 'ave_value': -4.586143603518202, 'soft_opc': nan} step=16104




2022-04-20 00:40.23 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_16104.pt


Epoch 34/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.25 [info     ] FQE_20220420003924: epoch=34 step=16592 epoch=34 metrics={'time_sample_batch': 0.00012963228538388113, 'time_algorithm_update': 0.0032400777105425225, 'loss': 0.03751879942185459, 'time_step': 0.003429097718879825, 'init_value': -4.2154364585876465, 'ave_value': -4.795114633961841, 'soft_opc': nan} step=16592




2022-04-20 00:40.25 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_16592.pt


Epoch 35/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.27 [info     ] FQE_20220420003924: epoch=35 step=17080 epoch=35 metrics={'time_sample_batch': 0.00014077444545558242, 'time_algorithm_update': 0.0034492582571311074, 'loss': 0.03751981660669296, 'time_step': 0.0036531525557158425, 'init_value': -4.028115749359131, 'ave_value': -4.665661801325308, 'soft_opc': nan} step=17080




2022-04-20 00:40.27 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_17080.pt


Epoch 36/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.29 [info     ] FQE_20220420003924: epoch=36 step=17568 epoch=36 metrics={'time_sample_batch': 0.0001462922721612649, 'time_algorithm_update': 0.0035655503390265293, 'loss': 0.03825420593665378, 'time_step': 0.0037780076753897744, 'init_value': -4.002947807312012, 'ave_value': -4.678927783482783, 'soft_opc': nan} step=17568




2022-04-20 00:40.29 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_17568.pt


Epoch 37/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.30 [info     ] FQE_20220420003924: epoch=37 step=18056 epoch=37 metrics={'time_sample_batch': 0.00014048570492228524, 'time_algorithm_update': 0.0034595586237360218, 'loss': 0.03819517409223796, 'time_step': 0.003661914438497825, 'init_value': -4.079201698303223, 'ave_value': -4.855711725756929, 'soft_opc': nan} step=18056




2022-04-20 00:40.30 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_18056.pt


Epoch 38/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.32 [info     ] FQE_20220420003924: epoch=38 step=18544 epoch=38 metrics={'time_sample_batch': 0.0001389130217130067, 'time_algorithm_update': 0.003409626542544756, 'loss': 0.03988853701224291, 'time_step': 0.0036112929953903447, 'init_value': -4.007801055908203, 'ave_value': -4.806240004805831, 'soft_opc': nan} step=18544




2022-04-20 00:40.32 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_18544.pt


Epoch 39/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.34 [info     ] FQE_20220420003924: epoch=39 step=19032 epoch=39 metrics={'time_sample_batch': 0.0001434268521480873, 'time_algorithm_update': 0.0034699293433642776, 'loss': 0.03945087136335282, 'time_step': 0.003677111180102239, 'init_value': -3.9443111419677734, 'ave_value': -4.824925829472843, 'soft_opc': nan} step=19032




2022-04-20 00:40.34 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_19032.pt


Epoch 40/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.36 [info     ] FQE_20220420003924: epoch=40 step=19520 epoch=40 metrics={'time_sample_batch': 0.00012788665099222152, 'time_algorithm_update': 0.0031498592407976996, 'loss': 0.040562854578802515, 'time_step': 0.0033323124783938046, 'init_value': -3.9190287590026855, 'ave_value': -4.824077691827808, 'soft_opc': nan} step=19520




2022-04-20 00:40.36 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_19520.pt


Epoch 41/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.38 [info     ] FQE_20220420003924: epoch=41 step=20008 epoch=41 metrics={'time_sample_batch': 0.00014985242827993925, 'time_algorithm_update': 0.0036218908966564743, 'loss': 0.040679092541170976, 'time_step': 0.0038381072341418655, 'init_value': -3.9594502449035645, 'ave_value': -4.913341886277672, 'soft_opc': nan} step=20008




2022-04-20 00:40.38 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_20008.pt


Epoch 42/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.40 [info     ] FQE_20220420003924: epoch=42 step=20496 epoch=42 metrics={'time_sample_batch': 0.00013968885922041095, 'time_algorithm_update': 0.0034329573639103623, 'loss': 0.042655397750452524, 'time_step': 0.003636683596939337, 'init_value': -4.011895656585693, 'ave_value': -4.957711175089484, 'soft_opc': nan} step=20496




2022-04-20 00:40.40 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_20496.pt


Epoch 43/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.41 [info     ] FQE_20220420003924: epoch=43 step=20984 epoch=43 metrics={'time_sample_batch': 0.0001327991485595703, 'time_algorithm_update': 0.0032549862001763014, 'loss': 0.043227249286500125, 'time_step': 0.0034463229726572507, 'init_value': -4.102808952331543, 'ave_value': -5.141284259665119, 'soft_opc': nan} step=20984




2022-04-20 00:40.41 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_20984.pt


Epoch 44/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.43 [info     ] FQE_20220420003924: epoch=44 step=21472 epoch=44 metrics={'time_sample_batch': 0.00014689076142232927, 'time_algorithm_update': 0.003626546410263562, 'loss': 0.043502941728115356, 'time_step': 0.003841032747362481, 'init_value': -4.056172847747803, 'ave_value': -5.022783115404146, 'soft_opc': nan} step=21472




2022-04-20 00:40.43 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_21472.pt


Epoch 45/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.45 [info     ] FQE_20220420003924: epoch=45 step=21960 epoch=45 metrics={'time_sample_batch': 0.00012523424429971663, 'time_algorithm_update': 0.0031223985992494176, 'loss': 0.04262225046853123, 'time_step': 0.0033010415366438567, 'init_value': -4.0525922775268555, 'ave_value': -4.900346231782758, 'soft_opc': nan} step=21960




2022-04-20 00:40.45 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_21960.pt


Epoch 46/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.47 [info     ] FQE_20220420003924: epoch=46 step=22448 epoch=46 metrics={'time_sample_batch': 0.00014514805840664223, 'time_algorithm_update': 0.0035717159998221474, 'loss': 0.04318135952649516, 'time_step': 0.0037819332763796946, 'init_value': -4.196473121643066, 'ave_value': -4.970783726954245, 'soft_opc': nan} step=22448




2022-04-20 00:40.47 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_22448.pt


Epoch 47/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.49 [info     ] FQE_20220420003924: epoch=47 step=22936 epoch=47 metrics={'time_sample_batch': 0.00014223573637790366, 'time_algorithm_update': 0.003471812752426648, 'loss': 0.04292831054629117, 'time_step': 0.0036784215051619733, 'init_value': -3.936641216278076, 'ave_value': -4.731990851137015, 'soft_opc': nan} step=22936




2022-04-20 00:40.49 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_22936.pt


Epoch 48/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.51 [info     ] FQE_20220420003924: epoch=48 step=23424 epoch=48 metrics={'time_sample_batch': 0.0001544742310633425, 'time_algorithm_update': 0.003713375720821443, 'loss': 0.043118747866084634, 'time_step': 0.003938336352832982, 'init_value': -3.9366371631622314, 'ave_value': -4.72132796328347, 'soft_opc': nan} step=23424




2022-04-20 00:40.51 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_23424.pt


Epoch 49/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.53 [info     ] FQE_20220420003924: epoch=49 step=23912 epoch=49 metrics={'time_sample_batch': 0.00013307127796235633, 'time_algorithm_update': 0.003349130270911045, 'loss': 0.043667806815317844, 'time_step': 0.0035412648662191924, 'init_value': -4.005101203918457, 'ave_value': -4.73658860536309, 'soft_opc': nan} step=23912




2022-04-20 00:40.53 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_23912.pt


Epoch 50/50:   0%|          | 0/488 [00:00<?, ?it/s]



2022-04-20 00:40.55 [info     ] FQE_20220420003924: epoch=50 step=24400 epoch=50 metrics={'time_sample_batch': 0.00013844986430934218, 'time_algorithm_update': 0.0034278616553447285, 'loss': 0.04401947254793085, 'time_step': 0.0036319538218076114, 'init_value': -4.013291835784912, 'ave_value': -4.730147983963425, 'soft_opc': nan} step=24400




2022-04-20 00:40.55 [info     ] Model parameters are saved to d3rlpy_logs/FQE_20220420003924/model_24400.pt


[(1,
  {'time_sample_batch': 0.00012289451771095152,
   'time_algorithm_update': 0.0027904974632575862,
   'loss': 0.006417373936722574,
   'time_step': 0.002968227277036573,
   'init_value': -0.5982769727706909,
   'ave_value': -0.43931858867675333,
   'soft_opc': nan}),
 (2,
  {'time_sample_batch': 0.00012487124224178127,
   'time_algorithm_update': 0.0029631032318365377,
   'loss': 0.005279628622925795,
   'time_step': 0.0031398119496517493,
   'init_value': -0.7484461069107056,
   'ave_value': -0.5154725129056621,
   'soft_opc': nan}),
 (3,
  {'time_sample_batch': 0.00012738050007429278,
   'time_algorithm_update': 0.003082964263978552,
   'loss': 0.004364420129530147,
   'time_step': 0.0032620928326591117,
   'init_value': -0.8721564412117004,
   'ave_value': -0.5933400834492735,
   'soft_opc': nan}),
 (4,
  {'time_sample_batch': 0.00013102371184552303,
   'time_algorithm_update': 0.0031307862430322367,
   'loss': 0.003956772288504118,
   'time_step': 0.0033208200188933825,
   'in

In [14]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset, states_ope, actions_ope, rewards_ope, metrics_ope = get_dataset([2,4,6,8], path="collected_data/rl_stoch_small.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=50, n_steps_per_epoch=1000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [15]:
model.save_policy("workflowConfound_CPUonly_redo.pt")
model.save_model("workflowConfoundmodel_CPUonly_redo.pt")

  minimum = torch.tensor(
  maximum = torch.tensor(


In [16]:
# from d3rlpy.torch_utility import to_cpu
# to_cpu(model)
# model.save_policy("workflowConfoundCPU.pt")
# model.save_model("workflowConfoundmodelCPU.pt")