In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from stealing_gridworld import StealingGridworld
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = {
    "environment": {
        "grid_size": 5,
        "horizon": 30,
        "reward_for_depositing": 100,
        "reward_for_picking_up": 1,
        "reward_for_stealing": -200,
    },
    "reward_model": {
        "type": "NonImageCnnRewardNet",
        "hid_channels": [32, 32],
        "kernel_size": 3,
    },
    "seed": 0,
    "dataset_max_size": 2000,
    # If fragment_length is None, then the whole trajectory is used as a single fragment.
    "fragment_length": 12,
    "transition_oversampling": 10,
    "initial_epoch_multiplier": 1.0,
    "feedback": {
        "type": "preference",
    },
    "trajectory_generator": {
        "epsilon": 0.1,
    },
    "visibility": {
        "visibility": "partial",
        # Available visibility mask keys:
        # "full": All of the grid is visible. Not actually used, but should be set for easier comparison.
        # "(n-1)x(n-1)": All but the outermost ring of the grid is visible.
        #"visibility_mask_key": "(n-1)x(n-1)",
        "visibility_mask_key": "camera",
    },
    "reward_trainer": {
        "num_epochs": 5,
    },
}


In [4]:
env = StealingGridworld(**config['environment'])

In [5]:
env = StealingGridworld(
    grid_size=config["environment"]["grid_size"],
    horizon=config["environment"]["horizon"],
    reward_for_depositing=config["environment"]["reward_for_depositing"],
    reward_for_picking_up=config["environment"]["reward_for_picking_up"],
    reward_for_stealing=config["environment"]["reward_for_stealing"],
    seed = config["seed"]
)

In [6]:
env.params_string

'gs5_nfp2_nop1_rfd100_rfp1_rfs-200'

In [7]:
from stealing_gridworld import DynamicGridVisibility_OJ

camera = DynamicGridVisibility_OJ(env)

In [14]:
masks = camera.update_visibility(t=40)

for i in masks:
    camera.render_mask(i)

+---+---+---+---+---+
| # | # |   |   |   |
+---+---+---+---+---+
| # | # |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
+---+---+---+---+---+
|   | # | # |   |   |
+---+---+---+---+---+
|   | # | # |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
+---+---+---+---+---+
|   |   | # | # |   |
+---+---+---+---+---+
|   |   | # | # |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
+---+---+---+---+---+
|   |   |   | # | # |
+---+---+---+---+---+
|   |   |   | # | # |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
+---+---+---+---+---+
|   |   | 

In [7]:
import numpy as np
import torch as th
from imitation.util import logger as imit_logger

import wandb
import os
import abc
from evaluate_reward_model import full_visibility_evaluator_factory, partial_visibility_evaluator_factory, camera_visibility_evaluator_factory
from imitation_modules import (
    BasicScalarFeedbackRewardTrainer,
    DeterministicMDPTrajGenerator,
    MSERewardLoss,
    NoisyObservationGathererWrapper,
    NonImageCnnRewardNet,
    RandomSingleFragmenter,
    ScalarFeedbackModel,
    ScalarRewardLearner,
    SyntheticScalarFeedbackGatherer,
)
from imitation_modules import (
    PreferenceComparisons,
    PreferenceModel,
    BasicRewardTrainer,
    CrossEntropyRewardLoss,
    SyntheticGatherer,
    RandomFragmenter,
    PreferenceComparisonNoisyObservationGathererWrapper,
)
import matplotlib.pyplot as plt

from stealing_gridworld import PartialGridVisibility, DynamicGridVisibility

# class ObservationFunction(abc.ABC):
#     """Abstract class for functions that take an observation and return a new observation."""

#     @abc.abstractmethod
#     def __call__(self, fragment):
#         """Returns a new fragment with observations, actions, and rewards filtered through an observation function.

#         Args:
#             fragment: a TrajectoryWithRew object.

#         Returns:
#             A new TrajectoryWithRew object with the same infos and terminal flag, but with the observations, actions,
#             and rewards filtered through the observation function.
#         """

# class DynamicGridVisibility(ObservationFunction):
#     def __init__(self, env: StealingGridworld, pattern=None, feedback="scalar"):
#         super().__init__()
#         self.env = env
#         self.grid_size = env.grid_size
#         self.feedback = feedback
        
#         # Define the pattern of camera movement
#         if pattern is None:
#             self.pattern = self.default_pattern()
#         else:
#             self.pattern = pattern
#         print("Pattern = ", self.pattern)
#         self.pattern_index = 0  # Start at the first position in the pattern

#         # Build the initial visibility mask
#         self.visibility_mask = self.construct_visibility_mask()

#     def default_pattern(self):
#         # Create a default movement pattern for the camera
#         # Example for a 5x5 grid, you may adjust as needed
#         half_size = self.grid_size // 2 + self.grid_size % 2
#         positions = []
#         # for i in range(half_size):
#         #     for j in range(half_size):
#         #         positions.append((i, j))
#         # return positions
#         # in a 3x3 grid, the camera moves in a spiral pattern
#         # the top leftmost corner is (0,0) for both the camera and the grid
#         # thus, the camera's top leftmost corner's pattern is going to be:
#         # (0,0) -> (0,1) -> (1,1) -> (1,0) 

#         # HARDCODED, TODO find a way to generalize this
#         if self.grid_size == 3:
#             positions = [(0,0), (0,1), (1,1), (1,0)]
#         elif self.grid_size == 5:
#             positions = [(0,0), (0,1), (0,2), (1,2), (2,2), (2,1), (2,0), (1,0)]
#         else:
#             raise NotImplementedError("Default pattern not implemented for grid size other than 3x3 or 5x5")
#         return positions

#     def construct_visibility_mask(self):
#         # Build a visibility mask based on the current pattern index
#         mask = np.zeros((self.grid_size, self.grid_size), dtype=np.bool_)
#         left_x, left_y = self.pattern[self.pattern_index]
#         camera_size = self.grid_size // 2 + self.grid_size % 2
        
#         # Calculate bounds of the camera window
#         start_x = left_x
#         end_x = min(left_x + camera_size, self.grid_size)
#         start_y = left_y
#         end_y = min(left_y + camera_size, self.grid_size)
#         print("start_x, end_x, start_y, end_y = ", start_x, end_x, start_y, end_y)
#         mask[start_x:end_x, start_y:end_y] = True
#         return mask

#     def update_visibility(self):
#         # Update the visibility mask for the next timestep
#         self.pattern_index = (self.pattern_index + 1) % len(self.pattern)
#         self.visibility_mask = self.construct_visibility_mask()

#     def __call__(self, fragments):
#         # Apply the current visibility mask to the fragments
#         self.update_visibility()  # Move the camera to the next position
#         return super().__call__(fragments)  # Call the base method to apply the mask

#     def __repr__(self):
#         return f"DynamicGridVisibility(\n    grid_size={self.grid_size},\n    visibility_mask=\n{self.visibility_mask},\n    feedback={self.feedback}\n)"


python(87237) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(87238) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [8]:
#######################################################################################################################
##################################################### Run params ######################################################
#######################################################################################################################


GPU_NUMBER = 0
N_ITER = 40
N_COMPARISONS = 2000
TESTING = True

In [9]:
reward_net = NonImageCnnRewardNet(
    env.observation_space,
    env.action_space,
    hid_channels=config["reward_model"]["hid_channels"],
    kernel_size=config["reward_model"]["kernel_size"],
)

rng = np.random.default_rng(config["seed"])

if GPU_NUMBER is not None:
    device = th.device(f"cuda:{GPU_NUMBER}" if th.cuda.is_available() else "mps" if th.backends.mps.is_available() else 'cpu')
    reward_net.to(device)
    print(f"Reward net on {device}.")

if config["feedback"]["type"] == 'scalar':
    fragmenter = RandomSingleFragmenter(rng=rng)
    gatherer = SyntheticScalarFeedbackGatherer(rng=rng)
else:
    fragmenter = RandomFragmenter(rng=rng)
    gatherer = SyntheticGatherer(rng=rng)

if config["visibility"]["visibility"] == "partial":
    # visibility_mask = construct_visibility_mask(
    #     config["environment"]["grid_size"],
    #     config["visibility"]["visibility_mask_key"],
    # )
    if config["visibility"]["visibility_mask_key"] == "(n-1)x(n-1)":
        observation_function = PartialGridVisibility(env, mask_key = config["visibility"]["visibility_mask_key"], feedback=config["feedback"]["type"])
        print("Debug new observation function: ", observation_function)
        policy_evaluator = partial_visibility_evaluator_factory(observation_function.visibility_mask)
    elif config["visibility"]["visibility_mask_key"] == "camera":
        observation_function = DynamicGridVisibility(env, feedback=config["feedback"]["type"], halt=4)
        print("Debug new observation function: ", observation_function)
        policy_evaluator = camera_visibility_evaluator_factory(observation_function.visibility_mask)

    if config["feedback"]["type"] == 'scalar':
        gatherer = NoisyObservationGathererWrapper(gatherer, observation_function)
    elif config["feedback"]["type"] == 'preference':
        gatherer = PreferenceComparisonNoisyObservationGathererWrapper(gatherer, observation_function)

    #policy_evaluator = partial_visibility_evaluator_factory(observation_function.visibility_mask)


Reward net on mps.
Debug new observation function:  DynamicGridVisibility(pattern=[(0, 0), (0, 1), (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (1, 0)], feedback=preference)


In [None]:
observation_function.reset()

masks = observation_function.update_visibility(t = 10, limits=(10,20))

for m in masks: 
    observation_function.render_mask(m)

In [10]:
if config["feedback"]["type"] == 'scalar':
    feedback_model = ScalarFeedbackModel(model=reward_net)
    reward_trainer = BasicScalarFeedbackRewardTrainer(
        feedback_model=feedback_model,
        loss=MSERewardLoss(),  # Will need to change this for preference learning
        rng=rng,
        epochs=config["reward_trainer"]["num_epochs"],
    )

else:
    feedback_model = PreferenceModel(reward_net)
    reward_trainer = BasicRewardTrainer(
        preference_model=feedback_model,
        loss=CrossEntropyRewardLoss(),
        rng=rng,
        epochs=config["reward_trainer"]["num_epochs"],
    )

### I think that as long as we are in ValueIteration, this can stay like this?
trajectory_generator = DeterministicMDPTrajGenerator(
    reward_fn=reward_net,
    env=env,
    rng=None,  # This doesn't work yet
    epsilon=config["trajectory_generator"]["epsilon"],
)



logger = imit_logger.configure(format_strs=["stdout", "wandb"])


def save_model_params_anwand_dataset_callback(reward_learner):
    data_dir = os.path.join(wandb.run.dir, "saved_reward_models")
    latest_checkpoint_path = os.path.join(data_dir, "latest_checkpoint.pt")
    latest_dataset_path = os.path.join(data_dir, "latest_dataset.pkl")
    checkpoints_dir = os.path.join(data_dir, "checkpoints")
    checkpoint_iter_path = os.path.join(checkpoints_dir, f"model_weights_iter{reward_learner._iteration}.pt")
    dataset_iter_path = os.path.join(checkpoints_dir, f"dataset_iter{reward_learner._iteration}.pkl")

    os.makedirs(checkpoints_dir, exist_ok=True)
    th.save(reward_learner.model.state_dict(), latest_checkpoint_path)
    th.save(reward_learner.model.state_dict(), checkpoint_iter_path)
    reward_learner.dataset.save(latest_dataset_path)
    reward_learner.dataset.save(dataset_iter_path)

if config["feedback"]["type"] == 'scalar':
    reward_learner = ScalarRewardLearner(
        trajectory_generator=trajectory_generator,
        reward_model=reward_net,
        num_iterations=N_ITER,
        fragmenter=fragmenter,
        feedback_gatherer=gatherer,
        feedback_queue_size=config["dataset_max_size"],
        reward_trainer=reward_trainer,
        fragment_length=config["fragment_length"],
        transition_oversampling=config["transition_oversampling"],
        initial_epoch_multiplier=config["initial_epoch_multiplier"],
        policy_evaluator=policy_evaluator,
        custom_logger=logger,
        #callback=save_model_params_and_dataset_callback,
    )

else:
    reward_learner = PreferenceComparisons(
        trajectory_generator=trajectory_generator,
        reward_model=reward_net,
        num_iterations=N_ITER,
        fragmenter=fragmenter,
        preference_gatherer=gatherer,
        comparison_queue_size=config["dataset_max_size"],
        reward_trainer=reward_trainer,
        fragment_length=config["fragment_length"],
        transition_oversampling=config["transition_oversampling"],
        initial_epoch_multiplier=config["initial_epoch_multiplier"],
        initial_comparison_frac=0.1,
        #query_schedule="hyperbolic",
        policy_evaluator=policy_evaluator,
        custom_logger=logger,
    )

python(87243) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Enumerating states: 100%|██████████| 25/25 [00:01<00:00, 12.56it/s]


In [None]:
if config["feedback"]["type"] == 'scalar':
    result = reward_learner.train(
        # Just needs to be bigger then N_ITER * HORIZON. Value iteration doesn't really use this.
        total_timesteps=10 * N_ITER * wandb.config["environment"]["horizon"],
        total_queries=N_COMPARISONS,
    )

else:
    result = reward_learner.train(
        # Just needs to be bigger then N_ITER * HORIZON. Value iteration doesn't really use this.
        total_timesteps=10 * N_ITER * config["environment"]["horizon"],
        total_comparisons=N_COMPARISONS,
        #callback=save_model_params_and_dataset_callback,
    )

In [None]:
vis = observation_function.update_visibility(5)

In [None]:
vis

In [None]:
# Copy of code from experiment ipython notebook

import os

import numpy as np
import torch as th
from imitation.util import logger as imit_logger

import wandb
from evaluate_reward_model import full_visibility_evaluator_factory, partial_visibility_evaluator_factory, camera_visibility_evaluator_factory
from imitation_modules import (
    BasicScalarFeedbackRewardTrainer,
    DeterministicMDPTrajGenerator,
    MSERewardLoss,
    NoisyObservationGathererWrapper,
    NonImageCnnRewardNet,
    RandomSingleFragmenter,
    ScalarFeedbackModel,
    ScalarRewardLearner,
    SyntheticScalarFeedbackGatherer,
)
from imitation_modules import (
    PreferenceComparisons,
    PreferenceModel,
    BasicRewardTrainer,
    CrossEntropyRewardLoss,
    SyntheticGatherer,
    RandomFragmenter,
    PreferenceComparisonNoisyObservationGathererWrapper,
)

from stealing_gridworld import PartialGridVisibility, DynamicGridVisibility, StealingGridworld

#######################################################################################################################
##################################################### Run params ######################################################
#######################################################################################################################


GPU_NUMBER = 0
N_ITER = 40
N_COMPARISONS = 3000
TESTING = True


#######################################################################################################################
##################################################### Expt params #####################################################
#######################################################################################################################


config = {
    "environment": {
        "name": "StealingGridworld",
        "grid_size": 5,
        "horizon": 30,
        "reward_for_depositing": 100,
        "reward_for_picking_up": 1,
        "reward_for_stealing": -200,
    },
    "reward_model": {
        "type": "NonImageCnnRewardNet",
        "hid_channels": [32, 32],
        "kernel_size": 3,
    },
    "seed": 0,
    "dataset_max_size": 3000,
    # If fragment_length is None, then the whole trajectory is used as a single fragment.
    "fragment_length": 12,
    "transition_oversampling": 10,
    "initial_epoch_multiplier": 1.0,
    "feedback": {
        "type": "preference",
    },
    "trajectory_generator": {
        "epsilon": 0.1,
    },
    "visibility": {
        "visibility": "partial",
        # Available visibility mask keys:
        # "full": All of the grid is visible. Not actually used, but should be set for easier comparison.
        # "(n-1)x(n-1)": All but the outermost ring of the grid is visible.
        #"visibility_mask_key": "(n-1)x(n-1)",
        "visibility_mask_key": "camera",
    },
    "reward_trainer": {
        "num_epochs": 3,
    },
}

# Some validation

if config["feedback"]["type"] not in ("scalar", "preference"):
    raise NotImplementedError("Only scalar and preference feedback are supported at the moment.")

if config["visibility"]["visibility"] == "full" and config["visibility"]["visibility_mask_key"] != "full":
    raise ValueError(
        f'If visibility is "full", then visibility mask key must be "full".'
        f'Instead, it is {config["visibility"]["visibility_mask_key"]}.'
    )

if config["visibility"]["visibility"] not in ["full", "partial"]:
    raise ValueError(
        f'Unknown visibility {config["visibility"]["visibility"]}.' f'Visibility must be "full" or "partial".'
    )

if config["reward_model"]["type"] != "NonImageCnnRewardNet":
    raise ValueError(f'Unknown reward model type {config["reward_model"]["type"]}.')

available_visibility_mask_keys = ["full", "(n-1)x(n-1)", "camera"]
if config["visibility"]["visibility_mask_key"] not in available_visibility_mask_keys:
    raise ValueError(
        f'Unknown visibility mask key {config["visibility"]["visibility_mask_key"]}.'
        f"Available visibility mask keys are {available_visibility_mask_keys}."
    )

if config["fragment_length"] == None:
    config["fragment_length"] = config["environment"]["horizon"]
    print("Fragment length unspecified... setting it to ", config["environment"]["horizon"])

wandb.login()
run = wandb.init(
    project="assisting-bounded-humans",
    notes="Testing the preference comparisons model",
    name="Testing camera visibility with preference comparisons",
    tags=[
        "test",
        "Partial Observability"
    ],
    config=config,
    mode="disabled" if TESTING else "online",
)

#######################################################################################################################
################################################## Create everything ##################################################
#######################################################################################################################


env = StealingGridworld(
    grid_size=wandb.config["environment"]["grid_size"],
    horizon=wandb.config["environment"]["horizon"],
    reward_for_depositing=wandb.config["environment"]["reward_for_depositing"],
    reward_for_picking_up=wandb.config["environment"]["reward_for_picking_up"],
    reward_for_stealing=wandb.config["environment"]["reward_for_stealing"],
)


reward_net = NonImageCnnRewardNet(
    env.observation_space,
    env.action_space,
    hid_channels=wandb.config["reward_model"]["hid_channels"],
    kernel_size=wandb.config["reward_model"]["kernel_size"],
)

rng = np.random.default_rng(wandb.config["seed"])

if GPU_NUMBER is not None:
    device = th.device(f"cuda:{GPU_NUMBER}" if th.cuda.is_available() else "mps" if th.backends.mps.is_available() else 'cpu')
    reward_net.to(device)
    print(f"Reward net on {device}.")

if config["feedback"]["type"] == 'scalar':
    fragmenter = RandomSingleFragmenter(rng=rng)
    gatherer = SyntheticScalarFeedbackGatherer(rng=rng)
else:
    fragmenter = RandomFragmenter(rng=rng)
    gatherer = SyntheticGatherer(rng=rng)

if wandb.config["visibility"]["visibility"] == "partial":
    # visibility_mask = construct_visibility_mask(
    #     wandb.config["environment"]["grid_size"],
    #     wandb.config["visibility"]["visibility_mask_key"],
    # )
    if wandb.config["visibility"]["visibility_mask_key"] == "(n-1)x(n-1)":
        observation_function = PartialGridVisibility(env, mask_key = wandb.config["visibility"]["visibility_mask_key"], feedback=config["feedback"]["type"])
        print("Debug new observation function: ", observation_function)
        policy_evaluator = partial_visibility_evaluator_factory(observation_function.visibility_mask)
    elif wandb.config["visibility"]["visibility_mask_key"] == "camera":
        observation_function = DynamicGridVisibility(env, feedback=config["feedback"]["type"])
        print("Debug new observation function: ", observation_function)
        policy_evaluator = camera_visibility_evaluator_factory(observation_function.visibility_mask)

    if wandb.config["feedback"]["type"] == 'scalar':
        gatherer = NoisyObservationGathererWrapper(gatherer, observation_function)
    elif wandb.config["feedback"]["type"] == 'preference':
        gatherer = PreferenceComparisonNoisyObservationGathererWrapper(gatherer, observation_function)

    #policy_evaluator = partial_visibility_evaluator_factory(observation_function.visibility_mask)

elif wandb.config["visibility"]["visibility"] == "full":
    policy_evaluator = full_visibility_evaluator_factory()

if config["feedback"]["type"] == 'scalar':
    feedback_model = ScalarFeedbackModel(model=reward_net)
    reward_trainer = BasicScalarFeedbackRewardTrainer(
        feedback_model=feedback_model,
        loss=MSERewardLoss(),  # Will need to change this for preference learning
        rng=rng,
        epochs=wandb.config["reward_trainer"]["num_epochs"],
    )

else:
    feedback_model = PreferenceModel(reward_net)
    reward_trainer = BasicRewardTrainer(
        preference_model=feedback_model,
        loss=CrossEntropyRewardLoss(),
        rng=rng,
        epochs=wandb.config["reward_trainer"]["num_epochs"],
    )

### I think that as long as we are in ValueIteration, this can stay like this?
trajectory_generator = DeterministicMDPTrajGenerator(
    reward_fn=reward_net,
    env=env,
    rng=None,  # This doesn't work yet
    epsilon=wandb.config["trajectory_generator"]["epsilon"],
)



logger = imit_logger.configure(format_strs=["stdout", "wandb"])


def save_model_params_and_dataset_callback(reward_learner):
    data_dir = os.path.join(wandb.run.dir, "saved_reward_models")
    latest_checkpoint_path = os.path.join(data_dir, "latest_checkpoint.pt")
    latest_dataset_path = os.path.join(data_dir, "latest_dataset.pkl")
    checkpoints_dir = os.path.join(data_dir, "checkpoints")
    checkpoint_iter_path = os.path.join(checkpoints_dir, f"model_weights_iter{reward_learner._iteration}.pt")
    dataset_iter_path = os.path.join(checkpoints_dir, f"dataset_iter{reward_learner._iteration}.pkl")

    os.makedirs(checkpoints_dir, exist_ok=True)
    th.save(reward_learner.model.state_dict(), latest_checkpoint_path)
    th.save(reward_learner.model.state_dict(), checkpoint_iter_path)
    reward_learner.dataset.save(latest_dataset_path)
    reward_learner.dataset.save(dataset_iter_path)

if config["feedback"]["type"] == 'scalar':
    reward_learner = ScalarRewardLearner(
        trajectory_generator=trajectory_generator,
        reward_model=reward_net,
        num_iterations=N_ITER,
        fragmenter=fragmenter,
        feedback_gatherer=gatherer,
        feedback_queue_size=wandb.config["dataset_max_size"],
        reward_trainer=reward_trainer,
        fragment_length=wandb.config["fragment_length"],
        transition_oversampling=wandb.config["transition_oversampling"],
        initial_epoch_multiplier=wandb.config["initial_epoch_multiplier"],
        policy_evaluator=policy_evaluator,
        custom_logger=logger,
        callback=save_model_params_and_dataset_callback,
    )

else:
    reward_learner = PreferenceComparisons(
        trajectory_generator=trajectory_generator,
        reward_model=reward_net,
        num_iterations=N_ITER,
        fragmenter=fragmenter,
        preference_gatherer=gatherer,
        comparison_queue_size=wandb.config["dataset_max_size"],
        reward_trainer=reward_trainer,
        fragment_length=wandb.config["fragment_length"],
        transition_oversampling=wandb.config["transition_oversampling"],
        initial_epoch_multiplier=wandb.config["initial_epoch_multiplier"],
        initial_comparison_frac=0.1,
        query_schedule="hyperbolic",
        policy_evaluator=policy_evaluator,
        custom_logger=logger,
    )

#######################################################################################################################
####################################################### Training ######################################################
#######################################################################################################################

if config["feedback"]["type"] == 'scalar':
    result = reward_learner.train(
        # Just needs to be bigger then N_ITER * HORIZON. Value iteration doesn't really use this.
        total_timesteps=10 * N_ITER * wandb.config["environment"]["horizon"],
        total_queries=N_COMPARISONS,
    )

else:
    result = reward_learner.train(
        # Just needs to be bigger then N_ITER * HORIZON. Value iteration doesn't really use this.
        total_timesteps=10 * N_ITER * wandb.config["environment"]["horizon"],
        total_comparisons=N_COMPARISONS,
        callback=save_model_params_and_dataset_callback,
    )
    

In [11]:
trajectories = trajectory_generator.sample(100)
horizons = (len(traj) for traj in trajectories if traj.terminal)

Constructing state index: 100%|██████████| 196900/196900 [00:05<00:00, 34066.04it/s]


In [13]:
for traj in trajectories:
    for obs in traj.obs:
        if np.any(obs[-1]):
            print("Carrying coin, breaking")
            break
    

Carrying coin, breaking


In [19]:
import numpy as np
import time

def render_gridworld(observation=None, trajectory=None, grid_size=5, delay=1.0):
    """
    Simple ASCII rendering of the environment.
    
    Args:
        observation (np.array): Observation space image with 5 channels.
                                If provided, it will be used to render the state.
        trajectory (list of np.array): List of observations to render as a trajectory.
        grid_size (int): Size of the grid (default is 5).
        delay (float): Delay in seconds between rendering consecutive frames in the trajectory.
    """
    HOME = "H"
    OWNED_PELLET = "x"
    FREE_PELLET = "."
    AGENT = "A"

    def render_single_observation(obs):
        if obs is not None:
            agent_position = np.argwhere(obs[0, :, :] == 1)[0]
            free_pellet_locations = np.argwhere(obs[1, :, :] == 1)
            owned_pellet_locations = np.argwhere(obs[2, :, :] == 1)
            home_location = np.argwhere(obs[3, :, :] == 1)[0]
            num_carried_pellets = obs[4, 0, 0]  # Assumes carried pellets are the same across all pixels in the channel
            agent_repr = str(num_carried_pellets)
        else:
            agent_position = None
            free_pellet_locations = []
            owned_pellet_locations = []
            home_location = (0, 0)
            num_carried_pellets = 0
            agent_repr = AGENT

        grid = np.full((grid_size, grid_size), " ")
        grid[home_location[0], home_location[1]] = HOME
        for loc in free_pellet_locations:
            grid[loc[0], loc[1]] = FREE_PELLET
        for loc in owned_pellet_locations:
            grid[loc[0], loc[1]] = OWNED_PELLET

        print("+" + "---+" * grid_size)
        for i in range(grid_size):
            print("|", end="")
            for j in range(grid_size):
                if agent_position is not None and agent_position[0] == i and agent_position[1] == j:
                    print(f"{agent_repr}{grid[i, j]} |", end="")
                else:
                    print(f" {grid[i, j]} |", end="")
            print("\n+" + "---+" * grid_size)

    if trajectory is not None:
        for step, obs in enumerate(trajectory.obs):
            print(f"Step {step + 1} of {len(trajectory)}")
            render_single_observation(obs)
            time.sleep(delay)
            print("\n" * 2)
    else:
        render_single_observation(observation)

In [20]:
render_gridworld(trajectory=traj, grid_size=5, delay=2.0)

Step 1 of 30
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   | x |0H |   |   |
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+



Step 2 of 30
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   | x | H |0  |   |
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+



Step 3 of 30
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |0  |   |
+---+---+---+---+---+
|   | x | H |   |   |
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+



Step 4 of 30
+---+---+---+---+---+
|   | . |   |0  |   |
+---+---+---+---+---+
|   |   |   |   |   |
+---+---+---+---+---+
|   | x | H |   |   |
+---+---+---+---+---+
|   | . |   |   |   |
+---+---+---+---+---+
|   |   |   |  

KeyboardInterrupt: 

In [25]:
traj.obs[10].shape

(5, 5, 5)

In [33]:
#traj.obs[10], traj.acts[10]

visibility_mask = np.array([
    [0, 0, 0, 0, 0],
    [0, 1, 1, 1, 0],
    [0, 1, 1, 1, 0],
    [0, 1, 1, 1, 0],
    [0, 0, 0, 0, 0],
])

masked_obs = traj.obs[10][:-1,:,:] * (1 - visibility_mask)

agent = masked_obs[0]
free_pellets = masked_obs[1]
owned_pellets = masked_obs[2]

if np.any(agent * free_pellets):
    print("Agent is on a free pellet off mask")
    if traj.acts[10] == 4:
        print("Agent is picking up a  pellet off mask")


Agent is on a free pellet off mask
Agent is picking up a  pellet off mask


In [36]:
traj.acts[10], traj.obs[11][-1]

(4,
 array([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]], dtype=int16))

In [40]:
d = {'chao': 120}
d2 = {"hello": 1}

d.update(d2)

In [41]:
d

{'chao': 120, 'hello': 1}

In [None]:
fragments = fragmenter(trajectories, 10, 70)

In [None]:
preferences = gatherer(fragments)

In [None]:
len(preferences)

In [None]:
len(trajectories[0].rews)

In [None]:
trajectories[0].obs

In [None]:
fragments[0][0]

In [None]:
trajectory_generator.env.rollout_with_policy()

In [None]:
np.random.randint((13,5,5,5))

In [None]:
t = np.random.randint(2, size = (13,5,5,5))

In [None]:
t[:, :-1].shape

In [None]:
a = np.random.randint(2, size=(10))
a

In [None]:
a[:-1]

In [None]:
import numpy as np
m = np.random.randint(1, size=(13,5,5,5))

In [None]:
m.shape

In [None]:
agent_visible = m[:-1, 0].any(axis=(1, 2))
agent_visible.shape

In [None]:
m[:, np.newaxis].shape

In [None]:
list = [((1,2), (3,4)), ((5,6), (7,8)), ((9,10), (11,12)), ((13,14), (15,16))]

# get all the first tuples in each tuple pair




In [None]:
for index, item in enumerate(iter(list)):
    print(index, item)

In [None]:
limits = [((12, 24), (8, 20)), ((3, 15), (1, 13)), ((13, 25), (14, 26)), ((9, 21), (7, 19)), ((7, 19), (15, 27)), ((1, 13), (10, 22)), ((5, 17), (12, 24)), ((1, 13), (18, 30)), ((14, 26), (3, 15)), ((10, 22), (2, 14)), ((6, 18), (18, 30)), ((1, 13), (8, 20)), ((13, 25), (9, 21)), ((17, 29), (8, 20)), ((1, 13), (5, 17)), ((15, 27), (7, 19)), ((12, 24), (2, 14)), ((15, 27), (14, 26)), ((5, 17), (14, 26)), ((0, 12), (2, 14)), ((12, 24), (8, 20)), ((10, 22), (12, 24)), ((15, 27), (5, 17)), ((18, 30), (4, 16)), ((0, 12), (4, 16)), ((17, 29), (5, 17)), ((9, 21), (14, 26)), ((7, 19), (15, 27)), ((2, 14), (1, 13)), ((13, 25), (3, 15)), ((9, 21), (13, 25)), ((4, 16), (5, 17)), ((18, 30), (1, 13)), ((6, 18), (17, 29)), ((14, 26), (18, 30)), ((14, 26), (8, 20)), ((8, 20), (17, 29)), ((13, 25), (5, 17)), ((8, 20), (16, 28)), ((12, 24), (8, 20)), ((1, 13), (12, 24)), ((14, 26), (0, 12)), ((12, 24), (2, 14)), ((15, 27), (3, 15)), ((8, 20), (6, 18)), ((2, 14), (18, 30)), ((1, 13), (1, 13)), ((8, 20), (17, 29)), ((15, 27), (5, 17)), ((16, 28), (8, 20)), ((0, 12), (17, 29)), ((5, 17), (2, 14)), ((4, 16), (12, 24)), ((6, 18), (2, 14)), ((4, 16), (0, 12)), ((5, 17), (15, 27)), ((3, 15), (18, 30)), ((17, 29), (4, 16)), ((10, 22), (0, 12)), ((2, 14), (16, 28)), ((4, 16), (16, 28)), ((0, 12), (14, 26)), ((16, 28), (1, 13)), ((0, 12), (17, 29)), ((1, 13), (15, 27)), ((6, 18), (11, 23)), ((5, 17), (5, 17)), ((9, 21), (9, 21)), ((3, 15), (4, 16)), ((18, 30), (2, 14)), ((14, 26), (11, 23)), ((0, 12), (13, 25)), ((8, 20), (2, 14)), ((18, 30), (5, 17)), ((5, 17), (18, 30)), ((9, 21), (0, 12)), ((15, 27), (2, 14)), ((0, 12), (18, 30)), ((14, 26), (15, 27)), ((16, 28), (0, 12)), ((4, 16), (15, 27)), ((14, 26), (11, 23)), ((5, 17), (7, 19)), ((2, 14), (12, 24)), ((10, 22), (4, 16)), ((15, 27), (15, 27)), ((9, 21), (11, 23)), ((6, 18), (5, 17)), ((11, 23), (3, 15)), ((8, 20), (2, 14)), ((3, 15), (3, 15)), ((9, 21), (12, 24)), ((18, 30), (9, 21)), ((10, 22), (0, 12)), ((4, 16), (4, 16)), ((5, 17), (0, 12)), ((16, 28), (1, 13)), ((7, 19), (17, 29)), ((15, 27), (17, 29)), ((11, 23), (14, 26)), ((14, 26), (2, 14)), ((18, 30), (16, 28)), ((13, 25), (12, 24)), ((8, 20), (9, 21)), ((6, 18), (6, 18)), ((8, 20), (7, 19)), ((15, 27), (18, 30)), ((14, 26), (6, 18)), ((9, 21), (12, 24)), ((14, 26), (12, 24)), ((0, 12), (3, 15)), ((17, 29), (2, 14)), ((11, 23), (13, 25)), ((13, 25), (2, 14)), ((7, 19), (11, 23)), ((17, 29), (17, 29)), ((0, 12), (1, 13)), ((2, 14), (10, 22)), ((12, 24), (9, 21)), ((9, 21), (9, 21)), ((16, 28), (2, 14)), ((9, 21), (12, 24)), ((3, 15), (12, 24)), ((15, 27), (14, 26)), ((10, 22), (5, 17)), ((1, 13), (17, 29)), ((2, 14), (13, 25)), ((6, 18), (14, 26)), ((12, 24), (18, 30)), ((8, 20), (2, 14)), ((2, 14), (2, 14)), ((10, 22), (0, 12)), ((8, 20), (5, 17)), ((0, 12), (5, 17)), ((10, 22), (2, 14)), ((13, 25), (5, 17)), ((12, 24), (18, 30)), ((11, 23), (1, 13)), ((4, 16), (4, 16)), ((15, 27), (3, 15)), ((17, 29), (0, 12)), ((15, 27), (11, 23)), ((5, 17), (16, 28)), ((18, 30), (17, 29)), ((15, 27), (15, 27)), ((2, 14), (2, 14)), ((10, 22), (11, 23)), ((15, 27), (13, 25)), ((1, 13), (13, 25)), ((9, 21), (1, 13)), ((4, 16), (16, 28)), ((7, 19), (1, 13)), ((11, 23), (8, 20)), ((16, 28), (10, 22)), ((10, 22), (5, 17)), ((4, 16), (12, 24)), ((8, 20), (9, 21)), ((2, 14), (2, 14)), ((1, 13), (5, 17)), ((7, 19), (12, 24)), ((15, 27), (4, 16)), ((17, 29), (0, 12)), ((2, 14), (4, 16)), ((11, 23), (13, 25)), ((11, 23), (6, 18)), ((3, 15), (8, 20)), ((18, 30), (9, 21)), ((2, 14), (17, 29)), ((5, 17), (15, 27)), ((7, 19), (9, 21)), ((16, 28), (15, 27)), ((10, 22), (6, 18)), ((1, 13), (14, 26)), ((13, 25), (16, 28)), ((4, 16), (11, 23)), ((9, 21), (18, 30)), ((8, 20), (5, 17)), ((14, 26), (7, 19)), ((12, 24), (3, 15)), ((13, 25), (3, 15)), ((18, 30), (10, 22)), ((12, 24), (2, 14)), ((6, 18), (1, 13)), ((9, 21), (8, 20)), ((4, 16), (17, 29)), ((3, 15), (5, 17)), ((13, 25), (9, 21)), ((15, 27), (16, 28)), ((7, 19), (12, 24)), ((13, 25), (6, 18)), ((17, 29), (16, 28)), ((13, 25), (16, 28)), ((5, 17), (11, 23)), ((11, 23), (9, 21)), ((0, 12), (7, 19)), ((7, 19), (3, 15)), ((4, 16), (0, 12)), ((7, 19), (5, 17)), ((7, 19), (5, 17)), ((13, 25), (18, 30)), ((18, 30), (14, 26)), ((5, 17), (18, 30)), ((1, 13), (2, 14)), ((12, 24), (10, 22)), ((2, 14), (18, 30)), ((18, 30), (14, 26)), ((5, 17), (0, 12)), ((16, 28), (18, 30)), ((4, 16), (7, 19)), ((4, 16), (10, 22)), ((8, 20), (15, 27)), ((18, 30), (0, 12)), ((14, 26), (13, 25)), ((14, 26), (14, 26)), ((17, 29), (4, 16)), ((8, 20), (0, 12)), ((10, 22), (0, 12)), ((11, 23), (12, 24)), ((9, 21), (6, 18)), ((13, 25), (11, 23)), ((10, 22), (16, 28)), ((12, 24), (6, 18)), ((10, 22), (18, 30)), ((10, 22), (2, 14)), ((16, 28), (13, 25)), ((9, 21), (17, 29)), ((1, 13), (6, 18)), ((10, 22), (0, 12)), ((16, 28), (6, 18)), ((18, 30), (14, 26)), ((6, 18), (17, 29)), ((16, 28), (17, 29)), ((4, 16), (11, 23)), ((16, 28), (4, 16)), ((6, 18), (16, 28)), ((4, 16), (15, 27)), ((16, 28), (9, 21)), ((4, 16), (15, 27)), ((8, 20), (3, 15)), ((8, 20), (12, 24)), ((17, 29), (8, 20)), ((14, 26), (1, 13)), ((10, 22), (1, 13)), ((15, 27), (5, 17)), ((12, 24), (7, 19)), ((8, 20), (3, 15)), ((4, 16), (7, 19)), ((7, 19), (16, 28)), ((5, 17), (9, 21)), ((10, 22), (14, 26)), ((4, 16), (6, 18)), ((16, 28), (12, 24)), ((4, 16), (17, 29)), ((16, 28), (13, 25)), ((6, 18), (2, 14)), ((17, 29), (3, 15)), ((13, 25), (11, 23)), ((4, 16), (11, 23)), ((5, 17), (16, 28)), ((17, 29), (14, 26)), ((4, 16), (17, 29)), ((17, 29), (18, 30)), ((2, 14), (6, 18)), ((5, 17), (4, 16)), ((15, 27), (1, 13)), ((7, 19), (0, 12)), ((16, 28), (13, 25)), ((2, 14), (16, 28)), ((14, 26), (2, 14)), ((15, 27), (9, 21)), ((17, 29), (0, 12)), ((7, 19), (2, 14)), ((7, 19), (10, 22)), ((13, 25), (0, 12)), ((6, 18), (12, 24)), ((0, 12), (11, 23)), ((15, 27), (4, 16)), ((9, 21), (7, 19)), ((14, 26), (10, 22)), ((2, 14), (7, 19)), ((0, 12), (15, 27)), ((17, 29), (1, 13)), ((3, 15), (12, 24)), ((5, 17), (3, 15)), ((18, 30), (17, 29)), ((16, 28), (13, 25)), ((7, 19), (9, 21)), ((10, 22), (14, 26)), ((14, 26), (17, 29)), ((9, 21), (1, 13)), ((8, 20), (0, 12)), ((15, 27), (13, 25)), ((11, 23), (6, 18)), ((7, 19), (7, 19)), ((9, 21), (8, 20)), ((6, 18), (7, 19)), ((5, 17), (13, 25)), ((13, 25), (4, 16)), ((11, 23), (16, 28)), ((18, 30), (3, 15))]

In [None]:
limits_f1 = [pair[0] for pair in limits]

In [None]:
limits_f1

In [None]:
import numpy as np

np.random.seed(0)
np.random.choice(np.random.randint(0,100), 10, replace=False)

In [None]:
#home_location_raveled = np.ravel_multi_index(self.home_location, (self.grid_size, self.grid_size))

np.ravel_multi_index((2,2), (5,5))

In [None]:
import itertools

c = 0
for i in itertools.product(range(5), repeat=2):
    c += 1
    print(i)

print(c)

In [None]:
obs = np.random.randint(1, size=(10, 5, 5, 5))


In [None]:
obs[:, -1, :, :].shape

In [None]:

carried_across_time = np.arange(10)
pickups = np.diff(carried_across_time)
agent_vis = np.random.randint(2, size=10)

print("carried_across_time = ", carried_across_time)
print("pickups = ", pickups)
print("agent_vis = ", agent_vis)

np.where(agent_vis[:-1] == 1, 0, pickups)



In [None]:
np.diff(pos)