In [None]:
from tqdm import tqdm
import random

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import gymnasium as gym
import quantstats as qs
# import gym_anytrading

import sys
sys.path.append("C:/Users/WilliamFetzner/Documents/Trading/gym-anytrading/gym_anytrading")
import importlib
%load_ext autoreload
%autoreload 2
from envs import MyForexEnv, Actions
# sys.path.append("C:/Users/WilliamFetzner/Documents/Trading/gym-anytrading/")
# from datasets import FOREX_EURUSD_RENKO
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.callbacks import BaseCallback
import torch

In [None]:
import os
name = 'FOREX_EURUSD_RENKO'
# base_dir = os.path.dirname(os.path.abspath(__file__))
path = "C:/Users/WilliamFetzner/Documents/Trading/gym-anytrading/gym_anytrading/datasets/data/renko_full_data_81.csv"
# path = os.path.join(base_dir, '.csv')
FOREX_EURUSD_RENKO = pd.read_csv(path, parse_dates=True, index_col='datetime')

In [None]:
def print_stats(reward_over_episodes):
    """  Print Reward  """

    avg = np.mean(reward_over_episodes)
    min = np.min(reward_over_episodes)
    max = np.max(reward_over_episodes)

    print (f'Min. Reward          : {min:>10.3f}')
    print (f'Avg. Reward          : {avg:>10.3f}')
    print (f'Max. Reward          : {max:>10.3f}')

    return min, avg, max


# ProgressBarCallback for model.learn()
class ProgressBarCallback(BaseCallback):

    def __init__(self, check_freq: int, verbose: int = 1):
        super().__init__(verbose)
        self.check_freq = check_freq

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        """
        self.progress_bar = tqdm(total=self.model._total_timesteps, desc="model.learn()")

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:
            self.progress_bar.update(self.check_freq)
        return True
    
    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        """
        self.progress_bar.close()

In [None]:
env_train = MyForexEnv(
    df=FOREX_EURUSD_RENKO,
    window_size=10,
    frame_bound=(10, int(len(FOREX_EURUSD_RENKO)*.8)),
    trade_fee=0.0001,
    unit_side='right',
    sma_length=4,
    smoothing_sma=4
)

In [None]:
seed_ppo = 42  # random seed
obs_ppo, info_ppo = env_train.reset(seed=seed_ppo)
torch.manual_seed(seed_ppo)
random.seed(seed_ppo)
np.random.seed(seed_ppo)

vec_env_ppo = None

total_learning_timesteps_ppo = 5000 * 1000 # 5000
policy_dict_ppo = PPO.policy_aliases
policy_ppo = policy_dict_ppo.get('MlpPolicy')
if policy_ppo is None:
    policy_ppo = policy_dict_ppo.get('MlpLstmPolicy')

model_ppo = PPO(policy_ppo, env_train, verbose=0)

print(f'model {type(model_ppo)}')
print(f'policy {type(model_ppo.policy)}')

# custom callback for 'progress_bar'
model_ppo.learn(total_timesteps=total_learning_timesteps_ppo, callback=ProgressBarCallback(100))

env_train.close()

In [None]:
env_test = MyForexEnv(
    df=FOREX_EURUSD_RENKO,
    window_size=10,
    frame_bound=(int(len(FOREX_EURUSD_RENKO)*.8), int(len(FOREX_EURUSD_RENKO)*.99)),
    trade_fee=0.0001,
    unit_side='right',
    sma_length=4,
    smoothing_sma=4
)

In [None]:
done = False
action_stats = {Actions.Sell: 0, Actions.Buy: 0}
observation, info_ppo = env_test.reset(seed=seed_ppo)

while not done:
    action_ppo, _states = model_ppo.predict(obs_ppo)
    action_stats[Actions(action_ppo)] += 1
    obs_ppo, reward_ppo, terminated, truncated, info_ppo = env_test.step(action_ppo)
    done = terminated or truncated

    if done:
        break

env_test.close()

print("action_stats:", action_stats)
print("info:", info_ppo)

In [None]:
plt.figure(figsize=(16, 6))
env_test.unwrapped.render_all()
plt.show()

In [None]:
qs.extend_pandas()
window_size = 10
start_index = int(len(FOREX_EURUSD_RENKO)*.8)
end_index = int(len(FOREX_EURUSD_RENKO)*.99)

net_worth = pd.Series(env_test.unwrapped.history['total_profit'], index=FOREX_EURUSD_RENKO.index[start_index+1:end_index])
returns = net_worth.pct_change().iloc[1:]

qs.reports.full(returns)
qs.reports.html(returns, output='SB3_a2c_quantstats_ppo_4_3.html')