In [None]:
from dotenv import load_dotenv

load_dotenv("../conf/local/.env", override=True)

%load_ext kedro.ipython


%reload_kedro

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
import reinforced_replenishment.envs.replenishment_env

In [None]:
# Create the environment
env = gym.make('ReplenishmentEnv-v0')

# Set the seed during reset
env.reset(seed=42)

# Initialize the PPO agent
model = PPO("MlpPolicy", env, verbose=1)

# Train the agent
model.learn(total_timesteps=10000)

# Save the model
model.save("../data/06_models/ppo_replenishment")

In [None]:
# Test the agent
env = gym.make('ReplenishmentEnv-v0')
env = env.unwrapped  # Access the underlying custom environment
obs, _ = env.reset()  # Extract the observation from the tuple

# Initialize storage for observations
inventory_history = []
forecast_next_period = []  # Store only the first entry of the forecast

# Run the simulation
for _ in range(100):
    action, _ = model.predict(obs)  # Pass only the observation to predict()
    obs, reward, terminated, truncated, info = env.step(action)  # Updated to handle terminated and truncated

    # Save the current state
    inventory_history.append(env.inventory)
    forecast_next_period.append(env.forecast[0])  # Save only the first forecasted demand

    if terminated or truncated:  # End the loop if the episode is over
        break

In [None]:
import plotly.graph_objects as go
# Create an interactive plot with plotly
fig = go.Figure()

# Add inventory line
fig.add_trace(go.Scatter(
    x=list(range(len(inventory_history))),
    y=inventory_history,
    mode='lines+markers',
    name='Inventory',
    line=dict(color='red', dash='dash'),
    marker=dict(size=6)
))

# Add forecast line
fig.add_trace(go.Scatter(
    x=list(range(len(forecast_next_period))),
    y=forecast_next_period,
    mode='lines+markers',
    name='Forecast (Next Period)',
    line=dict(color='blue'),
    marker=dict(size=6)
))

# Customize layout
fig.update_layout(
    title="Replenishment Environment State Over Time",
    xaxis_title="Time Steps",
    yaxis_title="Units",
    legend=dict(x=0, y=1),
    template="plotly_white"
)

# Show the interactive plot
fig.show()