A template of the RL training

In [1]:
import numpy as np
import pandas as pd

from pathlib import Path
from datetime import datetime

from utils import print_log

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# load the created dataset
dataset_folder_path = Path("dataset", "20250707", "split")

In [4]:
# copied from 03_data_split.ipynb

# Helper functions for the new split folder structure
def load_split_data_from_folder(split_folder, split_type='train'):
    """Load aggregate data from split folder"""
    segments = []
    with open(split_folder / f'{split_type}_segments.txt', 'r') as f:
        for line in f:
            start_str, end_str = line.strip().split(' - ')
            start = datetime.fromisoformat(start_str)
            end = datetime.fromisoformat(end_str)
            segments.append((start, end))
    
    df = pd.read_pickle(split_folder / f'{split_type}_aggregate_df.pkl')
    return segments, df

def load_signatures_from_split_folder(split_folder, split_type, appliance):
    """Load appliance signatures from split folder"""
    sig_path = split_folder / 'load_signature_library' / split_type / appliance / 'load_signatures.pkl'
    ranges_path = split_folder / 'load_signature_library' / split_type / appliance / 'selected_ranges.txt'
    
    if not sig_path.exists():
        return pd.DataFrame(), []
    
    signatures_df = pd.read_pickle(sig_path)
    ranges = []
    if ranges_path.exists():
        with open(ranges_path, 'r') as f:
            for line in f:
                start, end = map(int, line.strip().split(','))
                ranges.append((start, end))
    
    return signatures_df, ranges

In [5]:
# convert datetime objects to timezone-naive datetime objects
def convert_to_naive_datetimes_df(df):
    """Convert datetime objects in DataFrame to timezone-naive datetime objects"""
    df['datetime'] = df['datetime'].apply(lambda x: x.replace(tzinfo=None) if isinstance(x, datetime) else x)

    return df

def convert_to_naive_datetimes(segments):
    """Convert datetime objects in segments to timezone-naive datetime objects"""
    return [(start.replace(tzinfo=None), end.replace(tzinfo=None)) for start, end in segments]

In [6]:
aggregate_load_segments_train, aggregate_load_df_train = load_split_data_from_folder(dataset_folder_path, 'train')
aggregate_load_segments_test, aggregate_load_df_test = load_split_data_from_folder(dataset_folder_path, 'test') 
aggregate_load_segments_validation, aggregate_load_df_validation = load_split_data_from_folder(dataset_folder_path, 'val')

In [7]:
aggregate_load_segments_train, aggregate_load_df_train = convert_to_naive_datetimes(aggregate_load_segments_train), convert_to_naive_datetimes_df(aggregate_load_df_train)
aggregate_load_segments_test, aggregate_load_df_test = convert_to_naive_datetimes(aggregate_load_segments_test), convert_to_naive_datetimes_df(aggregate_load_df_test)
aggregate_load_segments_validation, aggregate_load_df_validation = convert_to_naive_datetimes(aggregate_load_segments_validation), convert_to_naive_datetimes_df(aggregate_load_df_validation)

In [8]:
aggregate_load_df_train

Unnamed: 0,timestamp,aggregate,datetime,washing_machine,dishwasher,fridge,kettle,microwave,toaster,tv,htpc,gas_oven,kitchen_lights
0,1.357603e+09,234.0,2013-01-08 00:00:05,0.0,1.0,0.0,1.0,1.0,0.0,1.0,69.0,,0.0
1,1.357603e+09,231.0,2013-01-08 00:00:11,0.0,1.0,0.0,1.0,1.0,0.0,1.0,70.0,,0.0
2,1.357603e+09,234.0,2013-01-08 00:00:17,0.0,1.0,0.0,1.0,1.0,0.0,1.0,70.0,,0.0
3,1.357603e+09,232.0,2013-01-08 00:00:23,0.0,1.0,0.0,1.0,1.0,0.0,1.0,68.0,,0.0
4,1.357603e+09,232.0,2013-01-08 00:00:30,0.0,1.0,0.0,1.0,1.0,0.0,1.0,70.0,,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2231636,1.388448e+09,178.0,2013-12-30 23:59:35,0.0,1.0,0.0,1.0,1.0,0.0,1.0,2.0,3.0,0.0
2231637,1.388448e+09,177.0,2013-12-30 23:59:41,0.0,1.0,0.0,1.0,1.0,0.0,1.0,2.0,3.0,0.0
2231638,1.388448e+09,178.0,2013-12-30 23:59:47,0.0,1.0,0.0,1.0,1.0,0.0,1.0,2.0,3.0,0.0
2231639,1.388448e+09,178.0,2013-12-30 23:59:53,0.0,1.0,0.0,1.0,1.0,0.0,1.0,2.0,3.0,0.0


In [9]:
aggregate_load_segments_train

[(datetime.datetime(2013, 1, 8, 0, 0),
  datetime.datetime(2013, 1, 10, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 2, 27, 0, 0),
  datetime.datetime(2013, 2, 28, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 3, 8, 0, 0),
  datetime.datetime(2013, 3, 10, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 3, 28, 0, 0),
  datetime.datetime(2013, 3, 31, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 3, 22, 0, 0),
  datetime.datetime(2013, 3, 26, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 4, 8, 0, 0),
  datetime.datetime(2013, 4, 10, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 4, 26, 0, 0),
  datetime.datetime(2013, 4, 30, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 4, 1, 0, 0),
  datetime.datetime(2013, 4, 7, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 5, 15, 0, 0),
  datetime.datetime(2013, 5, 16, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 5, 1, 0, 0),
  datetime.datetime(2013, 5, 7, 23, 59, 59, 999999)),
 (datetime.datetime(2013, 5, 25, 0, 0),
  datetime.dateti

In [10]:
from rl_env.env_data_loader import SmartMeterDataLoader

sm_dl_train = SmartMeterDataLoader(
    aggregate_load_segments=aggregate_load_segments_train,
    aggregate_load_df=aggregate_load_df_train
)

sm_dl_train.get_divided_segments_length()

162

In [11]:
sm_dl_train.divided_segments[7]

array([datetime.datetime(2013, 3, 10, 0, 0),
       datetime.datetime(2013, 3, 10, 23, 59, 59, 999999)], dtype=object)

In [12]:
# sample segment

sm_dl_train.get_aggregate_load_segment(13)

Unnamed: 0,timestamp,aggregate,datetime
104747,1.363997e+09,335.0,2013-03-23 00:00:05
104748,1.363997e+09,336.0,2013-03-23 00:00:11
104749,1.363997e+09,333.0,2013-03-23 00:00:17
104750,1.363997e+09,334.0,2013-03-23 00:00:24
104751,1.363997e+09,331.0,2013-03-23 00:00:30
...,...,...,...
118501,1.364083e+09,179.0,2013-03-23 23:59:30
118502,1.364083e+09,171.0,2013-03-23 23:59:37
118503,1.364083e+09,171.0,2013-03-23 23:59:43
118504,1.364083e+09,171.0,2013-03-23 23:59:49


In [13]:
# create dataloader for validation and test sets
sm_dl_validation = SmartMeterDataLoader(
    aggregate_load_segments=aggregate_load_segments_validation,
    aggregate_load_df=aggregate_load_df_validation
)

sm_dl_test = SmartMeterDataLoader(
    aggregate_load_segments=aggregate_load_segments_test,
    aggregate_load_df=aggregate_load_df_test
)

(Optional) Load the pre-trained H-network and related components

In final product, the H-network should be trained along with the DDQL/PPO agent

In [14]:
import torch
from model.H_network.h_network import HNetwork

h_network_datetime = datetime(2025, 7, 12)

h_network_path = Path("model_trained", f"h_network_{h_network_datetime.strftime('%Y%m%d')}.pth")

h_network = HNetwork(2, 44, 1)
h_network.load_state_dict(torch.load(h_network_path))
h_network.eval()

HNetwork(
  (LSTM_1): LSTM(2, 44, batch_first=True, bidirectional=True)
  (ac1): Tanh()
  (LSTM_2): LSTM(88, 1, batch_first=True, bidirectional=True)
  (ac2): Tanh()
  (fc): Linear(in_features=2, out_features=1, bias=True)
)

In [15]:
h_network_stdscaler_path = Path("model_trained", f"h_network_standardscaler_{h_network_datetime.strftime('%Y%m%d')}.pkl")
import joblib
h_network_stdscaler = joblib.load(h_network_stdscaler_path)

Create the environment

In [None]:
import sys
sys.path.append(str(Path('rl_env')))

from rl_env.hrl_env import SmartMeterWorld

env_train = SmartMeterWorld(
    smart_meter_data_loader=sm_dl_train,
    render_mode="human",
)

env_train.set_h_network(h_network)
env_train.set_h_network_stdscaler(h_network_stdscaler)

In [None]:
from gymnasium.utils.env_checker import check_env

# This will catch many common issues
try:
    check_env(env_train)
    print("Environment passes all checks!")
except Exception as e:
    print(f"Environment has issues: {e}")

In [None]:
obs = env_train.reset()
obs

In [None]:
env_train.reset_render_window()

In [None]:
# initialize a PPO agent
from stable_baselines3 import PPO

rl_datetime = datetime.now()
tensorboard_log_path = Path("rl_model", "PPO", f"{rl_datetime.strftime('%Y%m%d_%H%M%S')}")

rl_model = PPO(
    "MultiInputPolicy", 
    env_train, 
    verbose=2,
    tensorboard_log=tensorboard_log_path
)

rl_model.learn(
    total_timesteps=100000,
    progress_bar=True,
    tb_log_name="PPO_SmartMeterWorld"
)

Create a validation environment

and put the policy into the validation env

In [16]:
import sys
sys.path.append(str(Path('rl_env')))
from rl_env.hrl_env import SmartMeterWorld

env_valid = SmartMeterWorld(
    smart_meter_data_loader=sm_dl_validation,
    render_mode="human",
)

[2025-07-13 02:56:27:240] [SmartMeterWorld] Could not connect to render server: [Errno 111] Connection refused
[2025-07-13 02:56:27:240] [SmartMeterWorld] Render mode set to 'human'. Render server at 127.0.0.1:50007. render_connected: False. render_client_socket: <socket.socket fd=77, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('0.0.0.0', 40678)>


In [17]:
env_valid.set_h_network(h_network)
env_valid.set_h_network_stdscaler(h_network_stdscaler)

In [20]:
rl_model.set_env(env_valid)

NameError: name 'rl_model' is not defined

In [None]:
obs, info = env_valid.reset()
for i in range(1):
    done = False
    while not done:
        action, _states = rl_model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env_valid.step(action)
        print_log(f"Step: {env_valid.episode.get_current_step()}, Action: {action}, Reward: {reward}")
        env_valid.render()

In [None]:
env_valid.episode.df

In [None]:
# save the graph

env_valid.save_graph(
    str(Path("rl_model", "PPO", f"{rl_datetime.strftime('%Y%m%d_%H%M%S')}", "graph_valid.png"))
)

In [None]:
env_valid.close()

In [None]:
# save the model
rl_model_path = Path("rl_model", "PPO", f"{rl_datetime.strftime('%Y%m%d_%H%M%S')}", "rl_model.zip")
rl_model.save(rl_model_path)

---

In [61]:
# load the model & environment
import sys
sys.path.append(str(Path('rl_env')))

from rl_env.hrl_env import SmartMeterWorld
from stable_baselines3 import PPO

env_test = SmartMeterWorld(
    sm_dl_test,
    render_mode="human",
)

env_test.set_h_network(h_network)
env_test.set_h_network_stdscaler(h_network_stdscaler)

rl_model_path = Path("rl_model", "PPO", f"{datetime(2025,7,12,18,25,2).strftime('%Y%m%d_%H%M%S')}", "rl_model.zip")
rl_model_loaded = PPO.load(rl_model_path, env=env_test)

[2025-07-13 06:06:39:156] [SmartMeterWorld] Render mode set to 'human'. Render server at 127.0.0.1:50007. render_connected: True. render_client_socket: <socket.socket fd=93, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=0, laddr=('127.0.0.1', 36298), raddr=('127.0.0.1', 50007)>
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [68]:
obs, info = env_test.reset(43)

[2025-07-13 06:16:20:872] [SmartMeterWorld] Resetting environment with a new episode. Episode info: {'length': 13821, 'datetime_range': (Timestamp('2013-06-19 00:00:05'), Timestamp('2013-06-19 23:59:54'))}


In [69]:
info

{'current_step': 0,
 'battery_soc (kWh)': 0.6522993,
 'user_load (W)': 205.0,
 '(prev) grid_load (W)': None,
 'last_action (kW)': None,
 'last_battery_actiuon (kW)': None,
 'last_reward': None,
 'last_f_signal': None,
 'last_g_signal': None}

In [67]:
env_test.reset_render_window()

In [64]:
env_test.episode.df

Unnamed: 0,timestamp,aggregate,datetime,grid_load,battery_soc
434548,1.371600e+09,205.0,2013-06-19 00:00:05,,
434549,1.371600e+09,205.0,2013-06-19 00:00:11,,
434550,1.371600e+09,208.0,2013-06-19 00:00:17,,
434551,1.371600e+09,205.0,2013-06-19 00:00:23,,
434552,1.371600e+09,204.0,2013-06-19 00:00:30,,
...,...,...,...,...,...
448364,1.371686e+09,239.0,2013-06-19 23:59:30,,
448365,1.371686e+09,238.0,2013-06-19 23:59:36,,
448366,1.371686e+09,238.0,2013-06-19 23:59:42,,
448367,1.371686e+09,240.0,2013-06-19 23:59:48,,


In [70]:
for i in range(1):
    done = False
    while not done:
        action, _states = rl_model_loaded.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env_test.step(action)
        print_log(f"Step: {env_test.episode.get_current_step()}, Action: {action}, Reward: {reward}, Info: {info}")
        env_test.render()

[2025-07-13 06:16:24:596] Step: 1, Action: [-0.16942872], Reward: 0.32956501571081803, Info: {'current_step': 1, 'battery_soc (kWh)': 0.65225655, 'user_load (W)': 205.0, '(prev) grid_load (W)': 0.0, 'last_action (kW)': -0.6777148842811584, 'last_battery_actiuon (kW)': -0.205, 'last_reward': 0.32956501571081803, 'last_f_signal': -0.366187185049057, 'last_g_signal': 3.4508333333333336e-05}
[2025-07-13 06:16:24:605] Step: 2, Action: [-0.16942874], Reward: 0.49733292488001507, Info: {'current_step': 2, 'battery_soc (kWh)': 0.6522139, 'user_load (W)': 208.0, '(prev) grid_load (W)': 0.0, 'last_action (kW)': -0.6777149438858032, 'last_battery_actiuon (kW)': -0.205, 'last_reward': 0.49733292488001507, 'last_f_signal': -0.5525959730148315, 'last_g_signal': 3.4508333333333336e-05}
[2025-07-13 06:16:24:614] Step: 3, Action: [-0.16943836], Reward: 0.3823827569111812, Info: {'current_step': 3, 'battery_soc (kWh)': 0.65217054, 'user_load (W)': 205.0, '(prev) grid_load (W)': 0.0, 'last_action (kW)': 

In [None]:
env_test.save_graph(
    str(Path("rl_model", "PPO", f"{rl_datetime.strftime('%Y%m%d_%H%M%S')}", "graph_test.png"))
)

In [None]:
env_test.close()