In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
os.chdir("..")

In [4]:
import d3rlpy
from d3rlpy.metrics.scorer import average_value_estimation_scorer, td_error_scorer, evaluate_on_environment

import warnings
warnings.filterwarnings("ignore")

from augmentation.augmented_dataset import MDPDatasetAugmented


In [5]:
dataset, env = d3rlpy.datasets.get_cartpole()

In [6]:
AUGMENTATIONS = ["gaussian", "mixup"]

In [7]:
augmented_dataset = MDPDatasetAugmented.from_mdpdataset(dataset, augmentations=AUGMENTATIONS)

In [8]:
assert augmented_dataset.observations.shape[0] == dataset.observations.shape[0] * (1 + len(AUGMENTATIONS))

In [9]:
agent = d3rlpy.algos.DiscreteCQL(use_gpu=False)

agent.fit(
    dataset,
    eval_episodes=dataset,
    n_epochs=2,
    scorers={
        'td_error': td_error_scorer,
        'value_scale': average_value_estimation_scorer,
        'environment': evaluate_on_environment(env)
    }
)

2022-05-21 20:00.37 [debug    ] RoundIterator is selected.
2022-05-21 20:00.37 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20220521200037
2022-05-21 20:00.37 [debug    ] Building models...
2022-05-21 20:00.37 [debug    ] Models have been built.
2022-05-21 20:00.37 [info     ] Parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200037/params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DiscreteCQL', 'observation

Epoch 1/2:   0%|          | 0/3116 [00:00<?, ?it/s]

2022-05-21 20:01.03 [info     ] DiscreteCQL_20220521200037: epoch=1 step=3116 epoch=1 metrics={'time_sample_batch': 0.00013498233126737035, 'time_algorithm_update': 0.005683536783568513, 'loss': 0.6817524252753508, 'time_step': 0.005947953164194913, 'td_error': 1.188272377368648, 'value_scale': 1.0975938445805873, 'environment': 181.3} step=3116
2022-05-21 20:01.03 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200037/model_3116.pt


Epoch 2/2:   0%|          | 0/3116 [00:00<?, ?it/s]

2022-05-21 20:01.29 [info     ] DiscreteCQL_20220521200037: epoch=2 step=6232 epoch=2 metrics={'time_sample_batch': 0.00013021533694285636, 'time_algorithm_update': 0.005531626962727852, 'loss': 0.6656075393893446, 'time_step': 0.005779766531443565, 'td_error': 1.165678379460701, 'value_scale': 1.0862695665836037, 'environment': 200.0} step=6232
2022-05-21 20:01.29 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200037/model_6232.pt


[(1,
  {'time_sample_batch': 0.00013498233126737035,
   'time_algorithm_update': 0.005683536783568513,
   'loss': 0.6817524252753508,
   'time_step': 0.005947953164194913,
   'td_error': 1.188272377368648,
   'value_scale': 1.0975938445805873,
   'environment': 181.3}),
 (2,
  {'time_sample_batch': 0.00013021533694285636,
   'time_algorithm_update': 0.005531626962727852,
   'loss': 0.6656075393893446,
   'time_step': 0.005779766531443565,
   'td_error': 1.165678379460701,
   'value_scale': 1.0862695665836037,
   'environment': 200.0})]

In [10]:
augmented_agent = d3rlpy.algos.DiscreteCQL(use_gpu=False)

augmented_agent.fit(
    augmented_dataset,
    eval_episodes=augmented_dataset,
    n_epochs=2,
    scorers={
        'td_error': td_error_scorer,
        'value_scale': average_value_estimation_scorer,
        'environment': evaluate_on_environment(env)
    }
)

2022-05-21 20:01.30 [debug    ] RoundIterator is selected.
2022-05-21 20:01.30 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20220521200130
2022-05-21 20:01.30 [debug    ] Building models...
2022-05-21 20:01.30 [debug    ] Models have been built.
2022-05-21 20:01.30 [info     ] Parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200130/params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DiscreteCQL', 'observation

Epoch 1/2:   0%|          | 0/9349 [00:00<?, ?it/s]

2022-05-21 20:02.50 [info     ] DiscreteCQL_20220521200130: epoch=1 step=9349 epoch=1 metrics={'time_sample_batch': 0.00013286542836351566, 'time_algorithm_update': 0.005554967008691279, 'loss': 0.6820018052710533, 'time_step': 0.0059076453392640045, 'td_error': 1.103101697128133, 'value_scale': 2.043517615620176, 'environment': 200.0} step=9349
2022-05-21 20:02.51 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200130/model_9349.pt


Epoch 2/2:   0%|          | 0/9349 [00:00<?, ?it/s]

2022-05-21 20:04.32 [info     ] DiscreteCQL_20220521200130: epoch=2 step=18698 epoch=2 metrics={'time_sample_batch': 0.00015564132647178226, 'time_algorithm_update': 0.007047073919386236, 'loss': 0.688243582951524, 'time_step': 0.007468992432804334, 'td_error': 1.1286719410555313, 'value_scale': 3.09165765321609, 'environment': 200.0} step=18698
2022-05-21 20:04.32 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220521200130/model_18698.pt


[(1,
  {'time_sample_batch': 0.00013286542836351566,
   'time_algorithm_update': 0.005554967008691279,
   'loss': 0.6820018052710533,
   'time_step': 0.0059076453392640045,
   'td_error': 1.103101697128133,
   'value_scale': 2.043517615620176,
   'environment': 200.0}),
 (2,
  {'time_sample_batch': 0.00015564132647178226,
   'time_algorithm_update': 0.007047073919386236,
   'loss': 0.688243582951524,
   'time_step': 0.007468992432804334,
   'td_error': 1.1286719410555313,
   'value_scale': 3.09165765321609,
   'environment': 200.0})]