# 结果展示

In [None]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import datetime
import time
import os

from utils import config
from utils.backtest import *

## 数据读取

In [None]:
root_path = os.getcwd()
root_path = os.path.join(root_path, "learn", "trade_file")
print("root_path: ", root_path)
data_file_path = os.path.join(os.getcwd(), "learn", "data_file")
mode = 1
if mode == 1:
    test_file = os.path.join(data_file_path, "stock_info_test.csv")
else:
    test_file = os.path.join(data_file_path, "trade.csv")
test_pd = pd.read_csv(test_file)
start_date = test_pd['date'].min().replace("-", "")
end_date = test_pd['date'].max().replace("-", "")
print(start_date, end_date)

In [None]:
model_list = ["a2c", "ppo", "td3", "ddpg", "sac"]
path_dict = {}
for model in model_list:
    path_dict[model] = os.path.join(root_path, "account_value_{}.csv".format(model))

In [None]:
account_value_dict = {}
for m in model_list:
    account_value_dict[m] = pd.read_csv(path_dict[m])

## 获取 baseline 的结果

In [None]:
get_baseline_func = lambda: get_baseline(config.SSE_50_INDEX, 
              start=start_date,
              end=end_date)
baseline_csv_file = root_path + os.sep + f"baseline_sse50_{start_date}_{end_date}.csv"
baseline_df = get_baseline_from_file(baseline_csv_file, get_baseline_func)

In [None]:
baseline_stats = backtest_stats(baseline_df, value_col_name='close')

In [None]:
start_close_value = baseline_df.iloc[0]['close']
baseline_df['processed_close'] = ((baseline_df['close'] - start_close_value)/start_close_value + 1) * 1e+6

In [None]:
baseline_df.head()

## 展示最终的结果

In [None]:
data = {
    m: account_value_dict[m]['total_assets'] for m in model_list
}
data['baseline'] = baseline_df['processed_close']
result_account_value = pd.DataFrame(data = data)

In [None]:
result_account_value = result_account_value.iloc[:-1].apply(lambda x : (x - 1e+6)/1e+6)

In [None]:
result_account_value.tail()

In [None]:
result_account_value.to_csv(os.path.join(root_path, "result_account_value.csv"), index=False)

In [None]:
plt.figure(figsize=(10, 6))
for col in result_account_value.columns:
    plt.plot(result_account_value.index, result_account_value[col], label=col)

plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Account Value Comparison')
plt.legend()
plt.show()

## 展示回测结果

In [None]:
print("和 {} 指数进行比较".format(config.SSE_50_INDEX[0]))
cmp_data = backtest_plot_from_file(baseline_csv_file, get_baseline_func, 
                        account_value_dict,
                        value_col_name = 'total_assets')
df = pd.DataFrame(cmp_data).T
df