In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("..")

In [2]:
import torch
import pandas as pd
import d3rlpy
import numpy as np

import warnings

warnings.filterwarnings("ignore")

from augrl.augmentations import exrp
from augrl.augmentations.spin_cartpole.cartpole import CartPoleEnv

In [3]:
env = CartPoleEnv(game=False)

## Train on manually collected data
Because of what the model is trying to approximate it's hard to directly evaluate.

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
device

device(type='cpu')

In [5]:
predictor = exrp.ExplicitRewardPredictor.from_env(env, device, lr=1e-4)

In [6]:
user_rollouts = pd.read_pickle("augrl/augmentations/spin_cartpole/preferences/handmade_results_150.pickle")
segments = user_rollouts["segment"].max() + 1
eval_segments = np.random.choice(segments, size=int(0.1 * segments))

train_rollouts = user_rollouts[~user_rollouts["segment"].isin(eval_segments)]
eval_rollouts = user_rollouts[user_rollouts["segment"].isin(eval_segments)]
user_rollouts_spin = train_rollouts[train_rollouts["preference"] == 1]
user_rollouts_no_spin = train_rollouts[train_rollouts["preference"] == 0]
user_rollouts_spin_eval = eval_rollouts[eval_rollouts["preference"] == 1]
user_rollouts_no_spin_eval = eval_rollouts[eval_rollouts["preference"] == 0]

print("Training on {} segments ({} with spin, {} without)".format(len(train_rollouts), len(user_rollouts_spin), len(user_rollouts_no_spin)))
print("Evaluating on {} segments ({} with spin, {} without)".format(len(eval_rollouts), len(user_rollouts_spin_eval), len(user_rollouts_no_spin_eval)))

obs_spin = torch.tensor(list(user_rollouts_spin["state"].values), dtype=torch.float32)
act_spin = torch.tensor(list(user_rollouts_spin["action"].values), dtype=torch.float32)
obs_no_spin = torch.tensor(list(user_rollouts_no_spin["state"].values), dtype=torch.float32)
act_no_spin = torch.tensor(list(user_rollouts_no_spin["action"].values), dtype=torch.float32)

obs_spin_eval = torch.tensor(list(user_rollouts_spin_eval["state"].values), dtype=torch.float32)
act_spin_eval = torch.tensor(list(user_rollouts_spin_eval["action"].values), dtype=torch.float32)
obs_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["state"].values), dtype=torch.float32)
act_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["action"].values), dtype=torch.float32)

Training on 433 segments (224 with spin, 209 without)
Evaluating on 48 segments (24 with spin, 24 without)


### Collect segment tuples with preference

In [7]:
segments = exrp.get_segments(obs_spin, act_spin, obs_no_spin, act_no_spin)
segments_eval = exrp.get_segments(obs_spin_eval, act_spin_eval, obs_no_spin_eval, act_no_spin_eval)
print("Training on {} actual sequences".format(len(segments["obs_left"])))
print("Evaluating on {} actual sequences".format(len(segments_eval["obs_left"])))

Training on 46816 actual sequences
Evaluating on 576 actual sequences


### Train

In [None]:
predictor.train(segments, show_pregress=True, epochs=50)

# Evaluate

In [8]:
acc = sum(predictor.prefer(segments_eval) - segments_eval["preferences"] == 0) / len(segments_eval["preferences"])
print("Predicting preferences with an accuracy of {:.2f}%".format(100 * acc))

Predicting preferences with an accuracy of 47.74%


## Generate exploratory data

In [9]:
env = CartPoleEnv(game=False, reward_fn=predictor)

In [10]:
dqn = d3rlpy.algos.DQN(
    batch_size=32,
    learning_rate=2.5e-4,
    target_update_interval=100,
    use_gpu=False,
)

dqn.build_with_env(env)
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)

explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(0.3)

In [None]:
dqn.fit_online(
    env,
    buffer,
    explorer,
    n_steps=10000,
    eval_env=env,
    n_steps_per_epoch=1000,
    update_start_step=1000,
)

In [21]:
env.reset()
while True:
    action = np.random.choice([1, 0])
    if action in (0, 1):
        state, reward, terminal, _, _ = env.step(action)
        env.render()
        if terminal:
            break