Skip to content

Commit

Permalink
can draw requested point to graph
Browse files Browse the repository at this point in the history
  • Loading branch information
msaltnet committed Feb 23, 2022
1 parent bff9cef commit a2dc7d7
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 9 deletions.
6 changes: 6 additions & 0 deletions integration_tests/analyzer_ITG_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def test_update_info_func():
}
analyzer.put_trading_info(info)

analyzer.add_drawing_spot("2020-12-21T01:14:00", 26020000.0)
analyzer.add_drawing_spot("2020-12-21T01:14:01", 26182000.0)

requests = [
{
"id": "1621767064.395",
Expand Down Expand Up @@ -242,6 +245,8 @@ def test_update_info_func():
}
analyzer.put_trading_info(info)

analyzer.add_drawing_spot("2020-12-21T01:17:00", 25061000.0)

requests = [
{
"id": "1621767067.473",
Expand Down Expand Up @@ -305,6 +310,7 @@ def test_ITG_analyze_create_report(self):
analyzer.info_list = analyzer_data.get_data("info_list")
analyzer.asset_info_list = analyzer_data.get_data("asset_info_list")
analyzer.score_list = analyzer_data.get_data("score_list")
analyzer.spot_list = analyzer_data.get_data("spot_list")
analyzer.start_asset_info = analyzer.asset_info_list[0]

if os.path.isfile(analyzer.OUTPUT_FOLDER + "test_report.jpg"):
Expand Down
21 changes: 21 additions & 0 deletions integration_tests/data/analyzer_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,29 @@ def get_data(name):
return asset_info_list
elif name == "score_list":
return score_list
elif name == "spot_list":
return spot_list


spot_list = [
{
"date_time": "2020-12-20T17:01:00",
"value": 26020000.0,
},
{
"date_time": "2020-12-20T17:01:01",
"value": 26020000.0,
},
{
"date_time": "2020-12-20T17:03:00",
"value": 26029000.0,
},
{
"date_time": "2020-12-20T17:07:00",
"value": 26039000.0,
},
]

request_list = [
{
"id": "1623162294.356",
Expand Down
66 changes: 58 additions & 8 deletions smtm/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,23 @@ def get_return_report(self, graph_filename=None, index_info=None):
score_list = self.score_list
info_list = self.info_list
result_list = self.result_list
spot_list = self.spot_list

if index_info is not None:
interval_data = self.__make_interval_data(index_info)
asset_info_list = interval_data[0]
score_list = interval_data[1]
info_list = interval_data[2]
result_list = interval_data[3]
spot_list = interval_data[4]

return self.__get_return_report(
asset_info_list, score_list, info_list, result_list, graph_filename=graph_filename
asset_info_list,
score_list,
info_list,
result_list,
graph_filename=graph_filename,
spot_list=spot_list,
)

def __make_interval_data(self, index_info):
Expand All @@ -321,11 +328,13 @@ def __make_interval_data(self, index_info):
score_list = []
asset_info_list = []
result_list = []
spot_list = []
self.__make_filtered_list(start_dt, end_dt, score_list, self.score_list)
self.__make_filtered_list(start_dt, end_dt, asset_info_list, self.asset_info_list)
self.__make_filtered_list(start_dt, end_dt, result_list, self.result_list)
self.__make_filtered_list(start_dt, end_dt, spot_list, self.spot_list)

return (asset_info_list, score_list, info_list, result_list)
return (asset_info_list, score_list, info_list, result_list, spot_list)

@staticmethod
def _get_min_max_return(score_list):
Expand All @@ -342,7 +351,13 @@ def __make_filtered_list(start_dt, end_dt, dest, source):
dest.append(target)

def __get_return_report(
self, asset_info_list, score_list, info_list, result_list, graph_filename=None
self,
asset_info_list,
score_list,
info_list,
result_list,
graph_filename=None,
spot_list=None,
):
try:
graph = None
Expand All @@ -353,7 +368,12 @@ def __get_return_report(
min_max = self._get_min_max_return(score_list)
if graph_filename is not None:
graph = self.__draw_graph(
info_list, result_list, score_list, graph_filename, is_fullpath=True
info_list,
result_list,
score_list,
graph_filename,
is_fullpath=True,
spot_list=spot_list,
)
period = info_list[0]["date_time"] + " - " + info_list[-1]["date_time"]
summary = (
Expand Down Expand Up @@ -429,7 +449,9 @@ def create_report(self, tag="untitled-report"):
),
)
self.__create_report_file(tag, summary, trading_table)
self.__draw_graph(self.info_list, self.result_list, self.score_list, tag)
self.__draw_graph(
self.info_list, self.result_list, self.score_list, tag, spot_list=self.spot_list
)
return {"summary": summary, "trading_table": trading_table}
except (IndexError, AttributeError):
self.logger.error("create report FAIL")
Expand Down Expand Up @@ -503,9 +525,22 @@ def _get_rss_memory():
process = psutil.Process()
return process.memory_info().rss / 2 ** 20 # Bytes to MB

def __create_plot_data(self, info_list, result_list, score_list):
def __get_spot_info(self, spot_list, start_pos, ref_time):
spot_pos = start_pos
spot_info = None
while spot_pos < len(spot_list):
spot = spot_list[spot_pos]
spot_time = datetime.strptime(spot["date_time"], self.ISO_DATEFORMAT)
if ref_time < spot_time:
break
spot_info = spot["value"]
spot_pos += 1
return spot_info, spot_pos

def __create_plot_data(self, info_list, result_list, score_list, spot_list=None):
result_pos = 0
score_pos = 0
spot_pos = 0
last_avr_price = None
last_acc_return = 0
plot_data = []
Expand All @@ -528,6 +563,13 @@ def __create_plot_data(self, info_list, result_list, score_list):
new["sell"] = result["price"]
result_pos += 1

# 추가 spot 정보를 생성해서 추가. 없는 경우 추가 안함. 기간내 하나만 추가됨
if spot_list is not None:
spot_info = self.__get_spot_info(spot_list, spot_pos, info_time)
if spot_info[0] is not None:
new["spot"] = spot_info[0]
spot_pos = spot_info[1]

# 수익률 정보를 추가. 정보가 없는 경우 최근 정보로 채움
while score_pos < len(score_list):
score = score_list[score_pos]
Expand All @@ -553,8 +595,10 @@ def __create_plot_data(self, info_list, result_list, score_list):
plot_data.append(new)
return pd.DataFrame(plot_data)[-self.GRAPH_MAX_COUNT :]

def __draw_graph(self, info_list, result_list, score_list, filename, is_fullpath=False):
total = self.__create_plot_data(info_list, result_list, score_list)
def __draw_graph(
self, info_list, result_list, score_list, filename, is_fullpath=False, spot_list=None
):
total = self.__create_plot_data(info_list, result_list, score_list, spot_list=spot_list)
total = total.rename(
columns={
"date_time": "Date",
Expand All @@ -577,6 +621,12 @@ def __draw_graph(self, info_list, result_list, score_list, filename, is_fullpath
apds.append(mpf.make_addplot(total["avr_price"]))
if "return" in total.columns:
apds.append(mpf.make_addplot((total["return"]), panel=1, color="g", secondary_y=True))
if "spot" in total.columns:
apds.append(
mpf.make_addplot(
(total["spot"]), type="scatter", markersize=50, marker=".", color="g"
)
)

destination = self.OUTPUT_FOLDER + filename + ".jpg"
if is_fullpath:
Expand Down
43 changes: 42 additions & 1 deletion tests/analyzer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,18 @@ def fill_test_data_for_report(self, analyzer):
}
analyzer.result_list.append(dummy_result)

dummy_spot = {
"value": 5900,
"date_time": "2020-02-23T00:00:00",
}
analyzer.spot_list.append(dummy_spot)

dummy_spot2 = {
"value": 6900,
"date_time": "2020-02-23T00:00:01",
}
analyzer.spot_list.append(dummy_spot2)

dummy_asset_info = {
"balance": 23456,
"asset": {},
Expand Down Expand Up @@ -507,6 +519,12 @@ def fill_test_data_for_report(self, analyzer):
}
analyzer.result_list.append(dummy_result3)

dummy_spot3 = {
"value": 8888,
"date_time": "2020-02-23T00:01:00",
}
analyzer.spot_list.append(dummy_spot3)

target_dummy_asset2 = {
"balance": 5000,
"asset": {"mango": (600, 4.23), "apple": (500, 3.11)},
Expand Down Expand Up @@ -658,8 +676,9 @@ def test_get_return_report_return_correct_report_with_index(self, mock_plot):

analyzer.update_asset_info.assert_called()

@patch("mplfinance.make_addplot")
@patch("mplfinance.plot")
def test_get_return_report_draw_graph_when_graph_filename_exist(self, mock_plot):
def test_get_return_report_draw_graph_when_graph_filename_exist(self, mock_plot, mock_addplot):
"""
{
cumulative_return: 기준 시점부터 누적 수익률
Expand Down Expand Up @@ -709,6 +728,28 @@ def test_get_return_report_draw_graph_when_graph_filename_exist(self, mock_plot)
savefig=dict(fname="mango_graph.png", dpi=300, pad_inches=0.25),
figscale=1.25,
)
self.assertEqual(len(mock_addplot.call_args_list), 5)
self.assertEqual(mock_addplot.call_args_list[0][1]["type"], "scatter")
self.assertEqual(mock_addplot.call_args_list[0][1]["markersize"], 100)
self.assertEqual(mock_addplot.call_args_list[0][1]["marker"], "^")
self.assertEqual(mock_addplot.call_args_list[1][1]["type"], "scatter")
self.assertEqual(mock_addplot.call_args_list[1][1]["markersize"], 100)
self.assertEqual(mock_addplot.call_args_list[1][1]["marker"], "v")
self.assertEqual(mock_addplot.call_args_list[3][1]["panel"], 1)
self.assertEqual(mock_addplot.call_args_list[3][1]["color"], "g")
self.assertEqual(mock_addplot.call_args_list[3][1]["secondary_y"], True)
self.assertEqual(mock_addplot.call_args_list[4][1]["type"], "scatter")
self.assertEqual(mock_addplot.call_args_list[4][1]["markersize"], 50)
self.assertEqual(mock_addplot.call_args_list[4][1]["marker"], ".")
self.assertEqual(mock_addplot.call_args_list[0][0][0][0], 5000)
self.assertEqual(mock_addplot.call_args_list[0][0][0][1], 5000)
self.assertEqual(mock_addplot.call_args_list[1][0][0][1], 6000.0)
self.assertEqual(mock_addplot.call_args_list[2][0][0][0], 500)
self.assertEqual(mock_addplot.call_args_list[2][0][0][1], 600)
self.assertEqual(mock_addplot.call_args_list[3][0][0][0], -65.248)
self.assertEqual(mock_addplot.call_args_list[3][0][0][1], -75.067)
self.assertEqual(mock_addplot.call_args_list[4][0][0][0], 5900)
self.assertEqual(mock_addplot.call_args_list[4][0][0][1], 8888)

@patch("pandas.to_datetime")
@patch("pandas.DataFrame")
Expand Down

0 comments on commit a2dc7d7

Please sign in to comment.