In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("..")

In [2]:
import torch
import pandas as pd
import d3rlpy
import numpy as np

import warnings

warnings.filterwarnings("ignore")

from augrl.augmentations import exrp
from augrl.augmentations.spin_cartpole.cartpole import CartPoleEnv

In [3]:
env = CartPoleEnv(game=False) # parameter update starts after 1K steps

## Train on manually collected data
Because of what the model is trying to approximate it's hard to directly evaluate.

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
device

device(type='cuda', index=0)

In [5]:
predictor = exrp.ExplicitRewardPredictor.from_env(env, device, lr=5e-5)

In [None]:
user_rollouts = pd.read_pickle("augrl/augmentations/spin_cartpole/preferences/handmade_results_170.pickle")
segments = user_rollouts["segment"].max() + 1
eval_segments = np.random.choice(segments, size=int(0.1 * segments))

train_rollouts = user_rollouts[~user_rollouts["segment"].isin(eval_segments)]
eval_rollouts = user_rollouts[user_rollouts["segment"].isin(eval_segments)]
user_rollouts_spin = train_rollouts[train_rollouts["preference"] == 1]
user_rollouts_no_spin = train_rollouts[train_rollouts["preference"] == 0]
user_rollouts_spin_eval = eval_rollouts[eval_rollouts["preference"] == 1]
user_rollouts_no_spin_eval = eval_rollouts[eval_rollouts["preference"] == 0]

print("Training on {} segments ({} with spin, {} without)".format(len(train_rollouts), len(user_rollouts_spin), len(user_rollouts_no_spin)))
print("Evaluating on {} segments ({} with spin, {} without)".format(len(eval_rollouts), len(user_rollouts_spin_eval), len(user_rollouts_no_spin_eval)))

obs_spin = torch.tensor(list(user_rollouts_spin["state"].values), dtype=torch.float32)
act_spin = torch.tensor(list(user_rollouts_spin["action"].values), dtype=torch.float32)
obs_no_spin = torch.tensor(list(user_rollouts_no_spin["state"].values), dtype=torch.float32)
act_no_spin = torch.tensor(list(user_rollouts_no_spin["action"].values), dtype=torch.float32)

obs_spin_eval = torch.tensor(list(user_rollouts_spin_eval["state"].values), dtype=torch.float32)
act_spin_eval = torch.tensor(list(user_rollouts_spin_eval["action"].values), dtype=torch.float32)
obs_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["state"].values), dtype=torch.float32)
act_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["action"].values), dtype=torch.float32)

### Collect segment tuples with preference

In [None]:
segments = exrp.get_segments(obs_spin, act_spin, obs_no_spin, act_no_spin)
segments_eval = exrp.get_segments(obs_spin_eval, act_spin_eval, obs_no_spin_eval, act_no_spin_eval)
print("Training on {} actual sequences".format(len(segments["obs_left"])))
print("Evaluating on {} actual sequences".format(len(segments_eval["obs_left"])))

### Train

In [None]:
predictor.train(segments, show_pregress=True, epochs=6, batch_size=32)

# Evaluate

In [None]:
acc = sum(predictor.prefer(segments_eval) - segments_eval["preferences"] == 0) / len(segments_eval["preferences"])
print("Predicting preferences with an accuracy of {:.2f}%".format(100 * acc))

In [None]:
predictor.save("augrl/augmentations/spin_cartpole/predictors/predictor_170.pt")

In [6]:
predictor.load("augrl/augmentations/spin_cartpole/predictors/predictor_150.pt")

## Try in online RL

In [7]:
import d3rlpy

from IPython.display import Video
from gym.wrappers import RecordVideo
import pygame

def eval_with_env(env, algo):
    s = env.reset()
    reward = 0
    steps = 0
    while True:
        a = algo.predict(s.reshape(1, *s.shape))
        s, r, done, _ = env.step(a[0])
        reward += r
        steps += 1
        env.render()
        if done: break
    print("{:.2f} {}".format(reward, steps))

class Spec:
    def __init__(self):
        self.id = ""
        self.kwargs = {}

screen = pygame.display.set_mode((600, 400))

In [8]:
env_predictor = CartPoleEnv(game=False, reward_fn=predictor, timeout=500)
env_predictor.spec = Spec()
env_predictor_ = RecordVideo(env_predictor, "./videos",  name_prefix="DDQN", episode_trigger= lambda x: True)

In [17]:
def reward_angle(s, a):
    return 0.9 * abs(s[3]) - 0.1 * abs(s[0])

env_analytical = CartPoleEnv(game=False, reward_fn=reward_angle, timeout=500)
env_analytical.spec = Spec()
env_analytical_ = RecordVideo(env_analytical, "./videos",  name_prefix="DDQN", episode_trigger= lambda x: True)

### Learned reward

In [18]:
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=20000, env=env_predictor)
explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(0.3)
agent_preferences = d3rlpy.algos.DoubleDQN(
    learning_rate=2e-4,
    target_update_interval=1000,
    use_gpu=True,
)
agent_preferences.build_with_env(env_predictor)

In [19]:
agent_preferences.fit_online(
    env_predictor,
    buffer,
    explorer,
    n_steps=100000,
    eval_env=None,
    n_steps_per_epoch=1000,
    update_start_step=1000,
    show_progress=False
)

2022-08-21 10:50.49 [info     ] Directory is created at d3rlpy_logs/DoubleDQN_online_20220821105049
2022-08-21 10:50.49 [info     ] Parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 0.0002, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 1000, 'use_gpu': 0, 'algorithm': 'DoubleDQN', 'observation_shape': (4,), 'action_size': 2}
2022-08-21 10:50.50 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_1000.pt
2022-08-21 10:50.50 [inf

2022-08-21 10:51.55 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_16000.pt
2022-08-21 10:51.55 [info     ] DoubleDQN_online_20220821105049: epoch=16 step=16000 epoch=16 metrics={'time_inference': 0.0005589728355407715, 'time_environment_step': 0.0005984787940979004, 'time_sample_batch': 0.00010277080535888672, 'time_algorithm_update': 0.0029579751491546633, 'loss': 0.0029158688566676572, 'time_step': 0.0042724616527557375, 'rollout_return': 6.297490271251826} step=16000
2022-08-21 10:51.59 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_17000.pt
2022-08-21 10:51.59 [info     ] DoubleDQN_online_20220821105049: epoch=17 step=17000 epoch=17 metrics={'time_inference': 0.0005638575553894042, 'time_environment_step': 0.0006029458045959473, 'time_sample_batch': 0.00010447454452514648, 'time_algorithm_update': 0.0029540801048278807, 'loss': 0.0035737427226704313, 'time_step': 0.004279881000518799, 'rollou

2022-08-21 10:53.03 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_32000.pt
2022-08-21 10:53.03 [info     ] DoubleDQN_online_20220821105049: epoch=32 step=32000 epoch=32 metrics={'time_inference': 0.000558551549911499, 'time_environment_step': 0.0005982522964477539, 'time_sample_batch': 0.00010383224487304688, 'time_algorithm_update': 0.0029537267684936522, 'loss': 0.00536268044525059, 'time_step': 0.004269518136978149, 'rollout_return': 9.130738861278846} step=32000
2022-08-21 10:53.08 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_33000.pt
2022-08-21 10:53.08 [info     ] DoubleDQN_online_20220821105049: epoch=33 step=33000 epoch=33 metrics={'time_inference': 0.0005596327781677246, 'time_environment_step': 0.000600313663482666, 'time_sample_batch': 0.00010341119766235352, 'time_algorithm_update': 0.0029552578926086426, 'loss': 0.005311725050967653, 'time_step': 0.004273883819580078, 'rollout_retu

2022-08-21 10:54.12 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_48000.pt
2022-08-21 10:54.12 [info     ] DoubleDQN_online_20220821105049: epoch=48 step=48000 epoch=48 metrics={'time_inference': 0.000559701919555664, 'time_environment_step': 0.0006012797355651855, 'time_sample_batch': 0.00010428357124328613, 'time_algorithm_update': 0.0029633378982543944, 'loss': 0.011227703651296906, 'time_step': 0.004283747434616089, 'rollout_return': 17.332793607301678} step=48000
2022-08-21 10:54.17 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_49000.pt
2022-08-21 10:54.17 [info     ] DoubleDQN_online_20220821105049: epoch=49 step=49000 epoch=49 metrics={'time_inference': 0.0005598194599151611, 'time_environment_step': 0.0006026773452758789, 'time_sample_batch': 0.00010479331016540528, 'time_algorithm_update': 0.002936204195022583, 'loss': 0.012890483717201277, 'time_step': 0.004258781433105469, 'rollout_re

2022-08-21 10:55.21 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_64000.pt
2022-08-21 10:55.21 [info     ] DoubleDQN_online_20220821105049: epoch=64 step=64000 epoch=64 metrics={'time_inference': 0.000557593822479248, 'time_environment_step': 0.0005993402004241943, 'time_sample_batch': 0.00010387063026428223, 'time_algorithm_update': 0.0029408652782440185, 'loss': 0.0324309671621304, 'time_step': 0.004256298542022705, 'rollout_return': 26.04090190694842} step=64000
2022-08-21 10:55.25 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_65000.pt
2022-08-21 10:55.25 [info     ] DoubleDQN_online_20220821105049: epoch=65 step=65000 epoch=65 metrics={'time_inference': 0.0005553507804870605, 'time_environment_step': 0.0005973711013793945, 'time_sample_batch': 0.00010346698760986328, 'time_algorithm_update': 0.002944708824157715, 'loss': 0.03729680099850521, 'time_step': 0.00425477385520935, 'rollout_return'

2022-08-21 10:56.30 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_80000.pt
2022-08-21 10:56.30 [info     ] DoubleDQN_online_20220821105049: epoch=80 step=80000 epoch=80 metrics={'time_inference': 0.000562077283859253, 'time_environment_step': 0.0006046087741851807, 'time_sample_batch': 0.00010488224029541016, 'time_algorithm_update': 0.002969735860824585, 'loss': 0.040138979281298816, 'time_step': 0.004296490430831909, 'rollout_return': 8.34275312673417} step=80000
2022-08-21 10:56.34 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_81000.pt
2022-08-21 10:56.34 [info     ] DoubleDQN_online_20220821105049: epoch=81 step=81000 epoch=81 metrics={'time_inference': 0.0005609242916107177, 'time_environment_step': 0.0006036829948425293, 'time_sample_batch': 0.00010461139678955078, 'time_algorithm_update': 0.0029692089557647704, 'loss': 0.0428991912426427, 'time_step': 0.00429333758354187, 'rollout_return'

2022-08-21 10:57.39 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_96000.pt
2022-08-21 10:57.39 [info     ] DoubleDQN_online_20220821105049: epoch=96 step=96000 epoch=96 metrics={'time_inference': 0.0005596742630004883, 'time_environment_step': 0.0006037428379058838, 'time_sample_batch': 0.00010488486289978027, 'time_algorithm_update': 0.0029621007442474363, 'loss': 0.04846719340514392, 'time_step': 0.004285364627838135, 'rollout_return': 17.361455139929603} step=96000
2022-08-21 10:57.43 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105049/model_97000.pt
2022-08-21 10:57.43 [info     ] DoubleDQN_online_20220821105049: epoch=97 step=97000 epoch=97 metrics={'time_inference': 0.0005624465942382813, 'time_environment_step': 0.000605431318283081, 'time_sample_batch': 0.00010502958297729492, 'time_algorithm_update': 0.002973371744155884, 'loss': 0.046504643791355195, 'time_step': 0.004300935506820679, 'rollout_ret

In [28]:
agent_preferences.load_model("d3rlpy_logs/DoubleDQN_online_20220821105049/model_63000.pt")
eval_with_env(env_analytical_, agent_preferences)

713.99 212


In [21]:
os.system("ls -l videos")

total 40
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2158 Aug 21 10:57 DDQN-episode-0.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14154 Aug 21 10:57 DDQN-episode-0.mp4
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2164 Aug 21 10:45 DDQN-episode-67.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14653 Aug 21 10:45 DDQN-episode-67.mp4


0

In [22]:
Video("videos/DDQN-episode-0.mp4", embed=True)

### Handcrafted reward

In [23]:
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=20000, env=env_analytical)
explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(0.3)
agent_analytical = d3rlpy.algos.DoubleDQN(
    learning_rate=2e-4,
    target_update_interval=1000,
    use_gpu=True,
)
agent_analytical.build_with_env(env_analytical)

In [24]:
agent_analytical.fit_online(
    env_analytical,
    buffer,
    explorer,
    n_steps=100000,
    eval_env=None,
    n_steps_per_epoch=1000,
    update_start_step=1000,
    show_progress=False
)

2022-08-21 10:57.56 [info     ] Directory is created at d3rlpy_logs/DoubleDQN_online_20220821105756
2022-08-21 10:57.56 [info     ] Parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 0.0002, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 1000, 'use_gpu': 0, 'algorithm': 'DoubleDQN', 'observation_shape': (4,), 'action_size': 2}
2022-08-21 10:57.57 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_1000.pt
2022-08-21 10:57.57 [inf

2022-08-21 10:58.52 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_16000.pt
2022-08-21 10:58.52 [info     ] DoubleDQN_online_20220821105756: epoch=16 step=16000 epoch=16 metrics={'time_inference': 0.0005641376972198486, 'time_environment_step': 3.7349462509155274e-05, 'time_sample_batch': 8.645987510681153e-05, 'time_algorithm_update': 0.0029731197357177734, 'loss': 2.090922241896391, 'time_step': 0.003707923412322998, 'rollout_return': 319.04194307733394} step=16000
2022-08-21 10:58.56 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_17000.pt
2022-08-21 10:58.56 [info     ] DoubleDQN_online_20220821105756: epoch=17 step=17000 epoch=17 metrics={'time_inference': 0.0005614445209503174, 'time_environment_step': 3.707098960876465e-05, 'time_sample_batch': 8.677411079406738e-05, 'time_algorithm_update': 0.0029900155067443847, 'loss': 2.1412644590735437, 'time_step': 0.003721771240234375, 'rollout_return

2022-08-21 10:59.52 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_32000.pt
2022-08-21 10:59.52 [info     ] DoubleDQN_online_20220821105756: epoch=32 step=32000 epoch=32 metrics={'time_inference': 0.0005616354942321778, 'time_environment_step': 3.7314891815185545e-05, 'time_sample_batch': 8.638858795166015e-05, 'time_algorithm_update': 0.0029529454708099364, 'loss': 3.517882876634598, 'time_step': 0.0036857998371124267, 'rollout_return': 516.0958866566303} step=32000
2022-08-21 10:59.55 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_33000.pt
2022-08-21 10:59.55 [info     ] DoubleDQN_online_20220821105756: epoch=33 step=33000 epoch=33 metrics={'time_inference': 0.0005551528930664063, 'time_environment_step': 3.698253631591797e-05, 'time_sample_batch': 8.681917190551758e-05, 'time_algorithm_update': 0.0029383742809295653, 'loss': 3.5735247147083284, 'time_step': 0.003664762020111084, 'rollout_return

2022-08-21 11:00.51 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_48000.pt
2022-08-21 11:00.51 [info     ] DoubleDQN_online_20220821105756: epoch=48 step=48000 epoch=48 metrics={'time_inference': 0.0005593857765197754, 'time_environment_step': 3.679800033569336e-05, 'time_sample_batch': 8.644199371337891e-05, 'time_algorithm_update': 0.002937267780303955, 'loss': 4.4131822943687435, 'time_step': 0.0036674065589904787, 'rollout_return': 606.4699025945483} step=48000
2022-08-21 11:00.54 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_49000.pt
2022-08-21 11:00.54 [info     ] DoubleDQN_online_20220821105756: epoch=49 step=49000 epoch=49 metrics={'time_inference': 0.0005605680942535401, 'time_environment_step': 3.69417667388916e-05, 'time_sample_batch': 8.701062202453614e-05, 'time_algorithm_update': 0.0029446520805358886, 'loss': 4.468394647121429, 'time_step': 0.003677037000656128, 'rollout_return': 

2022-08-21 11:01.50 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_64000.pt
2022-08-21 11:01.50 [info     ] DoubleDQN_online_20220821105756: epoch=64 step=64000 epoch=64 metrics={'time_inference': 0.0005569941997528076, 'time_environment_step': 3.667187690734863e-05, 'time_sample_batch': 8.47625732421875e-05, 'time_algorithm_update': 0.002932533264160156, 'loss': 5.350506774961948, 'time_step': 0.0036580235958099364, 'rollout_return': 469.224211943564} step=64000
2022-08-21 11:01.53 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_65000.pt
2022-08-21 11:01.53 [info     ] DoubleDQN_online_20220821105756: epoch=65 step=65000 epoch=65 metrics={'time_inference': 0.0005518445968627929, 'time_environment_step': 3.6298036575317384e-05, 'time_sample_batch': 8.467674255371094e-05, 'time_algorithm_update': 0.002913485050201416, 'loss': 5.301417937278748, 'time_step': 0.0036331439018249512, 'rollout_return': 4

2022-08-21 11:02.49 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_80000.pt
2022-08-21 11:02.49 [info     ] DoubleDQN_online_20220821105756: epoch=80 step=80000 epoch=80 metrics={'time_inference': 0.0005618197917938233, 'time_environment_step': 3.7165641784667965e-05, 'time_sample_batch': 8.6273193359375e-05, 'time_algorithm_update': 0.0029104382991790773, 'loss': 5.174024652481079, 'time_step': 0.003643324136734009, 'rollout_return': 478.4448313051121} step=80000
2022-08-21 11:02.52 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_81000.pt
2022-08-21 11:02.52 [info     ] DoubleDQN_online_20220821105756: epoch=81 step=81000 epoch=81 metrics={'time_inference': 0.0005602607727050781, 'time_environment_step': 3.752493858337402e-05, 'time_sample_batch': 8.670544624328613e-05, 'time_algorithm_update': 0.0030303316116333007, 'loss': 5.21940882652998, 'time_step': 0.0037621424198150634, 'rollout_return': 5

2022-08-21 11:03.48 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_96000.pt
2022-08-21 11:03.48 [info     ] DoubleDQN_online_20220821105756: epoch=96 step=96000 epoch=96 metrics={'time_inference': 0.0005582706928253173, 'time_environment_step': 3.664088249206543e-05, 'time_sample_batch': 8.562850952148437e-05, 'time_algorithm_update': 0.002939398765563965, 'loss': 6.945534284353256, 'time_step': 0.0036667945384979248, 'rollout_return': 401.1051525186183} step=96000
2022-08-21 11:03.51 [info     ] Model parameters are saved to d3rlpy_logs/DoubleDQN_online_20220821105756/model_97000.pt
2022-08-21 11:03.51 [info     ] DoubleDQN_online_20220821105756: epoch=97 step=97000 epoch=97 metrics={'time_inference': 0.00056325364112854, 'time_environment_step': 3.726911544799805e-05, 'time_sample_batch': 8.691596984863282e-05, 'time_algorithm_update': 0.0029537584781646727, 'loss': 7.005337695300579, 'time_step': 0.0036888298988342286, 'rollout_return': 5

In [33]:
agent_preferences.load_model("d3rlpy_logs/DoubleDQN_online_20220821105756/model_71000.pt")
eval_with_env(env_analytical, agent_analytical)

416.71 152


In [26]:
os.system("ls -l videos")

total 40
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2158 Aug 21 10:57 DDQN-episode-0.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14154 Aug 21 10:57 DDQN-episode-0.mp4
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2164 Aug 21 10:45 DDQN-episode-67.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14653 Aug 21 10:45 DDQN-episode-67.mp4


0

In [27]:
Video("videos/DDQN-episode-1.mp4", embed=True)

## Try in offline RL

### Build the dataset

In [34]:
N_STEPS = 10000
env_predictor_ = RecordVideo(env_predictor, "./videos",  name_prefix="CQL", episode_trigger= lambda x: True)
env_analytical_ = RecordVideo(env_analytical, "./videos",  name_prefix="CQL", episode_trigger= lambda x: True)

In [35]:
ds_buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=N_STEPS, env=env_predictor)
# use trained policy from before
agent_preferences.collect(env_predictor, ds_buffer, n_steps=N_STEPS, show_progress=False)

dataset_learned = ds_buffer.to_mdp_dataset()
dataset_learned.dump("augrl/augmentations/spin_cartpole/datasets/ddqn_150_medium_learned_dataset.h5")



In [36]:
ds_buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=N_STEPS, env=env_analytical)
# use trained policy from before
agent_preferences.collect(env_analytical, ds_buffer, n_steps=N_STEPS, show_progress=False)

dataset_analytical = ds_buffer.to_mdp_dataset()
dataset_analytical.dump("augrl/augmentations/spin_cartpole/datasets/ddqn_medium_designed_dataset.h5")



In [37]:
dataset_learned = d3rlpy.datasets.MDPDataset.load("augrl/augmentations/spin_cartpole/datasets/ddqn_150_medium_learned_dataset.h5")
dataset_learned.rewards.shape

(9999,)

In [38]:
cql_preferences = d3rlpy.algos.DiscreteCQL(use_gpu=True)

_ = cql_preferences.fit(
    dataset_learned,
    n_epochs=8,
    show_progress=False,
    verbose=False
)

2022-08-21 11:07.59 [debug    ] RoundIterator is selected.
2022-08-21 11:07.59 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20220821110759
2022-08-21 11:07.59 [debug    ] Building models...
2022-08-21 11:07.59 [debug    ] Models have been built.
2022-08-21 11:08.01 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_312.pt
2022-08-21 11:08.02 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_624.pt
2022-08-21 11:08.04 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_936.pt
2022-08-21 11:08.05 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_1248.pt
2022-08-21 11:08.06 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_1560.pt
2022-08-21 11:08.08 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110759/model_1872.pt
2022-08-21 11:08.09 [info     ] Model parameters are

In [39]:
os.system("ls -l videos")

total 88
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2158 Aug 21 10:57 DDQN-episode-0.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14154 Aug 21 10:57 DDQN-episode-0.mp4
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2159 Aug 21 11:04 DDQN-episode-1.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 42536 Aug 21 11:04 DDQN-episode-1.mp4
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti  2164 Aug 21 10:45 DDQN-episode-67.meta.json
-rw-rw-r-- 1 gianluca99galletti gianluca99galletti 14653 Aug 21 10:45 DDQN-episode-67.mp4


0

In [40]:
eval_with_env(env_analytical_, cql_preferences)
Video("videos/CQL-episode-0.mp4", embed=True)

441.68 100


In [41]:
dataset_analytical = d3rlpy.datasets.MDPDataset.load("augrl/augmentations/spin_cartpole/datasets/ddqn_medium_designed_dataset.h5")
dataset_analytical.rewards.shape

(9999,)

In [42]:
cql_analytical = d3rlpy.algos.DiscreteCQL(use_gpu=True)

_ = cql_analytical.fit(
    dataset_analytical,
    n_epochs=8,
    show_progress=False,
    verbose=False
)

2022-08-21 11:08.36 [debug    ] RoundIterator is selected.
2022-08-21 11:08.36 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20220821110836
2022-08-21 11:08.36 [debug    ] Building models...
2022-08-21 11:08.36 [debug    ] Models have been built.
2022-08-21 11:08.37 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_312.pt
2022-08-21 11:08.39 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_624.pt
2022-08-21 11:08.40 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_936.pt
2022-08-21 11:08.41 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_1248.pt
2022-08-21 11:08.43 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_1560.pt
2022-08-21 11:08.44 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20220821110836/model_1872.pt
2022-08-21 11:08.46 [info     ] Model parameters are

In [43]:
eval_with_env(env_analytical_, cql_analytical)
Video("videos/CQL-episode-1.mp4", embed=True)

457.15 146
