In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import d3rlpy
from d3rlpy.metrics.scorer import average_value_estimation_scorer, td_error_scorer, evaluate_on_environment

import warnings
warnings.filterwarnings("ignore")

import os
os.chdir("..")

from augmentation.augmented_dataset import MDPDatasetAugmented


In [9]:
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"
dataset, env = d3rlpy.datasets.get_d4rl("door-human-v0")

Downloading dataset: http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5 to /home/gerkone/.d4rl/datasets/door-v0_demos_clipped.hdf5


load datafile: 100%|██████████| 7/7 [00:00<00:00, 119.96it/s]


In [10]:
AUGMENTATIONS = ["gaussian", "mixup"]

In [11]:
augmented_dataset = MDPDatasetAugmented.from_mdpdataset(dataset, augmentations=AUGMENTATIONS)

In [12]:
assert augmented_dataset.observations.shape[0] == dataset.observations.shape[0] * (1 + len(AUGMENTATIONS))

In [13]:
agent = d3rlpy.algos.CQL(use_gpu=False)

agent.fit(
    dataset,
    eval_episodes=dataset,
    n_epochs=2 * (1 + len(AUGMENTATIONS)),
    scorers={
        'td_error': td_error_scorer,
        'value_scale': average_value_estimation_scorer,
        'environment': evaluate_on_environment(env)
    }
)

2022-05-22 15:40.50 [debug    ] RoundIterator is selected.
2022-05-22 15:40.50 [info     ] Directory is created at d3rlpy_logs/CQL_20220522154050
2022-05-22 15:40.50 [debug    ] Building models...
2022-05-22 15:40.50 [debug    ] Models have been built.
2022-05-22 15:40.50 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220522154050/params.json params={'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conservative_weight': 5.0, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rat

Epoch 1/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:41.15 [info     ] CQL_20220522154050: epoch=1 step=26 epoch=1 metrics={'time_sample_batch': 0.0009445043710561899, 'time_algorithm_update': 0.7118320373388437, 'temp_loss': 46.71723659221943, 'temp': 0.998650835110591, 'alpha_loss': -102.42353615394005, 'alpha': 1.0013495683670044, 'critic_loss': 175.02379256028397, 'actor_loss': -18.812982999361477, 'time_step': 0.713067889213562, 'td_error': 35.472251812045, 'value_scale': 0.3178841885438229, 'environment': -53.841486284439476} step=26
2022-05-22 15:41.15 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_26.pt


Epoch 2/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:41.41 [info     ] CQL_20220522154050: epoch=2 step=52 epoch=2 metrics={'time_sample_batch': 0.0009436699060293344, 'time_algorithm_update': 0.7203797652171209, 'temp_loss': 46.82242833651029, 'temp': 0.9960554769405952, 'alpha_loss': -94.40246992844801, 'alpha': 1.0039218939267671, 'critic_loss': 144.78933803851788, 'actor_loss': -18.712352605966423, 'time_step': 0.7216175886300894, 'td_error': 26.477559076208458, 'value_scale': -0.05474662210687835, 'environment': -54.2223686593714} step=52
2022-05-22 15:41.41 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_52.pt


Epoch 3/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:42.03 [info     ] CQL_20220522154050: epoch=3 step=78 epoch=3 metrics={'time_sample_batch': 0.0008252400618333083, 'time_algorithm_update': 0.6149879327187171, 'temp_loss': 46.452384361853966, 'temp': 0.9934693047633538, 'alpha_loss': -80.12102068387546, 'alpha': 1.006357376392071, 'critic_loss': 100.80487676767203, 'actor_loss': -16.39248569195087, 'time_step': 0.6160286664962769, 'td_error': 25.697336406817406, 'value_scale': -0.5923816384115537, 'environment': -57.62098748037359} step=78
2022-05-22 15:42.03 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_78.pt


Epoch 4/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:42.26 [info     ] CQL_20220522154050: epoch=4 step=104 epoch=4 metrics={'time_sample_batch': 0.0009518770071176382, 'time_algorithm_update': 0.6013142053897564, 'temp_loss': 44.77232889028696, 'temp': 0.9909158990933344, 'alpha_loss': -63.35229829641489, 'alpha': 1.0085399013299208, 'critic_loss': 74.43186305119441, 'actor_loss': -13.12734343455388, 'time_step': 0.602516586963947, 'td_error': 31.7959023079784, 'value_scale': 1.1777528064872242, 'environment': -57.25170885213329} step=104
2022-05-22 15:42.26 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_104.pt


Epoch 5/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:42.48 [info     ] CQL_20220522154050: epoch=5 step=130 epoch=5 metrics={'time_sample_batch': 0.0008440292798555815, 'time_algorithm_update': 0.6020467373041006, 'temp_loss': 42.34855431776781, 'temp': 0.9884487207119281, 'alpha_loss': -48.96703485342172, 'alpha': 1.0103995891717763, 'critic_loss': 58.9422784952017, 'actor_loss': -10.712369038508488, 'time_step': 0.6031304322756253, 'td_error': 31.70658615559485, 'value_scale': 2.720891158987111, 'environment': -52.99139326943612} step=130
2022-05-22 15:42.48 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_130.pt


Epoch 6/6:   0%|          | 0/26 [00:00<?, ?it/s]

2022-05-22 15:43.12 [info     ] CQL_20220522154050: epoch=6 step=156 epoch=6 metrics={'time_sample_batch': 0.0008664314563457782, 'time_algorithm_update': 0.6509894316013043, 'temp_loss': 41.18097921518179, 'temp': 0.986058411689905, 'alpha_loss': -34.56473497244028, 'alpha': 1.0119349681414092, 'critic_loss': 44.56248298058143, 'actor_loss': -8.016905711247372, 'time_step': 0.6521049646230844, 'td_error': 33.805772095196, 'value_scale': 2.6485470197200196, 'environment': -53.03238902781733} step=156
2022-05-22 15:43.12 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154050/model_156.pt


[(1,
  {'time_sample_batch': 0.0009445043710561899,
   'time_algorithm_update': 0.7118320373388437,
   'temp_loss': 46.71723659221943,
   'temp': 0.998650835110591,
   'alpha_loss': -102.42353615394005,
   'alpha': 1.0013495683670044,
   'critic_loss': 175.02379256028397,
   'actor_loss': -18.812982999361477,
   'time_step': 0.713067889213562,
   'td_error': 35.472251812045,
   'value_scale': 0.3178841885438229,
   'environment': -53.841486284439476}),
 (2,
  {'time_sample_batch': 0.0009436699060293344,
   'time_algorithm_update': 0.7203797652171209,
   'temp_loss': 46.82242833651029,
   'temp': 0.9960554769405952,
   'alpha_loss': -94.40246992844801,
   'alpha': 1.0039218939267671,
   'critic_loss': 144.78933803851788,
   'actor_loss': -18.712352605966423,
   'time_step': 0.7216175886300894,
   'td_error': 26.477559076208458,
   'value_scale': -0.05474662210687835,
   'environment': -54.2223686593714}),
 (3,
  {'time_sample_batch': 0.0008252400618333083,
   'time_algorithm_update': 0.

In [14]:
augmented_agent = d3rlpy.algos.CQL(use_gpu=False)

augmented_agent.fit(
    augmented_dataset,
    eval_episodes=augmented_dataset,
    n_epochs=2,
    scorers={
        'td_error': td_error_scorer,
        'value_scale': average_value_estimation_scorer,
        'environment': evaluate_on_environment(env)
    }
)

2022-05-22 15:43.13 [debug    ] RoundIterator is selected.
2022-05-22 15:43.13 [info     ] Directory is created at d3rlpy_logs/CQL_20220522154313
2022-05-22 15:43.13 [debug    ] Building models...
2022-05-22 15:43.13 [debug    ] Models have been built.
2022-05-22 15:43.13 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220522154313/params.json params={'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conservative_weight': 5.0, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rat

Epoch 1/2:   0%|          | 0/78 [00:00<?, ?it/s]

2022-05-22 15:44.11 [info     ] CQL_20220522154313: epoch=1 step=78 epoch=1 metrics={'time_sample_batch': 0.0010720338576879257, 'time_algorithm_update': 0.6404838531445234, 'temp_loss': 46.7459525572948, 'temp': 0.996059829607988, 'alpha_loss': -92.82213993561574, 'alpha': 1.003877138480162, 'critic_loss': 144.71215927906525, 'actor_loss': -18.432279317806927, 'time_step': 0.6417846801953438, 'td_error': 25.180772176780298, 'value_scale': -0.18643188069034783, 'environment': -54.66997082253339} step=78
2022-05-22 15:44.12 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154313/model_78.pt


Epoch 2/2:   0%|          | 0/78 [00:00<?, ?it/s]

2022-05-22 15:45.07 [info     ] CQL_20220522154313: epoch=2 step=156 epoch=2 metrics={'time_sample_batch': 0.0009337938748873197, 'time_algorithm_update': 0.5982088278501462, 'temp_loss': 44.17121359018179, 'temp': 0.9884171944398147, 'alpha_loss': -59.23591589316344, 'alpha': 1.0106247235567143, 'critic_loss': 80.67120483594063, 'actor_loss': -12.796754641410631, 'time_step': 0.5993590477185372, 'td_error': 27.855670217993087, 'value_scale': 2.626322401517001, 'environment': -53.631392006400134} step=156
2022-05-22 15:45.07 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220522154313/model_156.pt


[(1,
  {'time_sample_batch': 0.0010720338576879257,
   'time_algorithm_update': 0.6404838531445234,
   'temp_loss': 46.7459525572948,
   'temp': 0.996059829607988,
   'alpha_loss': -92.82213993561574,
   'alpha': 1.003877138480162,
   'critic_loss': 144.71215927906525,
   'actor_loss': -18.432279317806927,
   'time_step': 0.6417846801953438,
   'td_error': 25.180772176780298,
   'value_scale': -0.18643188069034783,
   'environment': -54.66997082253339}),
 (2,
  {'time_sample_batch': 0.0009337938748873197,
   'time_algorithm_update': 0.5982088278501462,
   'temp_loss': 44.17121359018179,
   'temp': 0.9884171944398147,
   'alpha_loss': -59.23591589316344,
   'alpha': 1.0106247235567143,
   'critic_loss': 80.67120483594063,
   'actor_loss': -12.796754641410631,
   'time_step': 0.5993590477185372,
   'td_error': 27.855670217993087,
   'value_scale': 2.626322401517001,
   'environment': -53.631392006400134})]