# Environment (with discrete state space)

In [3]:
import gymnasium as gym
import numpy as np
from scipy.stats import gamma, norm, expon

class SteelProductionEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Tuple((
            gym.spaces.Discrete(4),
            gym.spaces.Discrete(5),
            gym.spaces.Discrete(2),
        ))
        # First workstation
        self.n = 20
        self.h1 = 800
        self.h2 = 200
        self.ci = 50
        self.cf = 3000
        self.cu = 100
        self.cr = 230
        self.cs = 0.01

        # Hammer degradation
        self.alpha_T1 = 1
        self.beta_T1 = 0.1
        self.alpha_T2 = 0.5
        self.beta_T2 = 0.12
        self.hammer_degradation = np.zeros(self.n)
        self.total_replaced_hammers = 0
        self.last_num_defective_hammers = 0

        # Production
        self.c1 = 1.0
        self.c2 = 0.8
        self.current_Pt = self.c1 * max(0, norm.rvs(loc=1.0, scale=0.1)) * 20
        self.production_pace = 'T1'

        # Maintenance parameters
        self.Tp = 4
        self.Tu = 1
        self.Ta_mean = 10

        self.last_pm_time = 0
        self.last_cm_time = 0
        self.cumulative_cm = 0
        self.first_ws_stopped = False
        self.second_ws_stopped = False

        # Second workstation
        self.demand_mean = 15
        self.demand_std = 1
        self.failure_rate = 1/336
        self.cm_duration_mean_ws2 = 3
        self.pm_duration_ws2 = 2
        self.WS2_age = 0
        self.lifetime = expon.rvs(scale=1/self.failure_rate)
        self.is_down = False
        self.current_rul = self.lifetime
        self.current_dt = norm.rvs(loc=self.demand_mean, scale=self.demand_std)
        self.unmet_demand_during_m = 0

        # Buffer
        self.buffer_capacity = 1000
        self.current_bt = 0

        # State and action spaces
        self.observation_space = gym.spaces.Tuple((
            gym.spaces.Discrete(4),
            gym.spaces.Discrete(5),
            gym.spaces.Discrete(2)
        ))
        self.action_space = gym.spaces.Discrete(2)

        self.current_state_Pt, self.current_state_bt, self.current_state_dt = self.get_state(
            self.current_Pt, self.current_bt, self.current_dt
        )
        self.time = 0

    def get_production_rate(self):
        rho_t = max(0, norm.rvs(loc=1.0, scale=0.1))
        hammer_contributions = np.zeros(self.n)
        for i in range(self.n):
            if self.hammer_degradation[i] < 0.2:
                hammer_contributions[i] = 1.0
            elif self.hammer_degradation[i] < 0.3:
                hammer_contributions[i] = 0.8
            elif self.hammer_degradation[i] < 0.4:
                hammer_contributions[i] = 0.6
            else:
                hammer_contributions[i] = 0.0
        if self.production_pace == 'T1':
            Pt = self.c1 * rho_t * np.sum(hammer_contributions)
        else:
            Pt = self.c2 * rho_t * np.sum(hammer_contributions)
        return Pt

    def get_demand(self):
        if self.second_ws_stopped:
            return 0
        else:
            return norm.rvs(loc=self.demand_mean, scale=self.demand_std)

    def get_rul(self):
        rul = self.lifetime - self.WS2_age
        return rul

    def get_buffer_level(self, Pt, bt, dt):
        new_bt = bt + Pt - dt
        return max(0, min(new_bt, self.buffer_capacity))

    def get_unmet_demand(self, Pt, bt, dt):
        if dt > Pt + bt:
            return dt - Pt - bt
        else:
            return 0

    def get_production_pace(self, bt):
        if bt >= self.h1:
            self.production_pace = 'T2'
        elif bt <= self.h2:
            self.production_pace = 'T1'
        return self.production_pace

    def get_state(self, Pt, bt, dt):
        if Pt >= 0.75 * self.c1 * self.n:
            Pt_state = 0
        elif Pt >= 0.5 * self.c1 * self.n:
            Pt_state = 1
        elif Pt >= 0.25 * self.c1 * self.n:
            Pt_state = 2
        else:
            Pt_state = 3

        if bt >= 800:
            bt_state = 4
        elif bt >= 600:
            bt_state = 3
        elif bt >= 400:
            bt_state = 2
        elif bt >= 200:
            bt_state = 1
        else:
            bt_state = 0

        if dt > 0:
            dt_state = 1
        else:
            dt_state = 0

        return Pt_state, bt_state, dt_state

    def update_hammer_degradation(self):
        if self.production_pace == 'T1':
            alpha = self.alpha_T1
            beta = self.beta_T1
        else:
            alpha = self.alpha_T2
            beta = self.beta_T2

        for i in range(self.n):
            degradation_increment = gamma.rvs(a=alpha, scale=beta)
            self.hammer_degradation[i] += degradation_increment
            if self.hammer_degradation[i] > 1.0:
                self.hammer_degradation[i] = 1.0

    def pm_perform_ws1_action1(self):
        self.last_num_defective_hammers = np.sum(self.hammer_degradation >= 0.2)
        if self.last_num_defective_hammers <= 3:
            replacement_cost = self.cr * 3
        else:
            replacement_cost = self.cr * self.last_num_defective_hammers
        maintenance_duration = self.Tp + self.last_num_defective_hammers * self.Tu
        self.hammer_degradation[self.hammer_degradation >= 0.2] = 0.0
        self.last_pm_time = self.time
        return maintenance_duration, replacement_cost

    def pm_perform_ws1(self):
        self.last_num_defective_hammers = np.sum(self.hammer_degradation >= 0.2)
        replacement_cost = self.cr * self.last_num_defective_hammers
        maintenance_duration = self.Tp + self.last_num_defective_hammers * self.Tu
        self.hammer_degradation[self.hammer_degradation >= 0.2] = 0.0
        self.last_pm_time = self.time
        return maintenance_duration, replacement_cost

    def cm_perform_ws1(self):
        self.last_num_defective_hammers = np.sum(self.hammer_degradation >= 0.2)
        replacement_cost = self.cr * self.last_num_defective_hammers
        maintenance_duration = self.Tp + self.last_num_defective_hammers * self.Tu + expon.rvs(scale=self.Ta_mean)
        self.hammer_degradation[self.hammer_degradation >= 0.2] = 0.0
        self.last_cm_time = self.time
        self.cumulative_cm += 1
        return maintenance_duration, replacement_cost

    def cm_perform_ws2(self):
        cm_duration = expon.rvs(scale=self.cm_duration_mean_ws2)
        self.lifetime = expon.rvs(scale=1/self.failure_rate)
        self.WS2_age = 0
        self.current_rul = self.lifetime
        return cm_duration

    def reset(self, seed=None, options=None):
        np.random.seed(seed)

        self.hammer_degradation = np.zeros(self.n)
        self.production_pace = 'T1'
        self.time = 0
        self.last_pm_time = 0
        self.last_cm_time = 0
        self.total_replaced_hammers = 0
        self.unmet_demand_during_m = 0
        self.current_Pt = self.c1 * max(0, norm.rvs(loc=1.0, scale=0.1)) * 20
        self.current_dt = norm.rvs(loc=self.demand_mean, scale=self.demand_std)
        self.lifetime = expon.rvs(scale=1/self.failure_rate)
        self.current_rul = self.lifetime
        self.current_state_Pt, self.current_state_bt, self.current_state_dt = self.get_state(
            self.current_Pt, self.current_bt, self.current_dt
        )

        return (self.current_state_Pt, self.current_state_bt, self.current_state_dt), {}

    def step(self, action):
        reward = 0
        terminated = False
        truncated = False

        self.first_ws_stopped = False
        self.second_ws_stopped = False
        time_before_m = self.time
        Pt_initial = self.current_Pt
        dt_initial = self.current_dt
        bt_initial = self.current_bt
        lifetime_initial = self.lifetime
        rul_initial = self.current_rul
        Pt_state_initial = self.current_state_Pt
        bt_state_initial = self.current_state_bt
        dt_state_initial = self.current_state_dt
        age_before_m = self.WS2_age
        hammer_degradation_initial = self.hammer_degradation.copy()
        production_pace_initial = self.production_pace

        ws2_maintenance_time = 0
        ws1_maintenance_time = 0

        if self.current_state_Pt == 3:
            cm_duration_ws1, replacement_cost = self.cm_perform_ws1()
            reward -= self.cf + self.ci + replacement_cost
            self.current_Pt = 0
            cm_duration_ws2 = 0

            if self.current_bt == 0:
                self.current_dt = 0
                self.second_ws_stopped = True
                if self.current_rul <= 0:
                    cm_duration_ws2 = self.cm_perform_ws2()

            self.first_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt

        elif self.time > 0 and self.current_bt <= 0:
            maintenance_duration, replacement_cost = self.pm_perform_ws1()
            reward -= self.ci + replacement_cost
            self.current_Pt = 0
            cm_duration_ws2 = 0

            if self.current_rul <= 0:
                cm_duration_ws2 = self.cm_perform_ws2()

            self.current_dt = 0
            self.second_ws_stopped = True
            self.first_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt

        elif self.time > 0 and self.current_bt > self.buffer_capacity:
            maintenance_duration, replacement_cost = self.pm_perform_ws1()
            reward -= self.ci + replacement_cost
            self.current_Pt = 0
            cm_duration_ws2 = 0

            self.first_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt


        elif self.current_rul <= 0:
            cm_duration = self.cm_perform_ws2()
            ws2_maintenance_time = cm_duration
            maintenance_duration, replacement_cost = self.pm_perform_ws1()
            reward -= self.ci + replacement_cost
            ws1_maintenance_time = maintenance_duration

            self.current_dt = 0
            self.current_Pt = 0
            self.first_ws_stopped = True
            self.second_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt


        elif self.time > 0 and int(self.time) % 24 == 0:
            if self.current_rul <= 0:
                cm_duration = self.cm_perform_ws2()
                ws2_maintenance_time = cm_duration
            else:
                ws2_maintenance_time = self.pm_duration_ws2

            maintenance_duration, replacement_cost = self.pm_perform_ws1()
            reward -= self.ci + replacement_cost
            ws1_maintenance_time = maintenance_duration

            self.current_dt = 0
            self.current_Pt = 0
            self.first_ws_stopped = True
            self.second_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt


        elif action == 1:
            last_num_defective_hammers = np.sum(self.hammer_degradation >= 0.2)
            maintenance_duration, replacement_cost = self.pm_perform_ws1_action1()
            reward -= self.ci + replacement_cost
            self.current_Pt = 0
            cm_duration_ws2 = 0

            if self.time > 0 and self.current_bt <= 0:
                self.current_dt = 0
                self.second_ws_stopped = True
                if self.current_rul <= 0:
                    cm_duration_ws2 = self.cm_perform_ws2()

            self.first_ws_stopped = True
            Pt_during_m = self.current_Pt
            bt_during_m = self.current_bt
            dt_during_m = self.current_dt

        hammer_degradation_after_maintenance = self.hammer_degradation.copy()
        time_after_m = self.time
        age_after_m = self.WS2_age

        self.update_hammer_degradation()
        self.time += 1
        self.WS2_age += 0 if self.second_ws_stopped else 1
        self.current_Pt = self.get_production_rate()
        self.current_dt = self.get_demand()
        self.current_bt = self.get_buffer_level(self.current_Pt, self.current_bt, self.current_dt)
        self.current_rul = self.get_rul()
        self.current_state_Pt, self.current_state_bt, self.current_state_dt = self.get_state(
            self.current_Pt, self.current_bt, self.current_dt
        )

        unmet_demand = self.get_unmet_demand(self.current_Pt, self.current_bt, self.current_dt)
        self.production_pace = self.get_production_pace(self.current_bt)
        reward -= self.cu * unmet_demand + self.cs * self.current_bt

        if self.time >= 1000:
            terminated = True

        '''print(f"""
*** Inspection and Maintenance Times:
    Time Before Maintenance(valid if it's in a maintenance state): {time_before_m}
    Time After Maintenance(valid if it's in a maintenance state): {time_after_m if 'time_after_m' in locals() else 'N/A'}
    Inspection Time At the End: {self.time}
    Action: {action}
    Reward: {reward}


*** Initial Values:
    Production Rate (Pt): {Pt_initial}
    Demand (dt): {dt_initial}
    Buffer level (bt): {bt_initial}
    State -> Pt_state: {Pt_state_initial}, bt_state: {bt_state_initial}
    Hammer Degradation (Initial): {hammer_degradation_initial}
    production pace: {production_pace_initial}
""")

        if self.first_ws_stopped:
            print(f"""
*** During Maintenance:
    Production Rate (Pt): {Pt_during_m}
    Demand (dt): {dt_during_m}
    Buffer level (bt): {bt_during_m}
""")

        print(f"""
*** After 1 step:
    Production Rate (Pt): {self.current_Pt}
    Demand (dt): {self.current_dt}
    Buffer level (bt): {self.current_bt}
    State -> Pt_state: {self.current_state_Pt}, bt_state: {self.current_state_bt}, dt_state: {self.current_state_dt}
    Total Replaced Hammers: {self.last_num_defective_hammers}
    Hammer Degradation (After Maintenance): {hammer_degradation_after_maintenance if 'hammer_degradation_after_maintenance' in locals() else hammer_degradation_initial}
    Hammer Degradation (After 1 step): {self.hammer_degradation}
    production pace: {self.production_pace}
    Unmet Demand: {unmet_demand}
""")

        print(f"""
*** next state: (Pt, bt, dt):{self.current_state_Pt, self.current_state_bt, self.current_state_dt}

""")
        print("-" * 150)'''

        return (
            (self.current_state_Pt, self.current_state_bt, self.current_state_dt),
            reward,
            terminated,
            truncated,
            {}
        )


In [None]:
!pip install stable-baselines3 gymnasium

# DDQN

In [None]:
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
import matplotlib.pyplot as plt
from tqdm import tqdm
from stable_baselines3.common.callbacks import CallbackList
import gymnasium as gym

class TupleToArrayWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0, high=10, shape=(3,), dtype=np.float32
        )

    def observation(self, obs):
        return np.array(obs, dtype=np.float32)

class DoubleDQN(DQN):
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        self.policy.set_training_mode(True)
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            replay_data = self.replay_buffer.sample(
                batch_size, env=self._vec_normalize_env
            )

            with torch.no_grad():
                next_q_online = self.policy.q_net(replay_data.next_observations)
                next_actions = next_q_online.argmax(dim=1, keepdim=True)

                next_q_target_all = self.policy.q_net_target(replay_data.next_observations)
                next_q = torch.gather(next_q_target_all, 1, next_actions).reshape(-1, 1)

                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q

            current_q_values = self.policy.q_net(replay_data.observations)
            current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long())

            loss = torch.nn.functional.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            self.policy.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        self._n_updates += gradient_steps
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", float(np.mean(losses)))


class ExponentialDecayCallback(BaseCallback):
    def __init__(self, decay_rate=0.999, min_epsilon=0.05, verbose=0):
        super(ExponentialDecayCallback, self).__init__(verbose)
        self.decay_rate = decay_rate
        self.min_epsilon = min_epsilon

    def _on_step(self) -> bool:
        if self.model.exploration_rate > self.min_epsilon:
            self.model.exploration_rate = max(
                self.model.exploration_rate * self.decay_rate,
                self.min_epsilon
            )
        return True

class RewardTrackerCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(RewardTrackerCallback, self).__init__(verbose)
        self.episode_rewards = []
        self.current_episode_reward = 0
        self.episode_lengths = []
        self.current_episode_length = 0

    def _on_step(self) -> bool:
        reward = self.locals['rewards'][0]
        done = self.locals['dones'][0]

        self.current_episode_reward += reward
        self.current_episode_length += 1

        if done:
            self.episode_rewards.append(self.current_episode_reward)
            self.episode_lengths.append(self.current_episode_length)
            self.current_episode_reward = 0
            self.current_episode_length = 0

        return True

    def plot_rewards(self):
        """Plot the reward convergence"""
        plt.figure(figsize=(12, 6))

        window_size = max(1, len(self.episode_rewards) // 50)
        moving_avg = np.convolve(
            self.episode_rewards,
            np.ones(window_size)/window_size,
            mode='valid'
        )

        plt.plot(self.episode_rewards, alpha=0.3, label='Episode Reward')
        plt.plot(
            range(window_size-1, window_size-1 + len(moving_avg)),
            moving_avg,
            'r-',
            linewidth=2,
            label=f'Moving Avg ({window_size} episodes)'
        )

        plt.title('Reward Convergence During Training')
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('reward_convergence.png')
        plt.show()

hyperparams = {
    "learning_rate": 0.003,
    "buffer_size": 1000,
    "batch_size": 64,
    "tau": 1.0,
    "gamma": 0.99,
    "train_freq": (1, "step"),
    "target_update_interval": 10,
    "exploration_final_eps": 0.05,
    "exploration_initial_eps": 1.0,
    "learning_starts": 64,
    "policy_kwargs": {
         "net_arch": [32, 32],
        "activation_fn": nn.ReLU
    },
    "seed": 42,
}

env = make_vec_env(
    lambda: TupleToArrayWrapper(SteelProductionEnv()),
    n_envs=1,
    vec_env_cls=DummyVecEnv
)

model = DoubleDQN("MlpPolicy", env, **hyperparams)

decay_callback = ExponentialDecayCallback()
reward_tracker = RewardTrackerCallback()
callback_list = CallbackList([decay_callback, reward_tracker])

total_steps = 1000 * 1000
model.learn(
    total_timesteps=total_steps,
    callback=callback_list,
    log_interval=100
)

model.save("DDQN")

reward_tracker.plot_rewards()

In [None]:
# discrete DDQN

env_discrete = make_vec_env(
    lambda: SteelProductionEnv(),
    n_envs=1,
    vec_env_cls=DummyVecEnv
)
model_discrete = DoubleDQN("MlpPolicy", env, **hyperparams)
model_discrete.load("DDQN")


## Evaluation of the trained DDQN agent

In [None]:
# -----------------------------------------------
# Evaluation of the trained agents
#    with greedy action selection
# -----------------------------------------------

n = 20
c1 = 1

env_discrete = SteelProductionEnv()

eval_episodes = 500
eval_epsilon  = 0
eval_rewards_ddqn_discrete = []


#Discrete DDQN evaluation
for _ in range(eval_episodes):
    obs, _ = env_discrete.reset()
    ep_reward_discrete = 0
    done = False

    while not done:
        action, _ = model_discrete.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env_discrete.step(action)
        ep_reward_discrete += reward
        if done or truncated:
            break

    eval_rewards_ddqn_discrete.append(ep_reward_discrete)


print("DDQN (discrete) results:")
print(f"ε = {eval_epsilon}")
print(f"Average reward from the discrete DDQN trained agent (thesis) during {eval_episodes} episodes: {np.mean(eval_rewards_ddqn_discrete):.2f}")
print(f"SD reward from the DDQN trained agent (thesis) during {eval_episodes} episodes: {np.std(eval_rewards_ddqn_discrete):.2f}")
print(f"Max episode reward: {np.max(eval_rewards_ddqn_discrete):.2f}")
print(f"Min episode reward: {np.min(eval_rewards_ddqn_discrete):.2f}")
