### Load packages

In [35]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import datetime

from finrl.meta.data_processor import DataProcessor 
from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent 
from stable_baselines3.common.logger import configure
from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline
from argparse import ArgumentParser 

import pyfolio
from pyfolio import timeseries

### Create folders

In [36]:
from finrl import config
from finrl import config_tickers
from finrl.main import check_and_make_directories
from finrl.config import (
    DATA_SAVE_DIR,
    TRAINED_MODEL_DIR,
    TENSORBOARD_LOG_DIR,
    RESULTS_DIR,
    INDICATORS,
)
check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])

In [37]:
TRAIN_START_DATE = '2015-01-01'
TRAIN_END_DATE = '2020-01-01'
TEST_START_DATE = '2020-01-02'
TEST_END_DATE = '2022-12-01'
TRADE_START_DATE = '2022-12-01'
TRADE_END_DATE = '2023-04-01'
TIME_INTERVAL = '1D'

### Download the data(Optional)

In [None]:
YahooDownloader = YahooFinanceProcessor()
df = YahooDownloader.download_data(start_date = TRAIN_START_DATE,
                     end_date = TRADE_END_DATE,
                     ticker_list = config_tickers.DOW_30_TICKER,
                     time_interval=TIME_INTERVAL).fetch_data()

### Load data

In [41]:
# read the stock data from csv file
df = pd.read_csv("../data/processed_data/Test万科A.csv")

In [42]:
df.head()

Unnamed: 0,date,code,open,high,low,close,volume,amount,turn,MACD,RSI,ADX
0,2021-01-04,sz.000002,3268.673434,3268.673434,3167.390595,3197.31507,146844133,4079417000.0,1.5116,0.0,,
1,2021-01-05,sz.000002,3197.31507,3214.579191,3119.051059,3212.277308,116265838,3189607000.0,1.1969,0.335691,,100.0
2,2021-01-06,sz.000002,3199.616953,3308.956381,3191.560364,3308.956381,104880129,2972573000.0,1.0796,3.42575,,66.574839
3,2021-01-07,sz.000002,3337.729915,3395.276983,3267.522493,3313.560147,122675574,3544224000.0,1.2628,4.913429,,64.699809
4,2021-01-08,sz.000002,3335.428032,3389.522276,3315.862029,3376.861921,102856329,3000846000.0,1.0588,8.137296,,63.764863


In [43]:
df.tail()

Unnamed: 0,date,code,open,high,low,close,volume,amount,turn,MACD,RSI,ADX
558,2023-04-24,sz.000002,1977.462102,1977.462102,1916.716534,1925.763746,98062644,1468467000.0,1.0092,-27.172809,41.528239,21.397956
559,2023-04-25,sz.000002,1930.933582,1934.810958,1897.329651,1923.178828,64653059,956930900.0,0.6654,-31.713402,45.126354,21.609224
560,2023-04-26,sz.000002,1899.914568,1921.886369,1880.527685,1912.839157,50426190,743782700.0,0.519,-35.734258,45.620438,22.078458
561,2023-04-27,sz.000002,1912.839157,1932.226041,1897.329651,1928.348664,45595500,677275800.0,0.4693,-37.240047,28.365385,22.217476
562,2023-04-28,sz.000002,1924.471287,1956.782759,1921.886369,1956.782759,60219414,905353100.0,0.6198,-35.727162,37.155963,21.668496


### Build Environment

In [45]:
# price_ary = config["price_array"]
# tech_ary = config["tech_array"]
# turbulence_ary = config["turbulence_array"]
# if_train = config["if_train"]
# env_config should be a dict with keys above
env_config = {"price_array":df,
              "tech_array":None,
              "turbulence_array":None,
              "if_train":True}

env_train = StockTradingEnv(config=env_config)

ValueError: could not convert string to float: '2021-01-04'

### Train DRL Agents

In [None]:
agent = DRLAgent(env = env_train)

PPO_PARAMS = {
    "n_steps": 2048,
    "ent_coef": 0.01,
    "learning_rate": 0.00025,
    "batch_size": 128,
}
model_ppo = agent.get_model("ppo",model_kwargs = PPO_PARAMS)

# set up logger
tmp_path = RESULTS_DIR + '/ppo'
new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
# Set new logger
model_ppo.set_logger(new_logger_ppo)