In [1]:
import gymnasium as gym
from stable_baselines3 import PPO
from app.envs.trading_env import StockTradingEnv
import pandas as pd

In [2]:
# Load your enriched dataset
df = pd.read_csv("../data/processed/final_dataset.csv")
df = df[df["Ticker"] == "AADR"].reset_index(drop=True)

In [3]:
# Initialize environment
env = StockTradingEnv(df)

# Wrap environment to make it SB3-compatible
env = gym.wrappers.FlattenObservation(env)

In [4]:
# Define PPO agent
model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=3e-4,
    gamma=0.99,
    n_steps=128,
    ent_coef=0.01,
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [5]:
# Train agent
model.learn(total_timesteps=20_000)

----------------------------
| time/              |     |
|    fps             | 186 |
|    iterations      | 1   |
|    time_elapsed    | 0   |
|    total_timesteps | 128 |
----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 168          |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 256          |
| train/                  |              |
|    approx_kl            | 0.0029227962 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.1         |
|    explained_variance   | 3.24e-05     |
|    learning_rate        | 0.0003       |
|    loss                 | 1.15e+05     |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0102      |
|    value_loss           | 2.07e+05     |
------------------------------------------
-----------------------

<stable_baselines3.ppo.ppo.PPO at 0x10c1b3110>

In [7]:
# Save model
model.save("../models/ppo_trading_agent")
print("Model saved to ../models/ppo_trading_agent.zip")

Model saved to ../models/ppo_trading_agent.zip


In [8]:
#evaluate the trained agent
model = PPO.load("../models/ppo_trading_agent")
obs, _ = env.reset()

for step in range(10):
    action, _ = model.predict(obs)
    obs, reward, done, _, _ = env.step(action)
    env.render()
    if done:
        break


Step: 1, Net Worth: 10000.00, Balance: 10000.00, Shares: 0
Step: 2, Net Worth: 10000.00, Balance: 15.52, Shares: 183.0
Step: 3, Net Worth: 10483.12, Balance: 15.52, Shares: 183.0
Step: 4, Net Worth: 10614.88, Balance: 15.52, Shares: 183.0
Step: 5, Net Worth: 10671.61, Balance: 15.52, Shares: 183.0
Step: 6, Net Worth: 10486.78, Balance: 15.52, Shares: 183.0
Step: 7, Net Worth: 10618.54, Balance: 15.52, Shares: 183.0
Step: 8, Net Worth: 10247.05, Balance: 15.52, Shares: 183.0
Step: 9, Net Worth: 10581.94, Balance: 15.52, Shares: 183.0
Step: 10, Net Worth: 10556.32, Balance: 15.52, Shares: 183.0
