In [1]:
import numpy as np
import gymnasium as gym
import gym_environment
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.airl import AIRL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
import imitation.policies.base

import enum
from typing import Union, Dict

SEED = 42

gym.register(
	id="HVAC-v0",
	entry_point=gym_environment.Environment,
	max_episode_steps=1440,
)
env = make_vec_env(
    "HVAC-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # to compute rollouts
)

class DumbPolicy(imitation.policies.base.NonTrainablePolicy):
	class Status(enum.Enum):
		NEED_COOL = 0
		WANT_COOL = 1
		NEED_HEAT = 2
		WANT_HEAT = 3
		EQUAL = 4
	def _choose_action(self, obs: Union[np.ndarray, Dict[str, np.ndarray]],) -> int:
		num_rooms = int((len(obs) - 3) / 4)

		epsilon = 0.9
		statuses = []
		need_heat, need_cool = 0, 0
		badness_heat, badness_cool = 0, 0
		min_temp, max_temp = 100000, -100000

		old_ac_status = obs[num_rooms * 2 + 2]
		old_dampers = [[]]
		for i in range(num_rooms):
			old_dampers[0].append(bool(obs[num_rooms * 2 + 4 + i * 2]))

		for i in range(num_rooms):
			temp, setp = obs[i * 2], obs[i * 2 + 1]
			min_temp = min(min_temp, temp)
			max_temp = max(max_temp, temp)
			if temp < setp - epsilon:
				statuses.append(self.Status.NEED_HEAT.value)
				badness_heat += abs(temp - setp)
				need_heat += 1
			elif temp < setp:
				statuses.append(self.Status.WANT_HEAT.value)
				badness_heat += abs(temp - setp)
			elif temp > setp + epsilon:
				statuses.append(self.Status.NEED_COOL.value)
				badness_cool += abs(temp - setp)
				need_cool += 1
			elif temp > setp:
				statuses.append(self.Status.WANT_COOL.value)
				badness_cool += abs(temp - setp)
			else:
				statuses.append(self.Status.EQUAL.value)

		outside_temp = obs[num_rooms * 2]
		dampers = [[]]
		if need_heat > need_cool:
			if max_temp >= outside_temp:
				for status in statuses:
					if status == self.Status.NEED_HEAT.value or status == self.Status.WANT_HEAT.value:
						dampers[0].append(False)
					else:
						dampers[0].append(True)
				return env.get_attr("actions")[0].index((1, dampers))
			else:
				return env.get_attr("actions")[0].index((0, old_dampers))
		if need_cool > need_heat:
			if min_temp <= outside_temp:
				for status in statuses:
					if status == self.Status.NEED_COOL.value or status == self.Status.WANT_COOL.value:
						dampers[0].append(False)
					else:
						dampers[0].append(True)
				return env.get_attr("actions")[0].index((-1, dampers))
			else:
				return env.get_attr("actions")[0].index((0, old_dampers))

		if need_cool > 0 and need_heat > 0:
			if badness_cool > badness_heat:
				if min_temp <= outside_temp:
					for status in statuses:
						if status == self.Status.NEED_COOL.value or status == self.Status.WANT_COOL.value:
							dampers[0].append(False)
						else:
							dampers[0].append(True)
					return env.get_attr("actions")[0].index((-1, dampers))
				else:
					return env.get_attr("actions")[0].index((0, old_dampers))
			if badness_heat > badness_cool:
				if max_temp >= outside_temp:
					for status in statuses:
						if status == self.Status.NEED_HEAT.value or status == self.Status.WANT_HEAT.value:
							dampers[0].append(False)
						else:
							dampers[0].append(True)
					return env.get_attr("actions")[0].index((1, dampers))
				else:
					return env.get_attr("actions")[0].index((0, old_dampers))
			
		return env.get_attr("actions")[0].index((old_ac_status, old_dampers))
expert = DumbPolicy(env.observation_space, env.action_space)


  from .autonotebook import tqdm as notebook_tqdm


In [29]:
rollouts = rollout.rollout(
	expert,
	env,
	rollout.make_sample_until(min_episodes=3000),
	rng=np.random.default_rng(SEED),
)

  logger.warn(


: 

In [24]:
import pickle

with open("rollouts.pkl", "wb") as out:
	pickle.dump(rollouts, out, pickle.HIGHEST_PROTOCOL)

In [25]:
rollouts = 0

In [27]:
with open("rollouts.pkl", "rb") as file:
	rollouts = pickle.load(file)

In [28]:
learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0005,
    gamma=0.95,
    clip_range=0.1,
    vf_coef=0.1,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicShapedRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
airl_trainer = AIRL(
    demonstrations=rollouts,
    demo_batch_size=2048,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=16,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)  # Train for 2_000_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)


KeyboardInterrupt: 

In [22]:
import torch
torch.save(reward_net.potential.state_dict(), "reward_net.pth")