Run A2C model learning


In [1]:
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from env import CustomEnv
from utils.load_data import load_data
import gymnasium as gym

# from gym_anytrading import gym_anytrading

df = load_data('../csv_clean_5m/AAPL.csv')

window_size = 10
start_index = window_size
# end_index = len(df) - start_index
end_index = 1000

# env = gym.make('custom-v0', df=df, window_size=window_size, frame_bound=(start_index, end_index))
env = CustomEnv(
    df=df,
    window_size=window_size,
    frame_bound=(start_index, end_index),
    # render_mode="human",
)

print("observation_space:", env.observation_space)

observation_space: Box(-10000000000.0, 10000000000.0, (10, 5), float64)


In [None]:
from matplotlib import pyplot as plt


# Reset env
env.reset(seed=42)

loadfile = True
agent_file_name = "_new_model"
if loadfile:
    try:
        model = A2C.load(f"{agent_file_name}", env=env)
        print('model loaded successfully')
    except Exception as e:
        # Create the PPO agent
        model = A2C('MlpPolicy', env, verbose=0)
else:
    model = A2C('MlpPolicy', env, verbose=0)

# Train the agent
total_timesteps = 10000
model.learn(total_timesteps=total_timesteps, progress_bar=True)


# Save the trained model with the current date in the filename
model.save(agent_file_name)

model.get_env().unwrapped.env_method('render_all')

In [None]:


# reproduce training and test
import random
import torch
from tqdm import tqdm


model = A2C.load(f"{agent_file_name}", env=env)
vec_env = model.get_env()


print('-' * 80)
seed = 42
obs = env.reset(seed=seed)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

total_num_episodes = 5
tbar = tqdm(range(total_num_episodes))

for episode in tbar:
    obs = vec_env.reset()

    total_reward = 0
    done = False
    while not done:
        action, _states = model.predict(obs)
        obs, reward, done, info = vec_env.step(action)

        total_reward += reward
        if done:
            break

        tbar.set_description(f'Episode: {episode} {reward}')
        tbar.update()

tbar.close()