In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("..")

In [2]:
import torch
import pandas as pd

import numpy as np

import warnings

warnings.filterwarnings("ignore")

from augrl.augmentations import exrp
from augrl.augmentations.spin_cartpole.cartpole import CartPoleEnv

In [3]:
env = CartPoleEnv(game=False)

## Train on manually collected data
Because of what the model is trying to approximate it's hard to directly evaluate.

In [4]:
predictor = exrp.ExplicitRewardPredictor.from_env(env)

In [92]:
user_rollouts = pd.read_pickle("augrl/augmentations/handmade_results_40.pickle")
segments = user_rollouts["segment"].max() + 1
eval_segments = np.random.choice(segments, size=int(0.1 * segments))

train_rollouts = user_rollouts[~user_rollouts["segment"].isin(eval_segments)]
eval_rollouts = user_rollouts[user_rollouts["segment"].isin(eval_segments)]
user_rollouts_spin = train_rollouts[train_rollouts["preference"] == 1]
user_rollouts_no_spin = train_rollouts[train_rollouts["preference"] == 0]
user_rollouts_spin_eval = eval_rollouts[eval_rollouts["preference"] == 1]
user_rollouts_no_spin_eval = eval_rollouts[eval_rollouts["preference"] == 0]

print("Training on {} segments ({} with spin, {} without)".format(len(train_rollouts), len(user_rollouts_spin), len(user_rollouts_no_spin)))
print("Evaluating on {} segments ({} with spin, {} without)".format(len(eval_rollouts), len(user_rollouts_spin_eval), len(user_rollouts_no_spin_eval)))

obs_spin = torch.tensor(list(user_rollouts_spin["state"].values), dtype=torch.float32)
act_spin = torch.tensor(list(user_rollouts_spin["action"].values), dtype=torch.float32)
obs_no_spin = torch.tensor(list(user_rollouts_no_spin["state"].values), dtype=torch.float32)
act_no_spin = torch.tensor(list(user_rollouts_no_spin["action"].values), dtype=torch.float32)

obs_spin_eval = torch.tensor(list(user_rollouts_spin_eval["state"].values), dtype=torch.float32)
act_spin_eval = torch.tensor(list(user_rollouts_spin_eval["action"].values), dtype=torch.float32)
obs_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["state"].values), dtype=torch.float32)
act_no_spin_eval = torch.tensor(list(user_rollouts_no_spin_eval["action"].values), dtype=torch.float32)

Training on 145 segments (76 with spin, 69 without)
Evaluating on 13 segments (5 with spin, 8 without)


### Collect segment tuples with preference

In [77]:
segments = exrp.get_segments(obs_spin, act_spin, obs_no_spin, act_no_spin)
segments_eval = exrp.get_segments(obs_spin_eval, act_spin_eval, obs_no_spin_eval, act_no_spin_eval)
print("Training on {} actual sequences".format(len(segments["obs_left"])))
print("Evaluating on {} actual sequences".format(len(segments_eval["obs_left"])))

Training on 5256 actual sequences
Evaluating on 40 actual sequences


### Train

In [None]:
predictor.train(segments, show_pregress=True, epochs=50)

# Evaluate

In [60]:
acc = sum(predictor.prefer(segments_eval) - segments_eval["preferences"] == 0) / len(segments_eval["preferences"])
print("Predicting preferences with an accuracy of {:.2f}%".format(100 * acc))

Predicting preferences with an accuracy of 75.00%


## Generate exploratory data

In [None]:
env.reset()
while True:
    action = np.random.choice([1, 0])
    if action in (0, 1):
        state, _, terminal, _, _ = env.step(action)
        reward = predictor(torch.tensor(state, dtype=torch.float32).reshape(1, *state.shape), torch.tensor(action, dtype=torch.float32).reshape(1, *action.shape))
        print(reward)
        if terminal:
            break