In [None]:
import numpy as np
from gym_unbalanced_disk import UnbalancedDisk
import gymnasium as gym
from gymnasium import spaces

class AC_UnbalancedDisk(UnbalancedDisk):
    def __init__(self, umax=3., dt=0.025, render_mode='human'):
        super().__init__(umax=umax, dt=dt, render_mode=render_mode)

        self.target = np.pi
        low = [-np.pi, -40]
        high = [np.pi, 40]
        self.observation_space = spaces.Box(
            low=np.array(low, dtype=np.float32),
            high=np.array(high, dtype=np.float32),
            shape=(2,)
        )

        self.recent_omegas = []

    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        th = obs[0]
        omega = obs[1]

        # Normalize angle so π maps to 0
        theta = ((th - np.pi) % (2 * np.pi)) - np.pi

        # Update buffer of recent omega values
        self.recent_omegas.append(omega)
        if len(self.recent_omegas) > 10:
            self.recent_omegas.pop(0)

        # Base reward structure
        if abs(theta) < np.pi / 2:
            reward = min(-0.5, -5 + abs(omega))
        elif abs(theta) > np.pi / 2 and abs(theta) < 3 * np.pi / 4:
            reward = abs(theta)**2 / (1 + abs(omega))**1
        elif abs(theta) > 3 * np.pi / 4 and abs(theta) < 11 * np.pi / 12:
            reward = abs(theta)**4 / (1 + abs(omega))**2

            # Add anti-stall penalty
            # Stall detection: angular velocity near zero for several steps
            if all(abs(w) < 0.005 for w in self.recent_omegas):
                reward = 1 / (1 + abs(omega))
        else:
            reward = abs(theta)**4 / (1 + abs(omega))**2
             

        return obs, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        obs, info = super().reset()
        self.recent_omegas = []  # Reset history
        return obs, info


class DQN_UnbalancedDisk(AC_UnbalancedDisk):
    def __init__(self, umax=3., dt=0.025, n_actions=10, render_mode='human'):
        super().__init__(umax=umax, dt=dt, render_mode=render_mode)


        self.actions = np.linspace(-umax, umax, n_actions)

        # Override action space to Discrete
        self.action_space = spaces.Discrete(n_actions)

    def step(self, action):
        idx = int(np.argmin(np.abs(self.actions - action)))
        # Build the 1-D array the parent expects
        discrete_action = np.array([self.torques[idx]], dtype=np.float32)
        obs, reward, terminated, truncated, info = super().step(discrete_action)

        return obs, reward, terminated, truncated, info

    #def reset(self, seed=None, options=None):
    #    obs, info = super().reset()
    #    self.recent_omegas = []  # Reset history
    #    return obs, info

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from gymnasium.wrappers import TimeLimit

# Create and wrap your env
env = DQN_UnbalancedDisk(umax=3.0, dt=0.025)
env = TimeLimit(env, max_episode_steps=500)
env = Monitor(env)

# Instantiate and train DQN
model_dqn = DQN(
    policy="MlpPolicy",
    env=env,
    learning_rate=1e-3,
    buffer_size=50_000,
    learning_starts=1_000,
    batch_size=64,
    tau=1.0,
    gamma=0.99,
    train_freq=4,
    target_update_interval=1_000,
    verbose=1,
)
model_dqn.learn(total_timesteps=500_000)
