# Sample Workflow for d3rlpy Experiments

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
import math
import subprocess
import os
import d3rlpy
plt.style.use('matplotlibrc')

from Python.data_sampler import *

## Building an MDPDataset

We first read in a large batch of samples from the file. As `d3rlpy` wants it in the form (observations, actions, rewards, terminal flags), we go ahead and do that. Here's a helper function to get a dataset from a list of chunks of your choosing.

In [2]:
def get_dataset(chunks : list, batch_size=30000, 
                path="collected_data/rl_deterministic.txt") -> d3rlpy.dataset.MDPDataset :
    random.seed(0)
    samples = DataSampler(path_to_data="collected_data/rl_stochpid.txt")
    samples.setting("coarse")
    states = []
    actions = []
    rewards = []
    next_states = []
    for chunk in chunks:
        samples.use_chunk(chunk)
        samples.read_chunk()
        [statesChunk, actionsChunk, rewardsChunk, nextStatesChunk] = samples.get_batch(batch_size)
        states.append(statesChunk)
        actions.append(actionsChunk)
        rewards.append(rewardsChunk)
        next_states.append(nextStatesChunk)
    states = torch.cat(states)
    actions = torch.cat(actions)
    rewards = torch.cat(rewards)
    next_states = torch.cat(next_states)
    terminals = np.zeros(len(states))
    terminals[::1111] = 1 #episode length 100, change if necessary
    print(states.shape)
    dataset = d3rlpy.dataset.MDPDataset(states.numpy(), 
                                        actions.numpy(), 
                                        rewards.numpy(), terminals)
    return dataset

We can build the dataset from there, just like this, and split into train and test sets.

In [3]:
dataset = get_dataset([i+(i-1)*2 for i in range(200)])

start
[ 0.00000000e+00  7.95731469e+08 -4.75891077e-02 -3.69999953e-02
  2.00999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.50429671e-01 -4.92727243e-01 -5.31666025e-03]
Read chunk # -1 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.25610892e-01 -3.35999953e-02
 -2.42000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.08749986e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 2 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.08389108e-01  3.32000047e-02
 -2.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.25137655e-01 -1.65270603e-01  6.00000000e-01]
Read chunk # 5 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -8.62891077e-02 -4.03999953e-02
 -1.76000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -2.96411621e-01 -3.05278115e-01  6.00000000e-01]
Read chunk # 8 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -9.61891077e-02 -3.73999953e-02
 -9.50001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.4272166

start
[ 0.00000000e+00  7.95731469e+08  9.19108923e-02 -3.75999953e-02
 -1.29000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.10756296e-01  9.23267200e-02 -6.00000000e-01]
Read chunk # 137 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.67589108e-01 -4.39999953e-02
  2.89999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.07944815e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 140 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.20789108e-01 -3.09999953e-02
 -2.24000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01 -2.74665987e-01  6.00000000e-01]
Read chunk # 143 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.02189108e-01  5.78000047e-02
  5.39998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.40938487e-01 -2.83568202e-01  6.00000000e-01]
Read chunk # 146 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -4.51891077e-02  2.50000047e-02
 -1.63000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -3.

start
[ 0.00000000e+00  7.95731469e+08  1.49210892e-01 -3.03999953e-02
  1.49998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.79615166e-01  5.81541855e-01 -6.00000000e-01]
Read chunk # 275 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  4.37810892e-01 -5.05999953e-02
 -1.02000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.09901302e-02  6.00000000e-01 -6.00000000e-01]
Read chunk # 278 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  7.48108923e-02 -5.23999953e-02
  2.36999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -3.68605982e-02 -4.87448936e-01]
Read chunk # 281 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.21610892e-01  8.00000469e-03
  2.97999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01  2.91952737e-02 -6.00000000e-01]
Read chunk # 284 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  1.48310892e-01  5.00000047e-02
  2.96999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.

[ 0.00000000e+00  7.95731469e+08  1.10510892e-01 -5.01999953e-02
 -2.71000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.67950082e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 419 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  3.13310892e-01 -2.69999953e-02
 -5.80001342e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.68944324e-01  6.00000000e-01 -6.00000000e-01]
Read chunk # 422 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -1.29489108e-01 -5.43999953e-02
  1.76999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  4.79725512e-01 -3.26286803e-01  6.00000000e-01]
Read chunk # 425 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.33210892e-01  2.12000047e-02
 -1.99000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -6.00000000e-01  5.26419148e-01 -6.00000000e-01]
Read chunk # 428 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.94489108e-01 -3.47999953e-02
  2.14999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.000000

start
[ 0.00000000e+00  7.95731469e+08  1.59410892e-01  1.36000047e-02
 -1.09000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.35316098e-01  5.03258928e-01 -6.00000000e-01]
Read chunk # 578 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.04189108e-01  1.82000047e-02
  1.36999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -4.92474183e-02 -6.00000000e-01  6.00000000e-01]
Read chunk # 581 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -2.86089108e-01 -5.23999953e-02
  8.59998658e-03  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  6.00000000e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 584 out of 4999
start
[ 0.00000000e+00  7.95731469e+08 -3.02589108e-01 -4.71999953e-02
 -1.54000134e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
 -1.09668859e-01 -6.00000000e-01  6.00000000e-01]
Read chunk # 587 out of 4999
start
[ 0.00000000e+00  7.95731469e+08  2.02108923e-02  4.70000047e-02
  1.75999866e-02  0.00000000e+00 -5.33423489e+00 -1.57091618e+00
  1.

In [4]:
print("The behavior policy value statistics are:")
dataset.compute_stats()['return']

The behavior policy value statistics are:


{'mean': -135.4019,
 'std': 100.196686,
 'min': -394.1042,
 'max': 0.0,
 'histogram': (array([ 4, 10,  3,  1,  7,  5,  3,  5,  5,  4,  2,  7,  4, 15, 20, 29, 31,
         37,  7,  1], dtype=int64),
  array([-394.1042  , -374.399   , -354.69376 , -334.98856 , -315.28336 ,
         -295.57812 , -275.87292 , -256.16772 , -236.46251 , -216.75731 ,
         -197.0521  , -177.34688 , -157.64168 , -137.93646 , -118.231255,
          -98.52605 ,  -78.82084 ,  -59.115627,  -39.41042 ,  -19.70521 ,
            0.      ], dtype=float32))}

In [5]:
from sklearn.model_selection import train_test_split
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

## Setting up an Algorithm

In [6]:
from d3rlpy.algos import CQL

from d3rlpy.preprocessing import MinMaxActionScaler
action_scaler = MinMaxActionScaler(minimum=-0.6, maximum=0.6)
#cql = CQL(action_scaler=action_scaler)

model = CQL(q_func_factory='mean', #qr -> quantile regression q function, but you don't have to use this
            reward_scaler='standard',
            action_scaler=action_scaler,
          actor_learning_rate=1e-5, 
          critic_learning_rate=0.0003, 
            use_gpu=False) #change it to true if you have one
model.build_with_dataset(dataset)

In [7]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import initial_state_value_estimation_scorer

# calculate metrics with test dataset
ave_error_init = average_value_estimation_scorer(model, test_episodes)
print(ave_error_init)

-0.055221074319288195


In [8]:
%load_ext tensorboard
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 9160), started 9 days, 13:54:19 ago. (Use '!kill 9160' to kill it.)

In [9]:
model.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=15, 
        tensorboard_dir='runs',
        scorers={
            'td_error': td_error_scorer,
            'init_value': initial_state_value_estimation_scorer,
            'ave_value': average_value_estimation_scorer
        })

2022-04-17 12:00.21 [debug    ] RoundIterator is selected.
2022-04-17 12:00.21 [info     ] Directory is created at d3rlpy_logs\CQL_20220417120021
2022-04-17 12:00.21 [debug    ] Fitting action scaler...       action_scaler=min_max
2022-04-17 12:00.21 [debug    ] Fitting reward scaler...       reward_scaler=standard
2022-04-17 12:00.21 [info     ] Parameters are saved to d3rlpy_logs\CQL_20220417120021\params.json params={'action_scaler': {'type': 'min_max', 'params': {'minimum': array(-0.6), 'maximum': array(0.6)}}, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 1e-05, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conser

Epoch 1/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:02.30 [info     ] CQL_20220417120021: epoch=1 step=693 epoch=1 metrics={'time_sample_batch': 0.000581503946543772, 'time_algorithm_update': 0.18238733375571334, 'temp_loss': 4.65626538677133, 'temp': 0.9668053239798993, 'alpha_loss': -12.911016746344849, 'alpha': 1.0322156683535115, 'critic_loss': 14.74611774461094, 'actor_loss': -0.42827668672892527, 'time_step': 0.18320987406919184, 'td_error': 6.001556134833017, 'init_value': -4.348829746246338, 'ave_value': -0.5191810367349148} step=693
2022-04-17 12:02.30 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_693.pt


Epoch 2/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:04.29 [info     ] CQL_20220417120021: epoch=2 step=1386 epoch=2 metrics={'time_sample_batch': 0.0004270537763102918, 'time_algorithm_update': 0.16782348118131124, 'temp_loss': 4.042389003745405, 'temp': 0.9060278358569565, 'alpha_loss': -7.260059795682392, 'alpha': 1.0884678769765306, 'critic_loss': 10.052958227683284, 'actor_loss': 3.121363379330002, 'time_step': 0.1684828299980659, 'td_error': 14.85021098793301, 'init_value': -13.9783935546875, 'ave_value': -4.502754552398767} step=1386
2022-04-17 12:04.29 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_1386.pt


Epoch 3/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:06.27 [info     ] CQL_20220417120021: epoch=3 step=2079 epoch=3 metrics={'time_sample_batch': 0.0005929741852555268, 'time_algorithm_update': 0.16659873010104181, 'temp_loss': 3.576845687533182, 'temp': 0.8512001107265423, 'alpha_loss': -4.1661767718767875, 'alpha': 1.1331465830878606, 'critic_loss': 9.84127060641114, 'actor_loss': 9.652977783098537, 'time_step': 0.16735042844499862, 'td_error': 19.60398619671991, 'init_value': -26.02377700805664, 'ave_value': -9.833783397163375} step=2079
2022-04-17 12:06.27 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_2079.pt


Epoch 4/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:08.25 [info     ] CQL_20220417120021: epoch=4 step=2772 epoch=4 metrics={'time_sample_batch': 0.0006621565137590681, 'time_algorithm_update': 0.16742711734634114, 'temp_loss': 3.1692982800251133, 'temp': 0.8011528140882737, 'alpha_loss': -2.8076866900260744, 'alpha': 1.17197868473086, 'critic_loss': 13.866958395915287, 'actor_loss': 17.794004983888215, 'time_step': 0.16825243095298867, 'td_error': 22.42609883305729, 'init_value': -39.32931900024414, 'ave_value': -16.39575589178855} step=2772
2022-04-17 12:08.25 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_2772.pt


Epoch 5/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:10.23 [info     ] CQL_20220417120021: epoch=5 step=3465 epoch=5 metrics={'time_sample_batch': 0.0006348157857919668, 'time_algorithm_update': 0.16598991054133075, 'temp_loss': 2.820546483580684, 'temp': 0.7550066305273367, 'alpha_loss': -2.2495455466329104, 'alpha': 1.2119492885870335, 'critic_loss': 21.397209364270406, 'actor_loss': 26.36379223227673, 'time_step': 0.166740145346131, 'td_error': 25.293200149778745, 'init_value': -54.10169219970703, 'ave_value': -22.602784375771563} step=3465
2022-04-17 12:10.23 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_3465.pt


Epoch 6/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:12.21 [info     ] CQL_20220417120021: epoch=6 step=4158 epoch=6 metrics={'time_sample_batch': 0.0006377026116177117, 'time_algorithm_update': 0.16696134033313217, 'temp_loss': 2.4986630112233787, 'temp': 0.7121622214069614, 'alpha_loss': -1.8040631297069654, 'alpha': 1.2553659978539053, 'critic_loss': 32.08857310626758, 'actor_loss': 34.95188240502648, 'time_step': 0.16790203889899097, 'td_error': 29.275144914172746, 'init_value': -70.87553405761719, 'ave_value': -30.787530470795485} step=4158
2022-04-17 12:12.21 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_4158.pt


Epoch 7/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:14.17 [info     ] CQL_20220417120021: epoch=7 step=4851 epoch=7 metrics={'time_sample_batch': 0.0004905364214083849, 'time_algorithm_update': 0.16463896963331434, 'temp_loss': 2.1880219995889485, 'temp': 0.6725771279218049, 'alpha_loss': -1.3982097943018434, 'alpha': 1.3002332290701708, 'critic_loss': 44.02269009766297, 'actor_loss': 43.06395434920406, 'time_step': 0.1653632488085594, 'td_error': 33.69266521438974, 'init_value': -85.70948791503906, 'ave_value': -36.35884067366008} step=4851
2022-04-17 12:14.17 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_4851.pt


Epoch 8/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:16.15 [info     ] CQL_20220417120021: epoch=8 step=5544 epoch=8 metrics={'time_sample_batch': 0.00048186940720243025, 'time_algorithm_update': 0.16577298816664393, 'temp_loss': 1.9091968126971313, 'temp': 0.6357587360028408, 'alpha_loss': -0.8422430705562675, 'alpha': 1.3415798963750423, 'critic_loss': 59.289354796361444, 'actor_loss': 50.756713201129244, 'time_step': 0.16642943777219213, 'td_error': 39.98317461586636, 'init_value': -102.40438079833984, 'ave_value': -42.78646596258656} step=5544
2022-04-17 12:16.15 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_5544.pt


Epoch 9/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:18.11 [info     ] CQL_20220417120021: epoch=9 step=6237 epoch=9 metrics={'time_sample_batch': 0.0006204067714630612, 'time_algorithm_update': 0.16425267476884145, 'temp_loss': 1.6732670479866678, 'temp': 0.6014262222796464, 'alpha_loss': -0.4084108769298971, 'alpha': 1.3685963265651577, 'critic_loss': 76.6834666949158, 'actor_loss': 58.06314557008069, 'time_step': 0.16500294053709352, 'td_error': 47.60893605526132, 'init_value': -118.98560333251953, 'ave_value': -49.00421224516956} step=6237
2022-04-17 12:18.11 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_6237.pt


Epoch 10/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:20.08 [info     ] CQL_20220417120021: epoch=10 step=6930 epoch=10 metrics={'time_sample_batch': 0.0006665818791024785, 'time_algorithm_update': 0.16571338352186854, 'temp_loss': 1.469061935274804, 'temp': 0.5690988520695189, 'alpha_loss': 0.07688553110907788, 'alpha': 1.377646206605314, 'critic_loss': 92.84677293015076, 'actor_loss': 65.00190007703817, 'time_step': 0.16653869781659278, 'td_error': 54.852616630438845, 'init_value': -133.86863708496094, 'ave_value': -54.96286857938301} step=6930
2022-04-17 12:20.08 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_6930.pt


Epoch 11/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:22.04 [info     ] CQL_20220417120021: epoch=11 step=7623 epoch=11 metrics={'time_sample_batch': 0.0006224964604233251, 'time_algorithm_update': 0.16451068082757153, 'temp_loss': 1.3229033974514035, 'temp': 0.5382011694481534, 'alpha_loss': 0.4170242162034303, 'alpha': 1.3621412262772068, 'critic_loss': 109.27623535509922, 'actor_loss': 71.64886636445017, 'time_step': 0.16530630151602785, 'td_error': 63.17940015315356, 'init_value': -149.45184326171875, 'ave_value': -61.386242799283785} step=7623
2022-04-17 12:22.04 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_7623.pt


Epoch 12/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:24.02 [info     ] CQL_20220417120021: epoch=12 step=8316 epoch=12 metrics={'time_sample_batch': 0.0004761088290083804, 'time_algorithm_update': 0.1660835650060084, 'temp_loss': 1.185880459404267, 'temp': 0.5087536117559216, 'alpha_loss': 0.5833540028707839, 'alpha': 1.324088328844541, 'critic_loss': 127.83576335122575, 'actor_loss': 77.80929886899125, 'time_step': 0.1666751042072907, 'td_error': 71.62691609080991, 'init_value': -164.76126098632812, 'ave_value': -66.63057370241908} step=8316
2022-04-17 12:24.02 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_8316.pt


Epoch 13/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:25.58 [info     ] CQL_20220417120021: epoch=13 step=9009 epoch=13 metrics={'time_sample_batch': 0.0005237670859905204, 'time_algorithm_update': 0.16389727420449085, 'temp_loss': 1.0804106128680242, 'temp': 0.48042657059680505, 'alpha_loss': 0.6573837014482307, 'alpha': 1.2808677620357938, 'critic_loss': 146.20540157059398, 'actor_loss': 83.85198626717792, 'time_step': 0.16469662874119967, 'td_error': 79.7875179208428, 'init_value': -177.63638305664062, 'ave_value': -70.9167471174476} step=9009
2022-04-17 12:25.58 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_9009.pt


Epoch 14/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:27.54 [info     ] CQL_20220417120021: epoch=14 step=9702 epoch=14 metrics={'time_sample_batch': 0.0005626912412877379, 'time_algorithm_update': 0.1646141273118717, 'temp_loss': 0.988661247414428, 'temp': 0.45337048633102045, 'alpha_loss': 0.6154817443819445, 'alpha': 1.2362774607079026, 'critic_loss': 166.17844334206023, 'actor_loss': 89.60769694054454, 'time_step': 0.16539324172819503, 'td_error': 88.38481607945064, 'init_value': -192.35400390625, 'ave_value': -76.79703487137917} step=9702
2022-04-17 12:27.54 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_9702.pt


Epoch 15/15:   0%|          | 0/693 [00:00<?, ?it/s]

2022-04-17 12:29.51 [info     ] CQL_20220417120021: epoch=15 step=10395 epoch=15 metrics={'time_sample_batch': 0.0005222808403026146, 'time_algorithm_update': 0.16569094293217293, 'temp_loss': 0.8922837355092146, 'temp': 0.4275907404495008, 'alpha_loss': 0.6358140777708482, 'alpha': 1.1962082086703478, 'critic_loss': 183.75308277699855, 'actor_loss': 94.97248449504461, 'time_step': 0.16642965417231662, 'td_error': 97.42876166185621, 'init_value': -205.39906311035156, 'ave_value': -82.70192911476404} step=10395
2022-04-17 12:29.51 [info     ] Model parameters are saved to d3rlpy_logs\CQL_20220417120021\model_10395.pt


[(1,
  {'time_sample_batch': 0.000581503946543772,
   'time_algorithm_update': 0.18238733375571334,
   'temp_loss': 4.65626538677133,
   'temp': 0.9668053239798993,
   'alpha_loss': -12.911016746344849,
   'alpha': 1.0322156683535115,
   'critic_loss': 14.74611774461094,
   'actor_loss': -0.42827668672892527,
   'time_step': 0.18320987406919184,
   'td_error': 6.001556134833017,
   'init_value': -4.348829746246338,
   'ave_value': -0.5191810367349148}),
 (2,
  {'time_sample_batch': 0.0004270537763102918,
   'time_algorithm_update': 0.16782348118131124,
   'temp_loss': 4.042389003745405,
   'temp': 0.9060278358569565,
   'alpha_loss': -7.260059795682392,
   'alpha': 1.0884678769765306,
   'critic_loss': 10.052958227683284,
   'actor_loss': 3.121363379330002,
   'time_step': 0.1684828299980659,
   'td_error': 14.85021098793301,
   'init_value': -13.9783935546875,
   'ave_value': -4.502754552398767}),
 (3,
  {'time_sample_batch': 0.0005929741852555268,
   'time_algorithm_update': 0.166598

In [10]:
model.save_model('cqlStochpid200model_Ep15check.pt')
model.save_policy('cqlStochpid200_Ep15check.pt')

  minimum = torch.tensor(
  maximum = torch.tensor(


## Off-Policy Evaluation

We do get some metrics on a test set of initial state value and average value. However, these estimates (using the critic's Q-function) of model performance are biased. They're useful for validation during training, but not much else. Instead, we fit a Q-function to the data (or a separate dataset, as I've done here) separately and evaluate the model's performance on it.

Feel free to change the chunks and number of steps.

In [11]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i*2 for i in range(100)], path="collected_data/rl_deterministic.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })

In [12]:
# from d3rlpy.ope import FQE
# # metrics to evaluate with
# from d3rlpy.metrics.scorer import soft_opc_scorer


# ope_dataset = get_dataset([i*2 for i in range(100)], path="collected_data/rl_stochastic.txt") #change if you'd prefer different chunks
# ope_train_episodes, ope_test_episodes = train_test_split(ope_dataset, test_size=0.2)

# fqe = FQE(algo=model, action_scaler = action_scaler, use_gpu=False) #change this if you have one!
# fqe.fit(ope_train_episodes, eval_episodes=ope_test_episodes,
#         tensorboard_dir='runs',
#         n_epochs=100, n_steps_per_epoch=10000, #change if overfitting/underfitting
#         scorers={
#            'init_value': initial_state_value_estimation_scorer,
#             'ave_value': average_value_estimation_scorer,
#            'soft_opc': soft_opc_scorer(return_threshold=0)
#         })