From c64c2215e073283baadbfad7e529c2da5137e9fd Mon Sep 17 00:00:00 2001 From: msaltnet Date: Sat, 5 Aug 2023 20:59:31 +0900 Subject: [PATCH] can execute simulation with 3min candle data --- integration_tests/__init__.py | 10 +- .../data/mass_simulation_3m_config.json | 18 ++++ integration_tests/data/simulation_data.py | 92 +++++++++++++++++- integration_tests/mass_simulator_ITG_test.py | 11 ++- .../simulation_operator_ITG_test.py | 96 ++++++++++++++++++- .../simulation_trader_ITG_test.py | 7 +- integration_tests/simulator_ITG_test.py | 7 +- smtm/__init__.py | 1 + smtm/config.py | 2 + smtm/mass_simulator.py | 12 ++- smtm/simulation_data_provider.py | 7 +- smtm/simulation_trader.py | 4 +- smtm/simulator.py | 11 ++- smtm/virtual_market.py | 8 +- tests/mass_simulator_test.py | 19 +--- tests/simulation_data_provider_test.py | 6 -- tests/simulator_test.py | 7 +- tests/virtual_market_test.py | 14 +++ 18 files changed, 280 insertions(+), 52 deletions(-) create mode 100644 integration_tests/data/mass_simulation_3m_config.json create mode 100644 smtm/config.py diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py index 0d8597d..c47e75d 100644 --- a/integration_tests/__init__.py +++ b/integration_tests/__init__.py @@ -4,10 +4,16 @@ from .analyzer_ITG_test import AnalyzerIntegrationTests from .bithumb_data_provider_ITG_test import BithumbDataProviderIntegrationTests from .operator_ITG_test import OperatorIntegrationTests -from .simulation_operator_ITG_test import SimulationOperatorIntegrationTests +from .simulation_operator_ITG_test import ( + SimulationOperatorIntegrationTests, + SimulationOperator3mIntervalIntegrationTests, +) from .simulation_trader_ITG_test import SimulationTraderIntegrationTests from .strategy_bnh_ITG_test import StrategyBuyAndHoldIntegrationTests from .upbit_data_provider_ITG_test import UpbitDataProviderIntegrationTests from .simulator_ITG_test import SimulatorIntegrationTests from .data_repository_ITG_test import DataRepositoryIntegrationTests -from .mass_simulator_ITG_test import MassSimulatorIntegrationTests +from .mass_simulator_ITG_test import ( + MassSimulatorIntegrationTests, + MassSimulator3mIntervalIntegrationTests, +) diff --git a/integration_tests/data/mass_simulation_3m_config.json b/integration_tests/data/mass_simulation_3m_config.json new file mode 100644 index 0000000..338fe1d --- /dev/null +++ b/integration_tests/data/mass_simulation_3m_config.json @@ -0,0 +1,18 @@ +{ + "title": "SMA-1Hour-3min-interval", + "budget": 500000, + "strategy": "BNH", + "interval": 0.5, + "currency": "BTC", + "description": "mass-simulation-integration-test", + "period_list": [ + { + "start": "2020-04-30T17:00:00", + "end": "2020-04-30T18:00:00" + }, + { + "start": "2020-04-30T17:30:00", + "end": "2020-04-30T18:30:00" + } + ] +} \ No newline at end of file diff --git a/integration_tests/data/simulation_data.py b/integration_tests/data/simulation_data.py index 1be4d26..a3d29a9 100644 --- a/integration_tests/data/simulation_data.py +++ b/integration_tests/data/simulation_data.py @@ -1,10 +1,100 @@ def get_data(name): if name == "bnh_snapshot": return bnh_snapshot + elif name == "bnh_3m_snapshot": + return bnh_3m_snapshot elif name == "sma0_snapshot": return sma0_snapshot +bnh_3m_snapshot = [ + { + "request": { + "id": "1691237535945.211215", + "type": "buy", + "price": 11288000.0, + "amount": 0.0017, + "date_time": "2020-04-30T14:48:00", + }, + "type": "buy", + "price": 11288000.0, + "amount": 0.0017, + "msg": "success", + "balance": 80801, + "state": "done", + "date_time": "2020-04-30T15:56:00", + "kind": 2, + }, + { + "request": { + "id": "1691237535948.211215", + "type": "buy", + "price": 11313000.0, + "amount": 0.0017, + "date_time": "2020-04-30T14:51:00", + }, + "type": "buy", + "price": 11313000.0, + "amount": 0.0017, + "msg": "success", + "balance": 61559, + "state": "done", + "date_time": "2020-04-30T15:57:00", + "kind": 2, + }, + { + "request": { + "id": "1691237535952.211215", + "type": "buy", + "price": 11351000.0, + "amount": 0.0017, + "date_time": "2020-04-30T14:54:00", + }, + "type": "buy", + "price": 11351000.0, + "amount": 0.0017, + "msg": "success", + "balance": 42253, + "state": "done", + "date_time": "2020-04-30T15:58:00", + "kind": 2, + }, + { + "request": { + "id": "1691237535956.211215", + "type": "buy", + "price": 11341000.0, + "amount": 0.0017, + "date_time": "2020-04-30T14:57:00", + }, + "type": "buy", + "price": 11341000.0, + "amount": 0.0017, + "msg": "success", + "balance": 22964, + "state": "done", + "date_time": "2020-04-30T15:59:00", + "kind": 2, + }, + { + "request": { + "id": "1691237535962.211215", + "type": "buy", + "price": 11325000.0, + "amount": 0.0017, + "date_time": "2020-04-30T15:00:00", + }, + "type": "buy", + "price": 11325000.0, + "amount": 0.0017, + "msg": "success", + "balance": 3702, + "state": "done", + "date_time": "2020-04-30T16:00:00", + "kind": 2, + }, +] + bnh_snapshot = [ { "request": { @@ -298,4 +388,4 @@ def get_data(name): "date_time": "2020-04-30T15:53:00", "kind": 2, }, -] \ No newline at end of file +] diff --git a/integration_tests/mass_simulator_ITG_test.py b/integration_tests/mass_simulator_ITG_test.py index 1c49be5..a03e794 100644 --- a/integration_tests/mass_simulator_ITG_test.py +++ b/integration_tests/mass_simulator_ITG_test.py @@ -1,6 +1,6 @@ import time import unittest -from smtm import MassSimulator +from smtm import MassSimulator, Config from unittest.mock import * @@ -16,3 +16,12 @@ def test_ITG_run_single_simulation(self, mock_print): mass = MassSimulator() mass.run("integration_tests/data/mass_simulation_config.json") + + +class MassSimulator3mIntervalIntegrationTests(unittest.TestCase): + # It should be executed after set Config.candle_interval = 180 + + def test_ITG_run_single_simulation(self): + mass = MassSimulator() + + mass.run("integration_tests/data/mass_simulation_3m_config.json") diff --git a/integration_tests/simulation_operator_ITG_test.py b/integration_tests/simulation_operator_ITG_test.py index fbb9ffc..88066bd 100644 --- a/integration_tests/simulation_operator_ITG_test.py +++ b/integration_tests/simulation_operator_ITG_test.py @@ -7,6 +7,7 @@ StrategyFactory, Analyzer, LogManager, + Config, ) from .data import simulation_data from unittest.mock import * @@ -15,9 +16,11 @@ class SimulationOperatorIntegrationTests(unittest.TestCase): def setUp(self): LogManager.set_stream_level(20) + self.interval = Config.candle_interval + Config.candle_interval = 60 def tearDown(self): - pass + Config.candle_interval = self.interval def test_ITG_run_simulation_with_bnh_strategy(self): trading_snapshot = simulation_data.get_data("bnh_snapshot") @@ -30,7 +33,7 @@ def test_ITG_run_simulation_with_bnh_strategy(self): time_limit = 15 end_str = "2020-04-30T16:30:00" - data_provider = SimulationDataProvider() + data_provider = SimulationDataProvider(interval=Config.candle_interval) data_provider.initialize_simulation(end=end_str, count=count) trader = SimulationTrader() trader.initialize_simulation(end=end_str, count=count, budget=budget) @@ -265,3 +268,92 @@ def callback(return_report): self.assertIsNotNone(report) self.assertEqual(report[0], 100000) + + +class SimulationOperator3mIntervalIntegrationTests(unittest.TestCase): + def setUp(self): + LogManager.set_stream_level(20) + self.interval = Config.candle_interval + Config.candle_interval = 180 + + def tearDown(self): + Config.candle_interval = self.interval + + def test_ITG_run_simulation_with_bnh_strategy(self): + trading_snapshot = simulation_data.get_data("bnh_3m_snapshot") + operator = SimulationOperator() + strategy = StrategyFactory.create("BNH") + strategy.is_simulation = True + count = 100 + budget = 100000 + interval = 0.0001 + time_limit = 15 + end_str = "2020-04-30T16:30:00" + + data_provider = SimulationDataProvider(interval=Config.candle_interval) + data_provider.initialize_simulation(end=end_str, count=count) + trader = SimulationTrader() + trader.initialize_simulation(end=end_str, count=count, budget=budget) + analyzer = Analyzer() + analyzer.is_simulation = True + + operator.initialize( + data_provider, + strategy, + trader, + analyzer, + budget=budget, + ) + + operator.set_interval(interval) + operator.start() + start_time = time.time() + while operator.state == "running": + time.sleep(0.5) + if time.time() - start_time > time_limit: + self.assertTrue(False, "Time out") + break + + trading_results = operator.get_trading_results() + + self.check_equal_results_list(trading_results, trading_snapshot) + waiting = True + start_time = time.time() + report = None + + def callback(return_report): + nonlocal report + nonlocal waiting + report = return_report + waiting = False + self.assertFalse(waiting) + + operator.get_score(callback) + + while waiting: + time.sleep(0.5) + if time.time() - start_time > time_limit: + self.assertTrue(False, "Time out") + break + + self.assertIsNotNone(report) + self.assertEqual(report[0], 100000) + self.assertEqual(report[1], 97066) + self.assertEqual(report[2], -2.934) + self.assertEqual(report[3]["KRW-BTC"], -2.693) + + def check_equal_results_list(self, a, b): + self.assertEqual(len(a), len(b)) + for i in range(len(a)): + self.assertEqual(a[i]["request"]["type"], b[i]["request"]["type"]) + self.assertEqual(a[i]["request"]["price"], b[i]["request"]["price"]) + self.assertEqual(a[i]["request"]["amount"], b[i]["request"]["amount"]) + self.assertEqual(a[i]["request"]["date_time"], b[i]["request"]["date_time"]) + + self.assertEqual(a[i]["type"], b[i]["type"]) + self.assertEqual(a[i]["price"], b[i]["price"]) + self.assertEqual(a[i]["amount"], b[i]["amount"]) + self.assertEqual(a[i]["msg"], b[i]["msg"]) + self.assertEqual(a[i]["balance"], b[i]["balance"]) + self.assertEqual(a[i]["date_time"], b[i]["date_time"]) + self.assertEqual(a[i]["kind"], b[i]["kind"]) diff --git a/integration_tests/simulation_trader_ITG_test.py b/integration_tests/simulation_trader_ITG_test.py index a958adb..19a4f68 100644 --- a/integration_tests/simulation_trader_ITG_test.py +++ b/integration_tests/simulation_trader_ITG_test.py @@ -1,14 +1,15 @@ import unittest -from smtm import SimulationTrader +from smtm import SimulationTrader, Config from unittest.mock import * class SimulationTraderIntegrationTests(unittest.TestCase): def setUp(self): - pass + self.interval = Config.candle_interval + Config.candle_interval = 60 def tearDown(self): - pass + Config.candle_interval = self.interval def test_ITG_simulation_trader_full(self): trader = SimulationTrader() diff --git a/integration_tests/simulator_ITG_test.py b/integration_tests/simulator_ITG_test.py index 75c1b1d..9c78f22 100644 --- a/integration_tests/simulator_ITG_test.py +++ b/integration_tests/simulator_ITG_test.py @@ -1,16 +1,17 @@ import time import unittest -from smtm import Simulator +from smtm import Simulator, Config from .data import simulation_data from unittest.mock import * class SimulatorIntegrationTests(unittest.TestCase): def setUp(self): - pass + self.interval = Config.candle_interval + Config.candle_interval = 60 def tearDown(self): - pass + Config.candle_interval = self.interval @patch("builtins.print") def test_ITG_run_single_simulation(self, mock_print): diff --git a/smtm/__init__.py b/smtm/__init__.py index 8387279..72cf66e 100644 --- a/smtm/__init__.py +++ b/smtm/__init__.py @@ -1,6 +1,7 @@ """ Description for Package """ +from .config import Config from .date_converter import DateConverter from .operator import Operator from .log_manager import LogManager diff --git a/smtm/config.py b/smtm/config.py new file mode 100644 index 0000000..dc71b45 --- /dev/null +++ b/smtm/config.py @@ -0,0 +1,2 @@ +class Config: + candle_interval = 60 diff --git a/smtm/mass_simulator.py b/smtm/mass_simulator.py index 82f89b8..dcbe6b0 100644 --- a/smtm/mass_simulator.py +++ b/smtm/mass_simulator.py @@ -14,6 +14,7 @@ import pandas as pd import matplotlib.pyplot as plt +from .config import Config from .log_manager import LogManager from .analyzer import Analyzer from .strategy_factory import StrategyFactory @@ -52,7 +53,7 @@ def memory_usage(): """현재 프로세스의 이름과 메모리 사용양을 화면에 출력""" # current process RAM usage process = psutil.Process() - rss = process.memory_info().rss / 2 ** 20 # Bytes to MB + rss = process.memory_info().rss / 2**20 # Bytes to MB print(f"[{current_process().name}] memory usage: {rss: 10.5f} MB") # print(f"[{current_process().name}] memory usage: {p.memory_info().rss} MB") @@ -118,11 +119,12 @@ def get_score_callback(report): @staticmethod def get_initialized_operator(budget, strategy_code, interval, currency, start, end, tag): """시뮬레이션 오퍼레이션 생성 후 주어진 설정 값으로 초기화 하여 반환""" - dt = DateConverter.to_end_min(start_iso=start, end_iso=end) + dt = DateConverter.to_end_min( + start_iso=start, end_iso=end, interval_min=Config.candle_interval / 60 + ) end = dt[0][1] count = dt[0][2] - - data_provider = SimulationDataProvider(currency=currency) + data_provider = SimulationDataProvider(currency=currency, interval=Config.candle_interval) data_provider.initialize_simulation(end=end, count=count) strategy = StrategyFactory.create(strategy_code) @@ -130,7 +132,7 @@ def get_initialized_operator(budget, strategy_code, interval, currency, start, e raise UserWarning(f"Invalid Strategy! {strategy_code}") strategy.is_simulation = True - trader = SimulationTrader(currency=currency) + trader = SimulationTrader(currency=currency, interval=Config.candle_interval) trader.initialize_simulation(end=end, count=count, budget=budget) analyzer = Analyzer() diff --git a/smtm/simulation_data_provider.py b/smtm/simulation_data_provider.py index 69e643c..cbd4cb4 100644 --- a/smtm/simulation_data_provider.py +++ b/smtm/simulation_data_provider.py @@ -11,11 +11,12 @@ class SimulationDataProvider(DataProvider): AVAILABLE_CURRENCY = {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"} - def __init__(self, currency="BTC"): + def __init__(self, currency="BTC", interval=60): if currency not in self.AVAILABLE_CURRENCY: raise UserWarning(f"not supported currency: {currency}") self.logger = LogManager.get_logger(__class__.__name__) - self.repo = DataRepository("smtm.db") + self.repo = DataRepository("smtm.db", interval=interval) + self.interval_min = interval / 60 self.data = [] self.index = 0 @@ -26,7 +27,7 @@ def initialize_simulation(self, end=None, count=100): self.index = 0 end_dt = datetime.strptime(end, "%Y-%m-%dT%H:%M:%S") - start_dt = end_dt - timedelta(minutes=count) + start_dt = end_dt - timedelta(minutes=count * self.interval_min) start = start_dt.strftime("%Y-%m-%dT%H:%M:%S") self.data = self.repo.get_data(start, end, market=self.market) diff --git a/smtm/simulation_trader.py b/smtm/simulation_trader.py index 8726250..86bea53 100644 --- a/smtm/simulation_trader.py +++ b/smtm/simulation_trader.py @@ -18,11 +18,11 @@ class SimulationTrader(Trader): AVAILABLE_CURRENCY = {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"} NAME = "Simulation" - def __init__(self, currency="BTC"): + def __init__(self, currency="BTC", interval=60): if currency not in self.AVAILABLE_CURRENCY: raise UserWarning(f"not supported currency: {currency}") self.logger = LogManager.get_logger(__class__.__name__) - self.market = VirtualMarket(market=self.AVAILABLE_CURRENCY[currency]) + self.market = VirtualMarket(market=self.AVAILABLE_CURRENCY[currency], interval=interval) self.is_initialized = False def initialize_simulation(self, end, count, budget): diff --git a/smtm/simulator.py b/smtm/simulator.py index 6ba6888..61ecc06 100644 --- a/smtm/simulator.py +++ b/smtm/simulator.py @@ -2,6 +2,7 @@ import signal import time +from .config import Config from .log_manager import LogManager from .analyzer import Analyzer from .simulation_operator import SimulationOperator @@ -142,7 +143,9 @@ def __init__( def initialize(self): """시뮬레이션 초기화""" - dt = DateConverter.to_end_min(self.start_str + "-" + self.end_str) + dt = DateConverter.to_end_min( + self.start_str + "-" + self.end_str, interval_min=Config.candle_interval / 60 + ) end = dt[0][1] count = dt[0][2] @@ -154,9 +157,11 @@ def initialize(self): self.operator = SimulationOperator() self._print_configuration(strategy.NAME) - data_provider = SimulationDataProvider(currency=self.currency) + data_provider = SimulationDataProvider( + currency=self.currency, interval=Config.candle_interval + ) data_provider.initialize_simulation(end=end, count=count) - trader = SimulationTrader(currency=self.currency) + trader = SimulationTrader(currency=self.currency, interval=Config.candle_interval) trader.initialize_simulation(end=end, count=count, budget=self.budget) analyzer = Analyzer() analyzer.is_simulation = True diff --git a/smtm/virtual_market.py b/smtm/virtual_market.py index e256e75..c14b05e 100644 --- a/smtm/virtual_market.py +++ b/smtm/virtual_market.py @@ -1,5 +1,6 @@ """업비트 거래소의 과거 거래 정보를 이용한 가상 거래소 역할의 VirtualMarket 클래스""" from datetime import datetime, timedelta +from .config import Config from .data_repository import DataRepository from .log_manager import LogManager @@ -19,9 +20,9 @@ class VirtualMarket: URL = "https://api.upbit.com/v1/candles/minutes/1" - def __init__(self, market="KRW-BTC"): + def __init__(self, market="KRW-BTC", interval=60): self.logger = LogManager.get_logger(__class__.__name__) - self.repo = DataRepository("smtm.db") + self.repo = DataRepository("smtm.db", interval=interval) self.data = None self.turn_count = 0 self.balance = 0 @@ -29,6 +30,7 @@ def __init__(self, market="KRW-BTC"): self.asset = {} self.is_initialized = False self.market = market + self.interval = interval def initialize(self, end=None, count=100, budget=0): """ @@ -38,7 +40,7 @@ def initialize(self, end=None, count=100, budget=0): count: 거래기간까지 가져올 데이터의 갯수 """ end_dt = datetime.strptime(end, "%Y-%m-%dT%H:%M:%S") - start_dt = end_dt - timedelta(minutes=count) + start_dt = end_dt - timedelta(minutes=count * (self.interval / 60)) start = start_dt.strftime("%Y-%m-%dT%H:%M:%S") self.data = self.repo.get_data(start, end, market=self.market) self.balance = budget diff --git a/tests/mass_simulator_test.py b/tests/mass_simulator_test.py index ce5da7f..22e9b46 100644 --- a/tests/mass_simulator_test.py +++ b/tests/mass_simulator_test.py @@ -1,7 +1,7 @@ import unittest from datetime import datetime from datetime import timedelta -from smtm import MassSimulator +from smtm import MassSimulator, Config from unittest.mock import * @@ -25,12 +25,6 @@ def test_memory_usage_should_print_correctly(self, mock_process, mock_print): class MassSimulatorAnalyzeTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - @patch("builtins.open", new_callable=mock_open) def test_analyze_result_should_call_file_write_correctly(self, mock_file): mass = MassSimulator() @@ -98,10 +92,11 @@ def test_draw_graph_should_call_plt_correctly(self, mock_savefig, mock_plot, moc class MassSimulatorInitializeTests(unittest.TestCase): def setUp(self): - pass + self.interval = Config.candle_interval + Config.candle_interval = 60 def tearDown(self): - pass + Config.candle_interval = self.interval @patch("smtm.SimulationDataProvider.initialize_simulation") @patch("smtm.SimulationTrader.initialize_simulation") @@ -128,12 +123,6 @@ def test_get_initialized_operator_should_initialize_correctly( class MassSimulatorRunTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - @patch("smtm.LogManager.set_stream_level") def test_run_should_call_run_simulation_correctly(self, mock_set_stream_level): mass = MassSimulator() diff --git a/tests/simulation_data_provider_test.py b/tests/simulation_data_provider_test.py index 9931efd..4470afa 100644 --- a/tests/simulation_data_provider_test.py +++ b/tests/simulation_data_provider_test.py @@ -4,12 +4,6 @@ class SimulationDataProviderTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - def test_initialize_simulation_should_call_repo_get_data_correctly(self): dp = SimulationDataProvider() dp.index = 10 diff --git a/tests/simulator_test.py b/tests/simulator_test.py index 3722ebe..f38f5b6 100644 --- a/tests/simulator_test.py +++ b/tests/simulator_test.py @@ -1,14 +1,15 @@ import unittest -from smtm import Simulator +from smtm import Simulator, Config from unittest.mock import * class SimulatorTests(unittest.TestCase): def setUp(self): - pass + self.interval = Config.candle_interval + Config.candle_interval = 60 def tearDown(self): - pass + Config.candle_interval = self.interval @patch("signal.signal") @patch("builtins.input", side_effect=EOFError) diff --git a/tests/virtual_market_test.py b/tests/virtual_market_test.py index 7539aed..d14f77a 100644 --- a/tests/virtual_market_test.py +++ b/tests/virtual_market_test.py @@ -25,6 +25,20 @@ def test_intialize_should_update_data_from_data_repository(self): "2020-04-29T15:40:00", "2020-04-30T00:00:00", market="mango_market" ) + def test_intialize_should_update_data_from_data_repository_with_3m_interval(self): + market = VirtualMarket(interval=180) + market.repo = MagicMock() + market.repo.get_data.return_value = ["mango", "orange"] + market.market = "mango_market" + market.initialize(end="2020-04-30T00:00:00", count=250, budget=7777777) + self.assertEqual(market.data[0], "mango") + self.assertEqual(market.data[1], "orange") + self.assertEqual(market.is_initialized, True) + self.assertEqual(market.balance, 7777777) + market.repo.get_data.assert_called_once_with( + "2020-04-29T11:30:00", "2020-04-30T00:00:00", market="mango_market" + ) + class VirtualMarketTests(unittest.TestCase): def setUp(self):