In [None]:
# Sync Colab environment with pyproject.toml dependencies

# 1. Clone your repo (change URL + folder name!)
!git clone https://github.com/kziliask/drl_inventory_env.git
%cd drl_inventory_env

# 2. Install a TOML parser (tomli works well)
!pip install tomli

import tomli
from pathlib import Path
import subprocess
import sys

# 3. Load dependencies from [project] table
pyproject_path = Path("pyproject.toml")
pyproject = tomli.loads(pyproject_path.read_text())

deps = pyproject.get("project", {}).get("dependencies", [])

# (Optional) If you don't want dev tools like mypy/pre-commit in Colab:
deps = [d for d in deps if not d.startswith(("mypy", "pre-commit"))]

print("Dependencies from pyproject.toml:")
for d in deps:
    print("  -", d)

if deps:
    print("\nInstalling dependencies...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", *deps])
else:
    print("No dependencies field found under [project].")

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
import matplotlib.pyplot as plt


def plot_inventory_levels(env, model):
    inventory_levels = []
    rewards = []
    obs, info = env.reset(seed=42)
    done = False
    for _ in range(100):
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        inventory_levels.append(obs[0])
        rewards.append(reward)
        done = terminated or truncated
        if done:
            obs, info = env.reset()
    inventory_levels = np.array(inventory_levels)
    plt.figure(figsize=(10, 6))
    plt.step(range(len(inventory_levels)), inventory_levels, where="mid")
    plt.fill_between(
        range(len(inventory_levels)), 0, inventory_levels, step="mid", alpha=0.2
    )
    neg_indices = [i for i, r in enumerate(rewards) if r < 0]
    plt.scatter(
        neg_indices,
        inventory_levels[neg_indices],
        marker="o",
        color="red",
        label="Negative Reward",
    )

    plt.xlabel("Time Step")
    plt.ylabel("Inventory Level")
    plt.show()


class MinimalInventoryEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.demand_dist = [0, 4]  # uniform demand between 0 and 3
        self.max_steps = 100
        self.max_inventory = 10
        # 1D state: inventory level x_t âˆˆ [0, 10]
        self.observation_space = spaces.Box(
            low=np.array([0], dtype=np.float32),
            high=np.array([self.max_inventory], dtype=np.float32),
        )
        # Discrete actions: order quantity {0, 1, 2, 3, 4, 5}
        self.action_space = spaces.Discrete(6)
        self.inv = 0.0
        self.step_count = 0

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        # Set initial inventory
        self.step_count = 0
        self.inv = self.np_random.integers(0, self.max_inventory + 1)
        obs = np.array([self.inv], dtype=np.float32)
        info = {}
        return obs, info

    def step(self, action):
        self.step_count += 1
        reward = 0.0
        noise = self.np_random.integers(-1, 2)  # -1, 0, or +1 noise
        order_qty = np.clip(int(action) + noise, 0, 5)
        if self.inv + order_qty > self.max_inventory:
            order_qty = self.max_inventory - self.inv
            reward -= 2.0  # overstock penalty = stockout penalty
        self.inv += order_qty
        low, high = self.demand_dist
        demand = self.np_random.integers(low, high)
        demand_spike = self.np_random.random() < 0.2  # demand spikes
        if demand_spike:
            demand += 2
        if demand > self.inv:
            demand = self.inv  # lost sales
            reward -= 1.0
        else:
            reward += 1.0
        self.inv -= demand
        terminated = self.step_count >= self.max_steps
        truncated = False
        obs = np.array([self.inv], dtype=np.float32)
        info = {}
        return obs, reward, terminated, truncated, info

In [None]:
env = MinimalInventoryEnv()
check_env(env, warn=True)  # will print warnings if something is off

model = PPO(
    policy="MlpPolicy",
    env=env,
    verbose=1,
)

model.learn(total_timesteps=10_000, progress_bar=True)
# Save and load your trained model
# model.save("ppo_inventory")
# model = PPO.load("ppo_inventory", env=env)

plot_inventory_levels(env, model)

In [None]:
# policy_kwargs = dict(
#     net_arch=dict(
#         pi=[64, 64],      # policy network
#         vf=[64, 64]       # value function network
#     ))
# policy_kwargs = dict(net_arch=[64, 64])
vec_env = make_vec_env(MinimalInventoryEnv, n_envs=4)
%load_ext tensorboard
model = PPO(
    "MlpPolicy", vec_env, verbose=0, tensorboard_log="./ppo_tensorboard/", seed=42
)
%reload_ext tensorboard
%tensorboard --logdir ./ppo_tensorboard/ --port 6006

model.learn(total_timesteps=100_000, progress_bar=False)

plot_inventory_levels(env, model)