In [5]:
import os
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union

# import couple of libs some will be useful
import gym
import numpy as np
from collections import deque
import random
import re
import os
import sys
import time
import json
import itertools
from datasets import Dataset
from _code.const import PATH_MODEL_SB,PATH_DATA_INTERACTIONS
from citylearn.agents.rbc import BasicRBC as BRBC
# import stable_baselines3
from stable_baselines3 import PPO, A2C, DDPG, TD3,SAC
from stable_baselines3.common.utils import set_random_seed

from citylearn.citylearn import CityLearnEnv
from utils.rewards import CustomReward

ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType")

import functools
from citylearn.wrappers import *


In [83]:
class NormalizedObservationWrapperCustom(ObservationWrapper):
    """Wrapper for observations min-max and periodic normalization.
    
    Temporal observations including `hour`, `day_type` and `month` are periodically normalized using sine/cosine 
    transformations and then all observations are min-max normalized between 0 and 1.

    Parameters
    ----------
    env: CityLearnEnv
        CityLearn environment.
    """

    def __init__(self, env: CityLearnEnv) -> None:
        super().__init__(env)
        self.env: CityLearnEnv
        self.dataset=[]

    @property
    def shared_observations_norm(self) -> List[str]:
        """Names of common observations across all buildings i.e. observations that have the same value irrespective of the building.
        
        Includes extra three observations added during cyclic transformation of :code:`hour`, :code:`day_type` and :code:`month`.
        """

        shared_observations = []
        periodic_observation_names = list(Building.get_periodic_observation_metadata().keys())

        for o in self.env.shared_observations:
            if o in periodic_observation_names:
                shared_observations += [f'{o}_cos', f'{o}_sin']
            
            else:
                shared_observations.append(o)

        return shared_observations
    
    


    

    def get_observation_norm(self, observations: List[List[float]]) -> List[List[float]]:
        """Returns normalized observations."""

        if self.env.central_agent:
            norm_observations = []
            shared_observations = []

            for i, b in enumerate(self.env.buildings):
                for k, v in b.observations(normalize=True, periodic_normalization=True).items():
                    if i==0 or k not in self.shared_observations_norm or k not in shared_observations:
                        norm_observations.append(v)

                    else:
                        pass

                    if k in self.shared_observations_norm and k not in shared_observations:
                        shared_observations.append(k)
                    
                    else:
                        pass
            
            norm_observations = [norm_observations]

        else:
            norm_observations = [list(b.observations(normalize=True, periodic_normalization=True).values()) for b in self.env.buildings]
        
        return norm_observations
    
    def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
        """Resets the environment with kwargs."""
        obs = self.env.reset(**kwargs)
        #print(obs)
        norm_obs = self.get_observation_norm(obs)
        self.current_obs = norm_obs
        self.dataset = []
        return self.env.reset(**kwargs)
    
    def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
        """Steps through the environment with action."""
        obs, reward, done, info = self.env.step(action)
        #print(obs)
        norm_obs = self.get_observation_norm(obs)
        print(norm_obs)
        
        
        self.dataset.append({
            "observations": self.current_obs,
            "next_observations": norm_obs,  # Assuming next observation is same as current for simplicity
            "actions": action,
            "rewards": reward,
            "dones": done,
            "info": info
        })
        
        self.current_obs = obs
        
        return self.env.step(action)

In [None]:
from citylearn.agents.rbc import OptimizedRBC as Agent

schema =  "citylearn_challenge_2022_phase_2"
env = CityLearnEnv(schema)
env = NormalizedObservationWrapperCustom(env)
env.central_agent = True
model_rbc = Agent(env)

model_rbc.learn(episodes=1)

[[7, 7, 24, 20.0, 18.3, 22.8, 20.0, 84.0, 81.0, 68.0, 81.0, 0.0, 25.0, 964.0, 0.0, 0.0, 100.0, 815.0, 0.0, 0.1707244, 1.4945166, 0.0, 0.0, 1.4945166, 0.22, 0.22, 0.22, 0.22], [7, 7, 24, 20.0, 18.3, 22.8, 20.0, 84.0, 81.0, 68.0, 81.0, 0.0, 25.0, 964.0, 0.0, 0.0, 100.0, 815.0, 0.0, 0.1707244, 0.77071667, 0.0, 0.0, 0.77071667, 0.22, 0.22, 0.22, 0.22], [7, 7, 24, 20.0, 18.3, 22.8, 20.0, 84.0, 81.0, 68.0, 81.0, 0.0, 25.0, 964.0, 0.0, 0.0, 100.0, 815.0, 0.0, 0.1707244, 9.7529096e-08, 0.0, 0.0, 9.7529096e-08, 0.22, 0.22, 0.22, 0.22], [7, 7, 24, 20.0, 18.3, 22.8, 20.0, 84.0, 81.0, 68.0, 81.0, 0.0, 25.0, 964.0, 0.0, 0.0, 100.0, 815.0, 0.0, 0.1707244, 0.63045, 0.0, 0.0, 0.63045, 0.22, 0.22, 0.22, 0.22], [7, 7, 24, 20.0, 18.3, 22.8, 20.0, 84.0, 81.0, 68.0, 81.0, 0.0, 25.0, 964.0, 0.0, 0.0, 100.0, 815.0, 0.0, 0.1707244, 0.5457, 0.0, 0.0, 0.5457, 0.22, 0.22, 0.22, 0.22]]
[[0.24999999999999978, 0.06698729810778081, 0.8535533905932737, 0.8535533905932737, 0.9829629131445341, 0.6294095225512604, 0.545

In [6]:
class StableBaselines3WrapperCustom(Wrapper):
    """Wrapper for :code:`stable-baselines3` algorithms.

    Wraps `env` in :py:class:`citylearn.wrappers.StableBaselines3ObservationWrapper`,
    :py:class:`citylearn.wrappers.StableBaselines3ActionWrapper`
    and :py:class:`citylearn.wrappers.StableBaselines3RewardWrapper`.
    
    Parameters
    ----------
    env: CityLearnEnv
        CityLearn environment.
    """

    def __init__(self, env: CityLearnEnv):
        env = StableBaselines3ActionWrapper(env)
        env = StableBaselines3RewardWrapper(env)
        env = StableBaselines3ObservationWrapper(env)
        super().__init__(env)
        self.env: CityLearnEnv
        self.dataset = []
        
    def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
        """Resets the environment with kwargs."""
        obs = self.env.reset(**kwargs)
        self.current_obs = obs
        self.dataset = []
        return self.env.reset(**kwargs)
    
    def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
        """Steps through the environment with action."""
        obs, reward, done, info = self.env.step(action)
        
        self.dataset.append({
            "observations": self.current_obs,
            "next_observations": obs,  # Assuming next observation is same as current for simplicity
            "actions": action,
            "rewards": reward,
            "dones": done,
            "info": info
        })
        
        self.current_obs = obs
        
        return self.env.step(action)
        

In [67]:
schema =  "citylearn_challenge_2022_phase_2"

In [68]:
sac_env = CityLearnEnv(schema)
sac_env.central_agent = True

In [52]:
sac_env = NormalizedObservationWrapperCustom(sac_env)

In [71]:
#sac_env = StableBaselines3WrapperCustom(sac_env)

In [70]:
sac_model = SAC(policy='MlpPolicy', env=sac_env, seed=10)

AttributeError: 'list' object has no attribute 'shape'

In [55]:
sac_model.learn(total_timesteps=1)

<stable_baselines3.sac.sac.SAC at 0x7fc118097100>

In [56]:
sac_env.observation_names

[['month_cos',
  'month_sin',
  'day_type_cos',
  'day_type_sin',
  'hour_cos',
  'hour_sin',
  'outdoor_dry_bulb_temperature',
  'outdoor_dry_bulb_temperature_predicted_6h',
  'outdoor_dry_bulb_temperature_predicted_12h',
  'outdoor_dry_bulb_temperature_predicted_24h',
  'outdoor_relative_humidity',
  'outdoor_relative_humidity_predicted_6h',
  'outdoor_relative_humidity_predicted_12h',
  'outdoor_relative_humidity_predicted_24h',
  'diffuse_solar_irradiance',
  'diffuse_solar_irradiance_predicted_6h',
  'diffuse_solar_irradiance_predicted_12h',
  'diffuse_solar_irradiance_predicted_24h',
  'direct_solar_irradiance',
  'direct_solar_irradiance_predicted_6h',
  'direct_solar_irradiance_predicted_12h',
  'direct_solar_irradiance_predicted_24h',
  'carbon_intensity',
  'non_shiftable_load',
  'solar_generation',
  'electrical_storage_soc',
  'net_electricity_consumption',
  'electricity_pricing',
  'electricity_pricing_predicted_6h',
  'electricity_pricing_predicted_12h',
  'electricity_

In [57]:
sac_env.dataset

[{'observations': array([0.0669873 , 0.25      , 0.8535534 , 0.14644662, 1.        ,
         0.5       , 0.54135334, 0.47744358, 0.6466165 , 0.54135334,
         0.82222223, 0.7888889 , 0.64444447, 0.7888889 , 0.        ,
         0.0245821 , 0.94788593, 0.        , 0.        , 0.10493179,
         0.85519415, 0.        , 0.47462252, 0.22008501, 0.        ,
         0.        , 0.5091632 , 0.03030304, 0.03030304, 0.03030304,
         0.03030304, 0.10645503, 0.        , 0.        , 0.4610448 ,
         0.03030304, 0.03030304, 0.03030304, 0.03030304, 0.        ,
         0.        , 0.        , 0.40333527, 0.03030304, 0.03030304,
         0.03030304, 0.03030304, 0.0835057 , 0.        , 0.        ,
         0.47982168, 0.03030304, 0.03030304, 0.03030304, 0.03030304,
         0.06343947, 0.        , 0.        , 0.4165131 , 0.03030304,
         0.03030304, 0.03030304, 0.03030304], dtype=float32),
  'next_observations': array([2.5000000e-01, 6.6987298e-02, 8.5355341e-01, 8.5355341e-01,
    

In [58]:
sac_env.reward_function

<citylearn.reward_function.RewardFunction at 0x7fc118097cd0>

In [59]:
first_b = sac_env.buildings[0]

In [60]:
first_b.observations().items()

dict_items([('month', 8), ('day_type', 1), ('hour', 2), ('outdoor_dry_bulb_temperature', 19.7), ('outdoor_dry_bulb_temperature_predicted_6h', 21.1), ('outdoor_dry_bulb_temperature_predicted_12h', 22.2), ('outdoor_dry_bulb_temperature_predicted_24h', 19.4), ('outdoor_relative_humidity', 78.0), ('outdoor_relative_humidity_predicted_6h', 73.0), ('outdoor_relative_humidity_predicted_12h', 73.0), ('outdoor_relative_humidity_predicted_24h', 87.0), ('diffuse_solar_irradiance', 0.0), ('diffuse_solar_irradiance_predicted_6h', 420.0), ('diffuse_solar_irradiance_predicted_12h', 683.0), ('diffuse_solar_irradiance_predicted_24h', 0.0), ('direct_solar_irradiance', 0.0), ('direct_solar_irradiance_predicted_6h', 592.0), ('direct_solar_irradiance_predicted_12h', 291.0), ('direct_solar_irradiance_predicted_24h', 0.0), ('carbon_intensity', 0.15450256), ('non_shiftable_load', 1.2106), ('solar_generation', 0.0), ('electrical_storage_soc', 0.97432476), ('net_electricity_consumption', 3.3523922), ('electrici

In [61]:
first_b.observations(normalize=True, periodic_normalization=True).items()

dict_items([('month_cos', 0.24999999999999978), ('month_sin', 0.06698729810778081), ('day_type_cos', 0.8535533905932737), ('day_type_sin', 0.8535533905932737), ('hour_cos', 0.9330127018922194), ('hour_sin', 0.75), ('outdoor_dry_bulb_temperature', 0.530075203133071), ('outdoor_dry_bulb_temperature_predicted_6h', 0.5827067660412003), ('outdoor_dry_bulb_temperature_predicted_12h', 0.6240601625064568), ('outdoor_dry_bulb_temperature_predicted_24h', 0.5187969649853647), ('outdoor_relative_humidity', 0.7555555555555555), ('outdoor_relative_humidity_predicted_6h', 0.7), ('outdoor_relative_humidity_predicted_12h', 0.7), ('outdoor_relative_humidity_predicted_24h', 0.8555555555555555), ('diffuse_solar_irradiance', 0.0), ('diffuse_solar_irradiance_predicted_6h', 0.41297935103244837), ('diffuse_solar_irradiance_predicted_12h', 0.671583087512291), ('diffuse_solar_irradiance_predicted_24h', 0.0), ('direct_solar_irradiance', 0.0), ('direct_solar_irradiance_predicted_6h', 0.621196222455404), ('direct_

In [62]:
sac_env.dataset[0]

{'observations': array([0.0669873 , 0.25      , 0.8535534 , 0.14644662, 1.        ,
        0.5       , 0.54135334, 0.47744358, 0.6466165 , 0.54135334,
        0.82222223, 0.7888889 , 0.64444447, 0.7888889 , 0.        ,
        0.0245821 , 0.94788593, 0.        , 0.        , 0.10493179,
        0.85519415, 0.        , 0.47462252, 0.22008501, 0.        ,
        0.        , 0.5091632 , 0.03030304, 0.03030304, 0.03030304,
        0.03030304, 0.10645503, 0.        , 0.        , 0.4610448 ,
        0.03030304, 0.03030304, 0.03030304, 0.03030304, 0.        ,
        0.        , 0.        , 0.40333527, 0.03030304, 0.03030304,
        0.03030304, 0.03030304, 0.0835057 , 0.        , 0.        ,
        0.47982168, 0.03030304, 0.03030304, 0.03030304, 0.03030304,
        0.06343947, 0.        , 0.        , 0.4165131 , 0.03030304,
        0.03030304, 0.03030304, 0.03030304], dtype=float32),
 'next_observations': array([2.5000000e-01, 6.6987298e-02, 8.5355341e-01, 8.5355341e-01,
        9.8296291e