# Sample Workflow for d3rlpy Experiments

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [2]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_deterministic.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data="collected_data/rl_stochpid2.txt")
    samples.setting("double")
    states = []
    actions = []
    rewards = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    terminals = np.zeros(len(states))
    terminals[::2222] = 1 #episode length 100, change if necessary
    print(states.shape)
    dataset = d3rlpy.dataset.MDPDataset(states.numpy(), 
                                        actions.numpy(), 
                                        rewards.numpy(), terminals)
    return dataset

We can build the dataset from there, just like this, and split into train and test sets.

In [3]:
dataset = get_dataset([i for i in range(2000)])

start
[ 0.00000000e+00  7.95731469e+08 -1.30891077e-02  1.00000047e-02
 -1.13000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.18111235e-01  7.86276171e-02  2.40256998e-01]
Read chunk # 1 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.14710892e-01  4.26000047e-02
  2.94999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 2 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.14710892e-01  4.26000047e-02
  2.94999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 3 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.13891077e-02  2.58000047e-02
 -1.58000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.21545490e-01  4.84695201e-02  4.95970982e-01]
Read chunk # 4 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.13891077e-02  2.58000047e-02
 -1.58000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.21545490

[ 0.00000000e+00  7.95731469e+08 -4.87891077e-02  1.90000047e-02
 -2.46000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  1.20983127e-01  6.00000000e-01]
Read chunk # 75 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.12089108e-01 -5.99999953e-02
 -4.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.16286537e-01 -2.80152941e-01  4.76846770e-01]
Read chunk # 76 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.12089108e-01 -5.99999953e-02
 -4.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.16286537e-01 -2.80152941e-01  4.76846770e-01]
Read chunk # 77 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.10889108e-01  2.00004692e-04
  2.79999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.42296700e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 78 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.10889108e-01  2.00004692e-04
  2.79999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.42296700e-

[ 0.00000000e+00  7.95731469e+08  1.99610892e-01 -7.79999531e-03
 -1.56000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 151 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.48910892e-01 -2.61999953e-02
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.19613549e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 152 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.48910892e-01 -2.61999953e-02
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.19613549e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 153 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.92789108e-01  2.40000047e-02
 -2.32000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.58171273e-01  6.00000000e-01]
Read chunk # 154 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.92789108e-01  2.40000047e-02
 -2.32000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.000000

start
[ 0.00000000e+00  7.95731469e+08  4.14710892e-01 -1.99999953e-02
 -2.27000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.62293447e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 191 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.23610892e-01  1.28000047e-02
  7.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.99715337e-01  2.22184822e-01 -6.00000000e-01]
Read chunk # 192 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.23610892e-01  1.28000047e-02
  7.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.99715337e-01  2.22184822e-01 -6.00000000e-01]
Read chunk # 193 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.42889108e-01  4.82000047e-02
 -2.11000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 194 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.42889108e-01  4.82000047e-02
 -2.11000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.

[ 0.00000000e+00  7.95731469e+08  3.76310892e-01  4.40000047e-02
  8.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.40916293e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 265 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.75108923e-02 -5.89999953e-02
  1.05999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.21981698e-01  3.04380441e-01 -6.00000000e-01]
Read chunk # 266 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.75108923e-02 -5.89999953e-02
  1.05999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.21981698e-01  3.04380441e-01 -6.00000000e-01]
Read chunk # 267 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.47789108e-01 -5.55999953e-02
  5.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.92820940e-01 -3.46694881e-01  1.80867451e-01]
Read chunk # 268 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.47789108e-01 -5.55999953e-02
  5.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.928209

start
[ 0.00000000e+00  7.95731469e+08 -3.77289108e-01 -5.97999953e-02
 -2.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.02250335e-01 -4.92475219e-01  6.00000000e-01]
Read chunk # 338 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.77289108e-01 -5.97999953e-02
 -2.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.02250335e-01 -4.92475219e-01  6.00000000e-01]
Read chunk # 339 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.91289108e-01  4.76000047e-02
  2.65999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 340 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.91289108e-01  4.76000047e-02
  2.65999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 341 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  9.97108923e-02 -5.39999953e-02
  1.45999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.

[ 0.00000000e+00  7.95731469e+08 -6.37891077e-02 -7.19999531e-03
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.75205064e-01 -7.49380215e-02  4.86160307e-01]
Read chunk # 408 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.37891077e-02 -7.19999531e-03
 -9.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.75205064e-01 -7.49380215e-02  4.86160307e-01]
Read chunk # 409 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.43891077e-02 -2.65999953e-02
 -2.05000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.11000651e-01  5.75256707e-02 -1.52382941e-01]
Read chunk # 410 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.43891077e-02 -2.65999953e-02
 -2.05000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.11000651e-01  5.75256707e-02 -1.52382941e-01]
Read chunk # 411 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.86289108e-01 -5.85999953e-02
  1.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.216439

start
[ 0.00000000e+00  7.95731469e+08  2.03810892e-01 -2.35999953e-02
  7.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.14430119e-01  2.10197063e-01 -6.00000000e-01]
Read chunk # 481 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.20289108e-01 -1.77999953e-02
 -2.98000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.31057564e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 482 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.20289108e-01 -1.77999953e-02
 -2.98000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.31057564e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 483 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.54189108e-01  5.16000047e-02
  1.35999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.37024307e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 484 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.54189108e-01  5.16000047e-02
  1.35999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.

start
[ 0.00000000e+00  7.95731469e+08 -3.73891077e-02 -5.91999953e-02
  8.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.95780533e-01  7.06591907e-02  2.92273685e-01]
Read chunk # 552 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.73891077e-02 -5.91999953e-02
  8.89998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.95780533e-01  7.06591907e-02  2.92273685e-01]
Read chunk # 553 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.64810892e-01  1.10000047e-02
 -2.35000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 554 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.64810892e-01  1.10000047e-02
 -2.35000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 555 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.04510892e-01  3.00000047e-02
  1.14999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.

[ 0.00000000e+00  7.95731469e+08  4.51108923e-02 -4.57999953e-02
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.36132244e-02  6.43264423e-02 -6.00000000e-01]
Read chunk # 622 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.51108923e-02 -4.57999953e-02
 -6.10001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.36132244e-02  6.43264423e-02 -6.00000000e-01]
Read chunk # 623 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.79289108e-01  1.68000047e-02
  1.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.35762929e-02 -5.50346264e-01  6.00000000e-01]
Read chunk # 624 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.79289108e-01  1.68000047e-02
  1.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.35762929e-02 -5.50346264e-01  6.00000000e-01]
Read chunk # 625 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  6.88108923e-02 -8.59999531e-03
 -8.30001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.335455

[ 0.00000000e+00  7.95731469e+08  1.55510892e-01  3.20000469e-03
  1.95999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.83475799e-01  3.54217400e-01 -4.39586592e-01]
Read chunk # 694 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.55510892e-01  3.20000469e-03
  1.95999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.83475799e-01  3.54217400e-01 -4.39586592e-01]
Read chunk # 695 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.59589108e-01 -1.11999953e-02
 -1.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.32621835e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 696 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.59589108e-01 -1.11999953e-02
 -1.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.32621835e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 697 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.16310892e-01  3.54000047e-02
  2.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -8.077755

[ 0.00000000e+00  7.95731469e+08 -4.10589108e-01 -4.05999953e-02
  1.81999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 763 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.81510892e-01  3.72000047e-02
  2.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.06747989e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 764 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.81510892e-01  3.72000047e-02
  2.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  9.06747989e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 765 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.48910892e-01 -2.53999953e-02
  2.60999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  4.23396419e-01 -6.00000000e-01]
Read chunk # 766 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.48910892e-01 -2.53999953e-02
  2.60999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

[ 0.00000000e+00  7.95731469e+08  1.28810892e-01  1.68000047e-02
 -1.91000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.92051337e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 834 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.28810892e-01  1.68000047e-02
 -1.91000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.92051337e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 835 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.00689108e-01  1.10000047e-02
  1.26999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.73246818e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 836 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.00689108e-01  1.10000047e-02
  1.26999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.73246818e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 837 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.05289108e-01  5.60000469e-03
 -4.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.849696

start
[ 0.00000000e+00  7.95731469e+08  1.57010892e-01  3.20000469e-03
  5.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.30862571e-01  3.91163016e-02 -6.00000000e-01]
Read chunk # 906 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.57010892e-01  3.20000469e-03
  5.19998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.30862571e-01  3.91163016e-02 -6.00000000e-01]
Read chunk # 907 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  8.53108923e-02 -1.17999953e-02
 -1.78000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.42754963e-02  3.39775427e-01 -6.00000000e-01]
Read chunk # 908 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  8.53108923e-02 -1.17999953e-02
 -1.78000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.42754963e-02  3.39775427e-01 -6.00000000e-01]
Read chunk # 909 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.66010892e-01  7.20000469e-03
  3.99986580e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.

start
[ 0.00000000e+00  7.95731469e+08 -3.28891077e-02 -5.37999953e-02
  1.53999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.08951733e-01 -1.79644854e-01 -3.78056952e-02]
Read chunk # 977 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.67889108e-01  5.20000047e-02
  9.99865802e-05  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.01147894e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 978 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.67889108e-01  5.20000047e-02
  9.99865802e-05  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.01147894e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 979 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.47110892e-01 -2.65999953e-02
  7.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 980 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.47110892e-01 -2.65999953e-02
  7.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.

[ 0.00000000e+00  7.95731469e+08 -1.93989108e-01  3.40000047e-02
 -8.20001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.01633043e-01 -4.34502924e-01  6.00000000e-01]
Read chunk # 1049 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.32110892e-01  2.50000047e-02
 -2.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.03860927e-01  4.87365637e-01 -6.00000000e-01]
Read chunk # 1050 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.32110892e-01  2.50000047e-02
 -2.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.03860927e-01  4.87365637e-01 -6.00000000e-01]
Read chunk # 1051 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.06610892e-01 -4.33999953e-02
  2.49999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1052 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.06610892e-01 -4.33999953e-02
  2.49999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00

[ 0.00000000e+00  7.95731469e+08 -4.49289108e-01 -5.81999953e-02
 -9.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.03119689e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1122 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.49289108e-01 -5.81999953e-02
 -9.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.03119689e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1123 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.28489108e-01  1.74000047e-02
 -2.84000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.57696629e-01 -4.39532191e-01  6.00000000e-01]
Read chunk # 1124 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.28489108e-01  1.74000047e-02
 -2.84000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.57696629e-01 -4.39532191e-01  6.00000000e-01]
Read chunk # 1125 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.46589108e-01 -1.81999953e-02
  5.79998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00

start
[ 0.00000000e+00  7.95731469e+08  4.35410892e-01 -5.07999953e-02
 -8.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.71953034e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1195 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.35410892e-01 -5.07999953e-02
 -8.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.71953034e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1196 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.18910771e-03 -3.45999953e-02
  9.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.57986176e-01  4.97180104e-02 -2.11775841e-01]
Read chunk # 1197 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.18910771e-03 -3.45999953e-02
  9.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.57986176e-01  4.97180104e-02 -2.11775841e-01]
Read chunk # 1198 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.25610892e-01 -5.15999953e-02
  1.69999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08  1.81010892e-01  4.69199999e-09
  2.91999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.92819825e-01 -6.00000000e-01]
Read chunk # 1267 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.81010892e-01  4.69199999e-09
  2.91999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.92819825e-01 -6.00000000e-01]
Read chunk # 1268 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  5.21089229e-03 -5.77999953e-02
 -2.30000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.92905271e-01  1.48265026e-02 -3.57410612e-01]
Read chunk # 1269 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  5.21089229e-03 -5.77999953e-02
 -2.30000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.92905271e-01  1.48265026e-02 -3.57410612e-01]
Read chunk # 1270 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.14189108e-01 -2.87999953e-02
 -2.23000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.64

start
[ 0.00000000e+00  7.95731469e+08  4.24610892e-01 -2.35999953e-02
  6.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.46589841e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1305 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.24610892e-01 -2.35999953e-02
  6.69998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  2.46589841e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1306 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.62089108e-01  3.94000047e-02
  2.78999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.44483149e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1307 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.62089108e-01  3.94000047e-02
  2.78999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -5.44483149e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1308 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.12610892e-01 -9.79999531e-03
 -1.54000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08  8.05108923e-02  8.60000469e-03
  1.03999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.69121098e-01 -1.19645797e-01 -6.00000000e-01]
Read chunk # 1344 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.59310892e-01 -3.03999953e-02
 -3.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.62114936e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1345 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.59310892e-01 -3.03999953e-02
 -3.90001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.62114936e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1346 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.19089108e-01  4.64000047e-02
  1.54999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.49056993e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1347 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.19089108e-01  4.64000047e-02
  1.54999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.49

[ 0.00000000e+00  7.95731469e+08 -4.16589108e-01 -2.19999953e-02
  1.78999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1383 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.16589108e-01 -2.19999953e-02
  1.78999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1384 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.48189108e-01  4.98000047e-02
 -2.21000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1385 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.48189108e-01  4.98000047e-02
 -2.21000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1386 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.98910771e-03  2.18000047e-02
  1.35999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.92

start
[ 0.00000000e+00  7.95731469e+08 -3.47889108e-01 -4.95999953e-02
  1.90999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1423 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.47889108e-01 -4.95999953e-02
  1.90999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1424 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.22189108e-01  1.28000047e-02
 -7.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.36585787e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1425 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.22189108e-01  1.28000047e-02
 -7.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.36585787e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1426 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.90891077e-02  1.74000047e-02
 -6.60001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


Read chunk # 1495 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.05789108e-01  2.82000047e-02
  2.38999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.36249272e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1496 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.21289108e-01  5.26000047e-02
 -2.66000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -5.80485057e-01  6.00000000e-01]
Read chunk # 1497 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.21289108e-01  5.26000047e-02
 -2.66000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -5.80485057e-01  6.00000000e-01]
Read chunk # 1498 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.22710892e-01  5.34000047e-02
 -2.24000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  2.88410444e-01 -6.00000000e-01]
Read chunk # 1499 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.22710892e-01  5.34000047e-02
 -2.24000134e-02  0.00000000e+00 -5

start
[ 0.00000000e+00  7.95731469e+08  3.52910892e-01 -4.97999953e-02
  1.08999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.64698244e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1569 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.52910892e-01 -4.97999953e-02
  1.08999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  3.64698244e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 1570 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.68108923e-02 -4.29999953e-02
  1.10999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.06469743e-01  2.57429530e-03 -6.00000000e-01]
Read chunk # 1571 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.68108923e-02 -4.29999953e-02
  1.10999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.06469743e-01  2.57429530e-03 -6.00000000e-01]
Read chunk # 1572 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.15789108e-01  2.08000047e-02
  1.29998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


[ 0.00000000e+00  7.95731469e+08  1.05110892e-01 -2.69999953e-02
 -2.50000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.67522493e-02  6.00000000e-01 -5.97139237e-01]
Read chunk # 1640 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.37610892e-01  1.02000047e-02
  4.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.78429015e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1641 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.37610892e-01  1.02000047e-02
  4.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.78429015e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 1642 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.23189108e-01 -4.33999953e-02
 -2.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.09574839e-01 -4.54637812e-01  6.00000000e-01]
Read chunk # 1643 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.23189108e-01 -4.33999953e-02
 -2.00013420e-04  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.09

start
[ 0.00000000e+00  7.95731469e+08 -4.03689108e-01  4.50000047e-02
 -2.30000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1711 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.03689108e-01  4.50000047e-02
 -2.30000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 1712 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.20789108e-01  5.52000047e-02
 -2.96000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.10363558e-02  6.00000000e-01]
Read chunk # 1713 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.20789108e-01  5.52000047e-02
 -2.96000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -4.10363558e-02  6.00000000e-01]
Read chunk # 1714 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.93889108e-01 -9.19999531e-03
 -1.52000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


start
[ 0.00000000e+00  7.95731469e+08 -1.25589108e-01 -4.79999953e-02
 -4.70001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  7.53771117e-02 -5.92966477e-01  6.00000000e-01]
Read chunk # 1782 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.24610892e-01  1.60000469e-03
  8.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.11119783e-04  4.27170387e-01 -3.45557112e-01]
Read chunk # 1783 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.24610892e-01  1.60000469e-03
  8.99998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.11119783e-04  4.27170387e-01 -3.45557112e-01]
Read chunk # 1784 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.39310892e-01 -4.71999953e-02
  2.62999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -1.64116219e-03 -6.00000000e-01]
Read chunk # 1785 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.39310892e-01 -4.71999953e-02
  2.62999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00


Read chunk # 1853 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.06108923e-02 -3.81999953e-02
  1.17999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -1.92536450e-01  3.03176353e-01]
Read chunk # 1854 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.82891077e-02 -1.05999953e-02
 -2.54000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -1.97565638e-02  6.00000000e-01]
Read chunk # 1855 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -6.82891077e-02 -1.05999953e-02
 -2.54000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -1.97565638e-02  6.00000000e-01]
Read chunk # 1856 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.32989108e-01 -4.37999953e-02
  2.01999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  5.59395707e-01]
Read chunk # 1857 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.32989108e-01 -4.37999953e-02
  2.01999866e-02  0.00000000e+00 -5

[ 0.00000000e+00  7.95731469e+08 -1.16289108e-01  3.04000047e-02
  1.32999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.22124397e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1927 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.16289108e-01  3.04000047e-02
  1.32999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -7.22124397e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 1928 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.88891077e-02 -4.85999953e-02
 -2.26000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -7.82786801e-02  1.69427052e-01]
Read chunk # 1929 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.88891077e-02 -4.85999953e-02
 -2.26000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -7.82786801e-02  1.69427052e-01]
Read chunk # 1930 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.99210892e-01  1.02000047e-02
  2.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.52

[ 0.00000000e+00  7.95731469e+08 -3.58989108e-01 -5.17999953e-02
  3.09998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  5.19684716e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 2000 out of 4999
torch.Size([2220000, 6])


In [4]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -248.49933,
 'std': 154.67868,
 'min': -752.22614,
 'max': 0.0,
 'histogram': (array([ 10,   9,   6,   9,  17,  16,  29,  34,  32,  39,  38,  62,  57,
          85,  94, 116, 155, 177,  14,   1]),
  array([-752.22614 , -714.6148  , -677.00354 , -639.3922  , -601.7809  ,
         -564.1696  , -526.5583  , -488.947   , -451.3357  , -413.72437 ,
         -376.11307 , -338.50177 , -300.89044 , -263.27914 , -225.66785 ,
         -188.05653 , -150.44522 , -112.83392 ,  -75.22261 ,  -37.611305,
            0.      ], dtype=float32))}

In [5]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

## Setting up an Algorithm

In [6]:
from d3rlpy.algos import CQL

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

model = CQL(q_func_factory='mean', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=False) #change it to true if you have one
model.build_with_dataset(dataset)

In [7]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

0.06862036204800842


In [8]:
%load_ext tensorboard
%tensorboard --logdir runs

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
Traceback (most recent call last):
  File "/home/dasc/anaconda3/envs/jbreeden3.10/bin/tensorboard", line 6, in <module>
    from tensorboard.main import run_main
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/main.py", line 40, in <module>
    from tensorboard import default
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/default.py", line 38, in <module>
    from tensorboard.plugins.audio import audio_plugin
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugins/audio/audio_plugin.py", line 25, in <module>
    from tensorboard import plugin_util
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorboard/plugin_util.py", line 21, in <module>
    from tensorboard._vendor import bleach
  File "/home/dasc/anaconda3/envs/jbreeden3.10/lib/python3.10/site-packages/tensorb

In [9]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=40, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-20 15:33.35 [debug    ] RoundIterator is selected.
2022-04-20 15:33.35 [info     ] Directory is created at d3rlpy_logs/CQL_20220420153335
2022-04-20 15:33.35 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-20 15:33.35 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-20 15:33.35 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220420153335/params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conser

Epoch 1/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 15:43.19 [info     ] CQL_20220420153335: epoch=1 step=6931 epoch=1 metrics={'time_sample_batch': 0.0003195364501737375, 'time_algorithm_update': 0.08248434316508269, 'temp_loss': 2.768967193367677, 'temp': 0.747011374728949, 'alpha_loss': -3.7443262735867933, 'alpha': 1.2409134573827074, 'critic_loss': 19.95090428755373, 'actor_loss': 24.710188812992985, 'time_step': 0.08313859886738528, 'td_error': 33.174810272332685, 'init_value': -63.10578155517578, 'ave_value': -54.6618558456178} step=6931
2022-04-20 15:43.19 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_6931.pt


Epoch 2/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 15:53.46 [info     ] CQL_20220420153335: epoch=2 step=13862 epoch=2 metrics={'time_sample_batch': 0.00032477519192610995, 'time_algorithm_update': 0.0886658003381485, 'temp_loss': 0.9284023452228221, 'temp': 0.4178826983761708, 'alpha_loss': 0.5071393584290255, 'alpha': 1.1343545683739833, 'critic_loss': 108.15097504456035, 'actor_loss': 88.52169504579811, 'time_step': 0.0893653748338485, 'td_error': 70.28582000488592, 'init_value': -123.6652603149414, 'ave_value': -108.84911071964142} step=13862
2022-04-20 15:53.46 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_13862.pt


Epoch 3/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:05.16 [info     ] CQL_20220420153335: epoch=3 step=20793 epoch=3 metrics={'time_sample_batch': 0.00032805914900172514, 'time_algorithm_update': 0.0977247973538257, 'temp_loss': 0.22757699012183527, 'temp': 0.2424380859114738, 'alpha_loss': 0.20690906138441204, 'alpha': 0.9889010477764337, 'critic_loss': 140.47778062255605, 'actor_loss': 122.39168678804353, 'time_step': 0.09842739148567188, 'td_error': 50.315468023362236, 'init_value': -134.0041961669922, 'ave_value': -118.94367652338013} step=20793
2022-04-20 16:05.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_20793.pt


Epoch 4/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:16.50 [info     ] CQL_20220420153335: epoch=4 step=27724 epoch=4 metrics={'time_sample_batch': 0.000325340399785469, 'time_algorithm_update': 0.0983536173953, 'temp_loss': -0.0037530718904118673, 'temp': 0.2063181700794715, 'alpha_loss': 0.5169044976315491, 'alpha': 0.7449245462082558, 'critic_loss': 95.58839527264833, 'actor_loss': 120.62241271501529, 'time_step': 0.09905372111763966, 'td_error': 36.03331333118191, 'init_value': -124.92079162597656, 'ave_value': -111.11229975425006} step=27724
2022-04-20 16:16.50 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_27724.pt


Epoch 5/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:28.14 [info     ] CQL_20220420153335: epoch=5 step=34655 epoch=5 metrics={'time_sample_batch': 0.000326121735755129, 'time_algorithm_update': 0.09681554316029613, 'temp_loss': 0.00032073859677773253, 'temp': 0.21318708915521944, 'alpha_loss': 0.11916536563081076, 'alpha': 0.5930740851513873, 'critic_loss': 72.87345870424187, 'actor_loss': 110.38359594668442, 'time_step': 0.09752248737344561, 'td_error': 29.422435653500496, 'init_value': -111.53463745117188, 'ave_value': -99.04465417601233} step=34655
2022-04-20 16:28.14 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_34655.pt


Epoch 6/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:39.02 [info     ] CQL_20220420153335: epoch=6 step=41586 epoch=6 metrics={'time_sample_batch': 0.00032564307545673073, 'time_algorithm_update': 0.09167907135153966, 'temp_loss': 0.0003894992316108266, 'temp': 0.2109469890577283, 'alpha_loss': -0.025462081025799915, 'alpha': 0.5646655735748015, 'critic_loss': 62.11889878170364, 'actor_loss': 99.92246999700698, 'time_step': 0.0923823975257725, 'td_error': 26.04638565232455, 'init_value': -101.17337036132812, 'ave_value': -89.98971871626529} step=41586
2022-04-20 16:39.02 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_41586.pt


Epoch 7/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:49.15 [info     ] CQL_20220420153335: epoch=7 step=48517 epoch=7 metrics={'time_sample_batch': 0.0003239985686018601, 'time_algorithm_update': 0.08672099240052203, 'temp_loss': 0.0030239860644304564, 'temp': 0.2073025358366495, 'alpha_loss': 0.003089397885721052, 'alpha': 0.5822872236562384, 'critic_loss': 57.2463680279064, 'actor_loss': 90.91751046676757, 'time_step': 0.08741982264223595, 'td_error': 24.36936655382426, 'init_value': -93.56272888183594, 'ave_value': -82.80826739042014} step=48517
2022-04-20 16:49.15 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_48517.pt


Epoch 8/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 16:59.06 [info     ] CQL_20220420153335: epoch=8 step=55448 epoch=8 metrics={'time_sample_batch': 0.0003220476365987227, 'time_algorithm_update': 0.08342171730098744, 'temp_loss': 0.004485020050345976, 'temp': 0.1901695298036616, 'alpha_loss': 0.06641539061849445, 'alpha': 0.5605142538620785, 'critic_loss': 52.78483647724678, 'actor_loss': 82.80584979604254, 'time_step': 0.08411620778105953, 'td_error': 22.323002429077384, 'init_value': -85.10530090332031, 'ave_value': -74.75787503606774} step=55448
2022-04-20 16:59.06 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_55448.pt


Epoch 9/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:08.48 [info     ] CQL_20220420153335: epoch=9 step=62379 epoch=9 metrics={'time_sample_batch': 0.0003249175000579442, 'time_algorithm_update': 0.08224816821798275, 'temp_loss': 0.006599882804418324, 'temp': 0.17452520775323718, 'alpha_loss': 0.10390414225898836, 'alpha': 0.4931810137952428, 'critic_loss': 48.13895591788608, 'actor_loss': 75.21096081307846, 'time_step': 0.08294775936273661, 'td_error': 20.526469998278113, 'init_value': -77.03153991699219, 'ave_value': -67.35667724761393} step=62379
2022-04-20 17:08.48 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_62379.pt


Epoch 10/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:18.16 [info     ] CQL_20220420153335: epoch=10 step=69310 epoch=10 metrics={'time_sample_batch': 0.0003184678148284238, 'time_algorithm_update': 0.08021046705635367, 'temp_loss': 0.005132685609276337, 'temp': 0.15643376244204915, 'alpha_loss': 0.09971451566570537, 'alpha': 0.4129131031731234, 'critic_loss': 43.89378713472738, 'actor_loss': 67.52094793588013, 'time_step': 0.08089852057877875, 'td_error': 18.769228545019985, 'init_value': -68.04108428955078, 'ave_value': -59.685035043441815} step=69310
2022-04-20 17:18.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_69310.pt


Epoch 11/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:27.46 [info     ] CQL_20220420153335: epoch=11 step=76241 epoch=11 metrics={'time_sample_batch': 0.00031931935889503173, 'time_algorithm_update': 0.08037730770664218, 'temp_loss': 0.0029247197671961366, 'temp': 0.14155166227353838, 'alpha_loss': 0.04788319427675982, 'alpha': 0.3651454767964478, 'critic_loss': 41.49490181719684, 'actor_loss': 59.15035770032644, 'time_step': 0.08106976634856512, 'td_error': 17.69526330397221, 'init_value': -57.75691223144531, 'ave_value': -50.403590080907776} step=76241
2022-04-20 17:27.46 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_76241.pt


Epoch 12/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:37.16 [info     ] CQL_20220420153335: epoch=12 step=83172 epoch=12 metrics={'time_sample_batch': 0.00031797852527869935, 'time_algorithm_update': 0.08048216403261646, 'temp_loss': -0.0004010571326874761, 'temp': 0.13769456127594257, 'alpha_loss': 0.02082934724180883, 'alpha': 0.3401045528708246, 'critic_loss': 39.57202903427374, 'actor_loss': 50.979965523128264, 'time_step': 0.08116866551235355, 'td_error': 16.777016053725905, 'init_value': -49.18160629272461, 'ave_value': -42.96880842322817} step=83172
2022-04-20 17:37.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_83172.pt


Epoch 13/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:46.47 [info     ] CQL_20220420153335: epoch=13 step=90103 epoch=13 metrics={'time_sample_batch': 0.00031973149177574384, 'time_algorithm_update': 0.08073819129089654, 'temp_loss': 0.0014446701686227012, 'temp': 0.1361372949774003, 'alpha_loss': -0.000595928755046377, 'alpha': 0.3358078405224491, 'critic_loss': 37.693017975865395, 'actor_loss': 43.944004089934474, 'time_step': 0.08142964015394351, 'td_error': 15.99224538878447, 'init_value': -42.24564743041992, 'ave_value': -36.78737932989227} step=90103
2022-04-20 17:46.47 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_90103.pt


Epoch 14/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 17:56.16 [info     ] CQL_20220420153335: epoch=14 step=97034 epoch=14 metrics={'time_sample_batch': 0.00032026157838719314, 'time_algorithm_update': 0.08030264929007906, 'temp_loss': -0.0005968265873214058, 'temp': 0.13560916850610208, 'alpha_loss': 0.0007307421655814911, 'alpha': 0.33201629468638255, 'critic_loss': 35.9071600111407, 'actor_loss': 38.12408530524654, 'time_step': 0.08099038427895307, 'td_error': 14.979487830269212, 'init_value': -35.7498779296875, 'ave_value': -30.96101388114645} step=97034
2022-04-20 17:56.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_97034.pt


Epoch 15/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:06.17 [info     ] CQL_20220420153335: epoch=15 step=103965 epoch=15 metrics={'time_sample_batch': 0.00032824321736368915, 'time_algorithm_update': 0.08490330428895522, 'temp_loss': -0.002152978329687624, 'temp': 0.1394934087729423, 'alpha_loss': 0.002416838392537948, 'alpha': 0.3318499759916722, 'critic_loss': 34.25877252817498, 'actor_loss': 33.43355451223913, 'time_step': 0.08561741131034149, 'td_error': 14.786470801166613, 'init_value': -31.70958709716797, 'ave_value': -27.404839726220924} step=103965
2022-04-20 18:06.17 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_103965.pt


Epoch 16/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:15.57 [info     ] CQL_20220420153335: epoch=16 step=110896 epoch=16 metrics={'time_sample_batch': 0.0003214864534073415, 'time_algorithm_update': 0.08181669601646273, 'temp_loss': -0.0005173225958386572, 'temp': 0.14345253905964833, 'alpha_loss': 0.0037911748273703584, 'alpha': 0.3336265975607541, 'critic_loss': 33.51147911114673, 'actor_loss': 29.857486369444374, 'time_step': 0.08251286784458119, 'td_error': 14.650923135655656, 'init_value': -28.326433181762695, 'ave_value': -24.258286584602637} step=110896
2022-04-20 18:15.57 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_110896.pt


Epoch 17/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:25.50 [info     ] CQL_20220420153335: epoch=17 step=117827 epoch=17 metrics={'time_sample_batch': 0.000328288899065182, 'time_algorithm_update': 0.08369722378028624, 'temp_loss': 0.000888251322270197, 'temp': 0.14437474433374067, 'alpha_loss': 0.00099368561372341, 'alpha': 0.3264582410052713, 'critic_loss': 32.8589497888317, 'actor_loss': 27.97708348958906, 'time_step': 0.08442094996453295, 'td_error': 14.475354351968747, 'init_value': -26.714262008666992, 'ave_value': -22.844318368899508} step=117827
2022-04-20 18:25.50 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_117827.pt


Epoch 18/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:35.52 [info     ] CQL_20220420153335: epoch=18 step=124758 epoch=18 metrics={'time_sample_batch': 0.000329833477198037, 'time_algorithm_update': 0.08505588368305894, 'temp_loss': -0.00099877519264129, 'temp': 0.14405766762002964, 'alpha_loss': 0.014528448485693448, 'alpha': 0.32068602295514226, 'critic_loss': 32.555570086494114, 'actor_loss': 26.53302738405447, 'time_step': 0.08577188787101203, 'td_error': 14.284168447327097, 'init_value': -26.172460556030273, 'ave_value': -22.438121131699607} step=124758
2022-04-20 18:35.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_124758.pt


Epoch 19/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:45.58 [info     ] CQL_20220420153335: epoch=19 step=131689 epoch=19 metrics={'time_sample_batch': 0.0003275014744952478, 'time_algorithm_update': 0.08560412309503783, 'temp_loss': -0.00045233357533054137, 'temp': 0.1447703934794225, 'alpha_loss': 0.014104906014329645, 'alpha': 0.3121550122547108, 'critic_loss': 32.33882812195708, 'actor_loss': 25.2322373046897, 'time_step': 0.08631860017948484, 'td_error': 14.472456722290977, 'init_value': -24.117769241333008, 'ave_value': -20.473093399938517} step=131689
2022-04-20 18:45.58 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_131689.pt


Epoch 20/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 18:55.47 [info     ] CQL_20220420153335: epoch=20 step=138620 epoch=20 metrics={'time_sample_batch': 0.00032830117946234535, 'time_algorithm_update': 0.08324358250358338, 'temp_loss': 0.0009850554951840035, 'temp': 0.14425049875054657, 'alpha_loss': 0.005612188641143342, 'alpha': 0.3049798773413801, 'critic_loss': 32.28727322318925, 'actor_loss': 23.61768613056331, 'time_step': 0.08394761787935212, 'td_error': 14.505782212885435, 'init_value': -22.956058502197266, 'ave_value': -19.469582391840106} step=138620
2022-04-20 18:55.47 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_138620.pt


Epoch 21/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:05.31 [info     ] CQL_20220420153335: epoch=21 step=145551 epoch=21 metrics={'time_sample_batch': 0.000321599316105081, 'time_algorithm_update': 0.08246351010767417, 'temp_loss': 0.00023333186613520238, 'temp': 0.14279457217058464, 'alpha_loss': 0.01033803305834628, 'alpha': 0.29522051430336754, 'critic_loss': 32.50856063103129, 'actor_loss': 22.45435404228865, 'time_step': 0.08315563999958119, 'td_error': 14.899241102132113, 'init_value': -22.585514068603516, 'ave_value': -19.14331042893599} step=145551
2022-04-20 19:05.31 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_145551.pt


Epoch 22/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:15.02 [info     ] CQL_20220420153335: epoch=22 step=152482 epoch=22 metrics={'time_sample_batch': 0.0003204203635785267, 'time_algorithm_update': 0.0805817625598635, 'temp_loss': -0.0008667614144024382, 'temp': 0.14282810681354016, 'alpha_loss': 0.011199554575525777, 'alpha': 0.2878621310369739, 'critic_loss': 32.80337396854447, 'actor_loss': 21.34520415330308, 'time_step': 0.08126844999990049, 'td_error': 14.820875349924577, 'init_value': -20.153118133544922, 'ave_value': -16.765282869575458} step=152482
2022-04-20 19:15.02 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_152482.pt


Epoch 23/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:24.38 [info     ] CQL_20220420153335: epoch=23 step=159413 epoch=23 metrics={'time_sample_batch': 0.00032470512142464844, 'time_algorithm_update': 0.0812665158545467, 'temp_loss': 0.0010672057373469362, 'temp': 0.14395537136987663, 'alpha_loss': -0.002764151327194225, 'alpha': 0.28377017374937197, 'critic_loss': 32.8712438999996, 'actor_loss': 20.337352455496497, 'time_step': 0.08196416160551574, 'td_error': 15.149727380027574, 'init_value': -19.727561950683594, 'ave_value': -16.54500926638619} step=159413
2022-04-20 19:24.38 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_159413.pt


Epoch 24/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:33.51 [info     ] CQL_20220420153335: epoch=24 step=166344 epoch=24 metrics={'time_sample_batch': 0.00032130290102845165, 'time_algorithm_update': 0.07790266399413816, 'temp_loss': -0.0008683141129029034, 'temp': 0.14407250584418804, 'alpha_loss': 0.013311585278565758, 'alpha': 0.28118171219856014, 'critic_loss': 32.988214973152296, 'actor_loss': 19.778074436851814, 'time_step': 0.07859918534919101, 'td_error': 15.333377086055687, 'init_value': -18.741764068603516, 'ave_value': -15.594141673438502} step=166344
2022-04-20 19:33.51 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_166344.pt


Epoch 25/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:43.16 [info     ] CQL_20220420153335: epoch=25 step=173275 epoch=25 metrics={'time_sample_batch': 0.00032133919183799603, 'time_algorithm_update': 0.07971809316620547, 'temp_loss': 0.0008447710764703216, 'temp': 0.1428879699959347, 'alpha_loss': 0.005489138972057275, 'alpha': 0.27313694880734857, 'critic_loss': 33.226343921382885, 'actor_loss': 19.23314597911605, 'time_step': 0.08041496263783897, 'td_error': 15.447653880901449, 'init_value': -18.2351016998291, 'ave_value': -15.09805850832822} step=173275
2022-04-20 19:43.16 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_173275.pt


Epoch 26/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 19:52.37 [info     ] CQL_20220420153335: epoch=26 step=180206 epoch=26 metrics={'time_sample_batch': 0.0003207676545862614, 'time_algorithm_update': 0.0792261794641616, 'temp_loss': -0.0003199391598463738, 'temp': 0.14322206769055265, 'alpha_loss': 0.00887306261378385, 'alpha': 0.2649668730701861, 'critic_loss': 33.368791053080486, 'actor_loss': 18.908768829579063, 'time_step': 0.07991938002094737, 'td_error': 15.594426117688672, 'init_value': -19.2792911529541, 'ave_value': -16.229726369469436} step=180206
2022-04-20 19:52.37 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_180206.pt


Epoch 27/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:02.04 [info     ] CQL_20220420153335: epoch=27 step=187137 epoch=27 metrics={'time_sample_batch': 0.0003233112103494279, 'time_algorithm_update': 0.07992174814246533, 'temp_loss': 0.0001568342031467055, 'temp': 0.14259757760919167, 'alpha_loss': 0.0017858914882907695, 'alpha': 0.2644224464162365, 'critic_loss': 33.464383466739775, 'actor_loss': 18.56777122218601, 'time_step': 0.08062002500953173, 'td_error': 15.436174382836409, 'init_value': -17.285856246948242, 'ave_value': -14.101178528394387} step=187137
2022-04-20 20:02.04 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_187137.pt


Epoch 28/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:11.39 [info     ] CQL_20220420153335: epoch=28 step=194068 epoch=28 metrics={'time_sample_batch': 0.00032522485397574433, 'time_algorithm_update': 0.08119419030065816, 'temp_loss': 0.0001927567814374532, 'temp': 0.14230889529785265, 'alpha_loss': 0.0013443664351561596, 'alpha': 0.26226040383400706, 'critic_loss': 33.40805724015949, 'actor_loss': 18.53683039751976, 'time_step': 0.08189484712245436, 'td_error': 15.392710682648492, 'init_value': -17.966524124145508, 'ave_value': -14.829914393010275} step=194068
2022-04-20 20:11.39 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_194068.pt


Epoch 29/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:21.14 [info     ] CQL_20220420153335: epoch=29 step=200999 epoch=29 metrics={'time_sample_batch': 0.0003250075563038089, 'time_algorithm_update': 0.0811735180579744, 'temp_loss': 0.0002677118835652032, 'temp': 0.1426031330097566, 'alpha_loss': 0.0012760942851264604, 'alpha': 0.25935442095761896, 'critic_loss': 33.22956035278556, 'actor_loss': 18.528199693845682, 'time_step': 0.08187273779811152, 'td_error': 15.709614645561128, 'init_value': -18.06094741821289, 'ave_value': -14.884930604752732} step=200999
2022-04-20 20:21.14 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_200999.pt


Epoch 30/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:30.52 [info     ] CQL_20220420153335: epoch=30 step=207930 epoch=30 metrics={'time_sample_batch': 0.00032206198092818244, 'time_algorithm_update': 0.08162737205729927, 'temp_loss': -0.001012232339988801, 'temp': 0.1431194973243915, 'alpha_loss': 0.008610225190067228, 'alpha': 0.25780932192695105, 'critic_loss': 33.700963118774276, 'actor_loss': 18.397712151160782, 'time_step': 0.08231968010096233, 'td_error': 15.917722390131443, 'init_value': -18.304237365722656, 'ave_value': -15.114217640742284} step=207930
2022-04-20 20:30.52 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_207930.pt


Epoch 31/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:40.25 [info     ] CQL_20220420153335: epoch=31 step=214861 epoch=31 metrics={'time_sample_batch': 0.0003255595549964712, 'time_algorithm_update': 0.08096633544303246, 'temp_loss': 0.0002327666471566768, 'temp': 0.14350672243795876, 'alpha_loss': -0.0033354885077030113, 'alpha': 0.2553833740482666, 'critic_loss': 34.181138967086596, 'actor_loss': 18.188530241216764, 'time_step': 0.08167323724237326, 'td_error': 16.20987056259284, 'init_value': -18.09563636779785, 'ave_value': -14.93292512752959} step=214861
2022-04-20 20:40.25 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_214861.pt


Epoch 32/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 20:50.11 [info     ] CQL_20220420153335: epoch=32 step=221792 epoch=32 metrics={'time_sample_batch': 0.0003230113553856379, 'time_algorithm_update': 0.082442383318301, 'temp_loss': -0.0002656264815356165, 'temp': 0.1433630742916313, 'alpha_loss': 0.003772665719807023, 'alpha': 0.2559037572110203, 'critic_loss': 34.728345656524816, 'actor_loss': 17.829108091242436, 'time_step': 0.08313539323654032, 'td_error': 16.493621691235813, 'init_value': -17.505754470825195, 'ave_value': -14.204159662112794} step=221792
2022-04-20 20:50.11 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_221792.pt


Epoch 33/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:00.48 [info     ] CQL_20220420153335: epoch=33 step=228723 epoch=33 metrics={'time_sample_batch': 0.000344621448904865, 'time_algorithm_update': 0.08982309062249422, 'temp_loss': 0.0008486445471810349, 'temp': 0.1432147224027518, 'alpha_loss': 0.00130200397277543, 'alpha': 0.25170524383310183, 'critic_loss': 35.26336981520613, 'actor_loss': 17.76131019382796, 'time_step': 0.09055875276800289, 'td_error': 16.79346998350522, 'init_value': -17.841386795043945, 'ave_value': -14.440899504690345} step=228723
2022-04-20 21:00.48 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_228723.pt


Epoch 34/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:11.48 [info     ] CQL_20220420153335: epoch=34 step=235654 epoch=34 metrics={'time_sample_batch': 0.0003455211513917213, 'time_algorithm_update': 0.09323200930456564, 'temp_loss': 0.000501227442580546, 'temp': 0.1423527768719955, 'alpha_loss': 0.004456212818138052, 'alpha': 0.2502512654590325, 'critic_loss': 35.49671606548144, 'actor_loss': 17.665711751987864, 'time_step': 0.09396325979479084, 'td_error': 16.87797868509053, 'init_value': -17.48982048034668, 'ave_value': -14.098896501718935} step=235654
2022-04-20 21:11.48 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_235654.pt


Epoch 35/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:22.47 [info     ] CQL_20220420153335: epoch=35 step=242585 epoch=35 metrics={'time_sample_batch': 0.00034734728428867194, 'time_algorithm_update': 0.09288503029459638, 'temp_loss': -0.0005125424518054957, 'temp': 0.14187589404499043, 'alpha_loss': 0.009451469327820237, 'alpha': 0.24313579278493785, 'critic_loss': 35.557294490404104, 'actor_loss': 17.63520174000002, 'time_step': 0.09361921827646238, 'td_error': 16.86568043426902, 'init_value': -17.50206756591797, 'ave_value': -14.039677733182488} step=242585
2022-04-20 21:22.47 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_242585.pt


Epoch 36/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:34.02 [info     ] CQL_20220420153335: epoch=36 step=249516 epoch=36 metrics={'time_sample_batch': 0.00035171167552256166, 'time_algorithm_update': 0.09510776701126242, 'temp_loss': 0.00034916582879118696, 'temp': 0.142622708145579, 'alpha_loss': 0.0042074592563311934, 'alpha': 0.23646915714654068, 'critic_loss': 35.614315938621075, 'actor_loss': 17.53239012240091, 'time_step': 0.09585913237925474, 'td_error': 17.00824076211856, 'init_value': -17.2319278717041, 'ave_value': -13.770526571433306} step=249516
2022-04-20 21:34.02 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_249516.pt


Epoch 37/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:44.56 [info     ] CQL_20220420153335: epoch=37 step=256447 epoch=37 metrics={'time_sample_batch': 0.00034629247728974386, 'time_algorithm_update': 0.09229520361001418, 'temp_loss': 4.190294639546171e-05, 'temp': 0.1410367480261366, 'alpha_loss': 0.008115342866725758, 'alpha': 0.23400517922986416, 'critic_loss': 35.91814717850357, 'actor_loss': 17.356215077137264, 'time_step': 0.09302937979306722, 'td_error': 17.089060899035896, 'init_value': -17.40532875061035, 'ave_value': -13.955509605871045} step=256447
2022-04-20 21:44.56 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_256447.pt


Epoch 38/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 21:55.50 [info     ] CQL_20220420153335: epoch=38 step=263378 epoch=38 metrics={'time_sample_batch': 0.0003475429794692388, 'time_algorithm_update': 0.09222452775721013, 'temp_loss': 0.0001288628786776224, 'temp': 0.14015199191994077, 'alpha_loss': 0.004610733376362086, 'alpha': 0.2302455517606222, 'critic_loss': 35.76431987404634, 'actor_loss': 17.269734122552425, 'time_step': 0.09296144260642608, 'td_error': 17.05119015887538, 'init_value': -17.28110694885254, 'ave_value': -13.84102226605183} step=263378
2022-04-20 21:55.50 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_263378.pt


Epoch 39/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 22:06.38 [info     ] CQL_20220420153335: epoch=39 step=270309 epoch=39 metrics={'time_sample_batch': 0.0003423051320875281, 'time_algorithm_update': 0.09143564601667584, 'temp_loss': 0.00021143193109988384, 'temp': 0.13938470317453658, 'alpha_loss': 0.0018873487403012722, 'alpha': 0.22648190581519922, 'critic_loss': 35.705481725614696, 'actor_loss': 17.279242882281736, 'time_step': 0.09216279901651193, 'td_error': 16.95232798074316, 'init_value': -16.499317169189453, 'ave_value': -13.00099126757284} step=270309
2022-04-20 22:06.39 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_270309.pt


Epoch 40/40:   0%|          | 0/6931 [00:00<?, ?it/s]

2022-04-20 22:17.39 [info     ] CQL_20220420153335: epoch=40 step=277240 epoch=40 metrics={'time_sample_batch': 0.00033363634125182953, 'time_algorithm_update': 0.09320847132628725, 'temp_loss': 0.0001576871940318687, 'temp': 0.13916002034650807, 'alpha_loss': -0.0005196051136630011, 'alpha': 0.22511139057996968, 'critic_loss': 35.590238842995475, 'actor_loss': 17.075815960790685, 'time_step': 0.09389307945773227, 'td_error': 16.91298724805621, 'init_value': -16.443775177001953, 'ave_value': -13.03374250195694} step=277240
2022-04-20 22:17.39 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220420153335/model_277240.pt


[(1,
  {'time_sample_batch': 0.0003195364501737375,
   'time_algorithm_update': 0.08248434316508269,
   'temp_loss': 2.768967193367677,
   'temp': 0.747011374728949,
   'alpha_loss': -3.7443262735867933,
   'alpha': 1.2409134573827074,
   'critic_loss': 19.95090428755373,
   'actor_loss': 24.710188812992985,
   'time_step': 0.08313859886738528,
   'td_error': 33.174810272332685,
   'init_value': -63.10578155517578,
   'ave_value': -54.6618558456178}),
 (2,
  {'time_sample_batch': 0.00032477519192610995,
   'time_algorithm_update': 0.0886658003381485,
   'temp_loss': 0.9284023452228221,
   'temp': 0.4178826983761708,
   'alpha_loss': 0.5071393584290255,
   'alpha': 1.1343545683739833,
   'critic_loss': 108.15097504456035,
   'actor_loss': 88.52169504579811,
   'time_step': 0.0893653748338485,
   'td_error': 70.28582000488592,
   'init_value': -123.6652603149414,
   'ave_value': -108.84911071964142}),
 (3,
  {'time_sample_batch': 0.00032805914900172514,
   'time_algorithm_update': 0.0977

In [10]:
model.save_model('cqlStochpidDouble2000_Ep40_CPUOnly.pt')
model.save_policy('cqlStochpidDobule2000_Ep40_CPUOnly.pt')

  minimum = torch.tensor(
  maximum = torch.tensor(


## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [11]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i+2000 for i in range(100)]) #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [12]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i*2 for i in range(100)], path="collected_data/rl_stochastic.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [13]:
# from d3rlpy.torch_utility import to_cpu
# to_cpu(model)
# model.save_policy("cqlStochpid2000Ep40CPU.pt")
# model.save_model("cqlStochpid2000Ep40modelCPU.pt")

In [14]:
# for key in dir(model):
#     module = getattr(model, key)
#     if isinstance(module, (torch.nn.Module, torch.nn.Parameter)):
#         print(yes)
#         print(key)
# dir(model)
# type(model)
# model.cpu()
# from d3rlpy.algos.torch.base import TorchImplBase
# new_model = TorchImplBase()
# from d3rlpy.torch_utility import _get_attributes
# model._device = "cpu:0"
# print(model._device)


# def my_get_state_dict(impl: Any) -> Dict[str, Any]:
#     rets = {}
#     for key in _get_attributes(impl):
#         obj = getattr(impl, key)
#         if isinstance(obj, (torch.nn.Module, torch.optim.Optimizer)):
#             if isinstance(obj, (torch.nn.Module, torch.nn.Parameter)):
#                 obj.cpu()
#             rets[key] = obj.state_dict()
#     return rets

# torch.save(my_get_state_dict(model), "my_test_model.pt")

# for key in dir(model):
#     obj = getattr(model, key)
#     if isinstance(obj, (torch.nn.Module, torch.nn.Parameter)):
#         obj.cpu()
#         print("convert to cpu")
# model.save_policy("cqlStochpid2000Ep40modelCPU.pt")

# import trace
# tracer = trace.Trace()
# tracer.run('model.save_policy("cqlStochpid2000Ep40modelCPU.pt")')