In [None]:
from dotenv import load_dotenv

load_dotenv("../conf/local/.env", override=True)

%reload_ext kedro.ipython


%reload_kedro

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
import reinforced_replenishment.envs.replenishment_env
from reinforced_replenishment.callbacks.reward_logging_callback import PlotLoggingCallback

In [None]:
# Create the environment
env = gym.make('ReplenishmentEnv-v0',
               max_steps=100,
               forecast_horizon=10,
               max_order=50,
               demand_prob=0.3,
               avg_demand=10)

# Init agent
env.reset(seed=42)
model = PPO("MlpPolicy", env, verbose=0)

# Fit
total_timesteps = int(2e6) #2e6 for good results
callback = PlotLoggingCallback(total_timesteps=total_timesteps)
model.learn(total_timesteps=total_timesteps, callback=callback)

In [None]:
# Test the agent
obs, _ = env.reset()
env.render()
# Initialize storage for observations
inventory_history = [0]
forecast_next_period = [obs[0]]
reward_history = [0]
action_history = [None]
backorder_history = [0]

# Run the simulation
for _ in range(100):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, info = env.step(action)  # Updated to handle terminated and truncated
    env.render()
    inventory_history.append(env.unwrapped.inventory)
    forecast_next_period.append(env.unwrapped.forecast[0])  # Save only the first forecasted demand
    reward_history.append(reward)  # Save the reward
    action_history.append(action)  # Save the action
    backorder_history.append(env.unwrapped.backorder)

    if terminated or truncated:  # End the loop if the episode is over
        break

In [None]:
# Plot the results
import plotly.graph_objects as go

# Create an interactive plot with plotly
fig = go.Figure()

# Add inventory line
fig.add_trace(go.Scatter(
    x=list(range(len(inventory_history))),
    y=inventory_history,
    mode='lines',  # Remove markers (dots) by using 'lines' only
    name='Inventory',
    line=dict(color='red', dash='dash')  # Dashed red line
))

# Add forecast line
fig.add_trace(go.Scatter(
    x=list(range(len(forecast_next_period))),
    y=forecast_next_period,
    mode='lines',  # Remove markers (dots) by using 'lines' only
    name='Forecast (Next Period)',
    line=dict(color='blue')  # Solid blue line
))

# Add action line
fig.add_trace(go.Scatter(
    x=list(range(len(action_history))),
    y=action_history,
    mode='lines',  # Remove markers (dots) by using 'lines' only
    name='Action',
    line=dict(color='purple')  # Solid purple line
))

# Add backorder line
fig.add_trace(go.Scatter(
    x=list(range(len(backorder_history))),
    y=backorder_history,
    mode='lines',  # Remove markers (dots) by using 'lines' only
    name='Backorder',
    line=dict(color='orange', dash='dot')  # Dotted orange line
))

# Add reward line on secondary y-axis
fig.add_trace(go.Scatter(
    x=list(range(len(reward_history))),
    y=reward_history,
    mode='lines',  # Remove markers (dots) by using 'lines' only
    name='Reward',
    line=dict(color='green'),  # Solid green line
    yaxis='y2'  # Specify secondary y-axis
))

# Customize layout
fig.update_layout(
    title="Replenishment Environment State Over Time",
    xaxis_title="Time Steps",
    yaxis_title="Units / Actions",
    yaxis2=dict(
        title="Reward",  # Title for secondary y-axis
        overlaying='y',  # Overlay on the same plot
        side='right'  # Place on the right side
    ),
    legend=dict(x=1, y=1.5, xanchor='right', yanchor='top'),  # Move legend outside the plot
    template="plotly_white"
)

# Show the interactive plot
fig.show()