In [8]:
import numpy as np 
import pandas as pd

import math
import sys
import os

import gymnasium as gym
gym.__version__

from citylearn.citylearn import CityLearnEnv
from citylearn.wrappers import NormalizedObservationWrapper, StableBaselines3Wrapper
from citylearn.reward_function import RewardFunction

sys.path.append("../custom_reward")
from custom_reward import CustomReward

from stable_baselines3 import SAC

import seaborn as sns
import matplotlib.pyplot as plt

In [9]:
class WrapperEnv:
    """
    Env to wrap provide Citylearn Env data without providing full env
    Preventing attribute access outside of the available functions
    """
    def __init__(self, env_data):
        self.observation_names = env_data['observation_names']
        self.action_names = env_data['action_names']
        self.observation_space = env_data['observation_space']
        self.action_space = env_data['action_space']
        self.time_steps = env_data['time_steps']
        self.seconds_per_time_step = env_data['seconds_per_time_step']
        self.random_seed = env_data['random_seed']
        self.buildings_metadata = env_data['buildings_metadata']
        self.episode_tracker = env_data['episode_tracker']
    
    def get_metadata(self):
        return {'buildings': self.buildings_metadata}

In [10]:
def makeEnv(schema_path, reward_function):# schema path
    # create environment
    env = CityLearnEnv(schema = schema_path, reward_function = reward_function, central_agent=True)

    env_data = dict(
        observation_names = env.observation_names,
        action_names = env.action_names,
        observation_space = env.observation_space,
        action_space = env.action_space,
        time_steps = env.time_steps,
        random_seed = None,
        episode_tracker = None,
        seconds_per_time_step = None,
        buildings_metadata = env.get_metadata()['buildings']
    )

    wrapper_env = WrapperEnv(env_data)
    return env, wrapper_env


In [12]:
schema_path = "../data/schema_edited.json"

env, wrapper_env = makeEnv(schema_path, CustomReward)

# wrap environment for use in stablebaselines3
env = NormalizedObservationWrapper(env)
env = StableBaselines3Wrapper(env)