In [44]:
from apad_env import APADEnv
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback

In [45]:
import time

class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

# Train

In [46]:
env = APADEnv()
model = DQN(
    "MlpPolicy", 
    env, 
    exploration_initial_eps=1.0,    # Start with 100% random
    exploration_final_eps=0.1,      # End with 10% random  
    exploration_fraction=0.5,       # Take half of training to decay
    learning_rate=1e-3,             # Slightly higher learning rate
    verbose=1
)

model.learn(total_timesteps=50000, callback=TimerCallback())

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Step 1000, 1s elapsed, 25s remaining
Step 2000, 1s elapsed, 24s remaining
Step 3000, 2s elapsed, 24s remaining
Step 4000, 2s elapsed, 23s remaining
Step 5000, 3s elapsed, 23s remaining
Step 6000, 3s elapsed, 22s remaining
Step 7000, 4s elapsed, 22s remaining
Step 8000, 4s elapsed, 21s remaining
Step 9000, 5s elapsed, 21s remaining
Step 10000, 5s elapsed, 20s remaining
Step 11000, 6s elapsed, 20s remaining
Step 12000, 6s elapsed, 20s remaining
Step 13000, 7s elapsed, 19s remaining
Step 14000, 7s elapsed, 19s remaining
Step 15000, 8s elapsed, 18s remaining
Step 16000, 8s elapsed, 18s remaining
Step 17000, 9s elapsed, 17s remaining
Step 18000, 10s elapsed, 17s remaining
Step 19000, 10s elapsed, 16s remaining
Step 20000, 11s elapsed, 16s remaining
Step 21000, 11s elapsed, 15s remaining
Step 22000, 12s elapsed, 15s remaining
Step 23000, 12s elapsed, 14s remaining
Step 24000, 13s elapsed, 14s remain

<stable_baselines3.dqn.dqn.DQN at 0x355263c10>

In [47]:
model.save("apad_dqn_model")