diff --git a/smtm/config.py b/smtm/config.py index 3729b1d..47a0782 100644 --- a/smtm/config.py +++ b/smtm/config.py @@ -1,4 +1,5 @@ class Config: + """시뮬레이션에 사용할 거래소 데이터 simulation_source: upbit, binance""" simulation_source = "upbit" candle_interval = 60 """스트림 핸들러의 레벨 diff --git a/smtm/data/simulation_data_provider.py b/smtm/data/simulation_data_provider.py index 5e0fe57..ae6fa80 100644 --- a/smtm/data/simulation_data_provider.py +++ b/smtm/data/simulation_data_provider.py @@ -1,6 +1,7 @@ """시뮬레이션을 위한 DataProvider 구현체 SimulationDataProvider 클래스""" from datetime import datetime, timedelta +from ..config import Config from .data_provider import DataProvider from ..log_manager import LogManager from .data_repository import DataRepository @@ -9,18 +10,23 @@ class SimulationDataProvider(DataProvider): """거래소로부터 과거 데이터를 수집해서 순차적으로 제공하는 클래스""" - AVAILABLE_CURRENCY = {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"} + AVAILABLE_CURRENCY = { + "upbit" : {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"}, + "binance" : {"BTC": "BTCUSDT", "ETH": "ETHUSDT", "DOGE": "DOGEUSDT", "XRP": "XRPUSDT"} + } def __init__(self, currency="BTC", interval=60): - if currency not in self.AVAILABLE_CURRENCY: + if Config.simulation_source not in self.AVAILABLE_CURRENCY.keys(): + raise UserWarning(f"not supported source: {Config.simulation_source}") + if currency not in self.AVAILABLE_CURRENCY[Config.simulation_source]: raise UserWarning(f"not supported currency: {currency}") + self.logger = LogManager.get_logger(__class__.__name__) self.repo = DataRepository("smtm.db", interval=interval) self.interval_min = interval / 60 self.data = [] self.index = 0 - - self.market = self.AVAILABLE_CURRENCY[currency] + self.market = self.AVAILABLE_CURRENCY[Config.simulation_source][currency] def initialize_simulation(self, end=None, count=100): """DataRepository를 통해서 데이터를 가져와서 초기화한다""" diff --git a/smtm/trader/simulation_trader.py b/smtm/trader/simulation_trader.py index 6d09490..c9c78c9 100644 --- a/smtm/trader/simulation_trader.py +++ b/smtm/trader/simulation_trader.py @@ -1,5 +1,6 @@ """시뮬레이션을 위한 가상 거래를 처리해주는 SimulationTrader 클래스""" +from ..config import Config from ..log_manager import LogManager from .trader import Trader from .virtual_market import VirtualMarket @@ -15,19 +16,24 @@ class SimulationTrader(Trader): amount: 거래 수량 """ - AVAILABLE_CURRENCY = {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"} + AVAILABLE_CURRENCY = { + "upbit" : {"BTC": "KRW-BTC", "ETH": "KRW-ETH", "DOGE": "KRW-DOGE", "XRP": "KRW-XRP"}, + "binance" : {"BTC": "BTCUSDT", "ETH": "ETHUSDT", "DOGE": "DOGEUSDT", "XRP": "XRPUSDT"} + } NAME = "Simulation" def __init__(self, currency="BTC", interval=60): - if currency not in self.AVAILABLE_CURRENCY: + if Config.simulation_source not in self.AVAILABLE_CURRENCY.keys(): + raise UserWarning(f"not supported source: {Config.simulation_source}") + if currency not in self.AVAILABLE_CURRENCY[Config.simulation_source]: raise UserWarning(f"not supported currency: {currency}") self.logger = LogManager.get_logger(__class__.__name__) - self.market = VirtualMarket(market=self.AVAILABLE_CURRENCY[currency], interval=interval) + self.v_market = VirtualMarket(market=self.AVAILABLE_CURRENCY[Config.simulation_source][currency], interval=interval) self.is_initialized = False def initialize_simulation(self, end, count, budget): """시뮬레이션기간, 횟수, 예산을 초기화 한다""" - self.market.initialize(end, count, budget) + self.v_market.initialize(end, count, budget) self.is_initialized = True def send_request(self, request_list, callback): @@ -59,7 +65,7 @@ def send_request(self, request_list, callback): raise UserWarning("Not initialzed") try: - result = self.market.handle_request(request_list[0]) + result = self.v_market.handle_request(request_list[0]) if result is not None: callback(result) except (TypeError, AttributeError) as msg: @@ -82,7 +88,7 @@ def get_account_info(self): raise UserWarning("Not initialzed") try: - return self.market.get_balance() + return self.v_market.get_balance() except (TypeError, AttributeError) as msg: self.logger.error(f"invalid state {msg}") raise UserWarning("invalid state") from msg diff --git a/smtm/trader/virtual_market.py b/smtm/trader/virtual_market.py index f930f95..119084c 100644 --- a/smtm/trader/virtual_market.py +++ b/smtm/trader/virtual_market.py @@ -1,4 +1,4 @@ -"""업비트 거래소의 과거 거래 정보를 이용한 가상 거래소 역할의 VirtualMarket 클래스""" +"""Config에 설정된 거래소의 과거 거래 정보를 이용한 가상 거래소 역할의 VirtualMarket 클래스""" from datetime import datetime, timedelta from ..config import Config from ..log_manager import LogManager @@ -18,8 +18,6 @@ class VirtualMarket: asset: 자산 목록, 마켓이름을 키값으로 갖고 (평균 매입 가격, 수량)을 갖는 딕셔너리 """ - URL = "https://api.upbit.com/v1/candles/minutes/1" - def __init__(self, market="KRW-BTC", interval=60): self.logger = LogManager.get_logger(__class__.__name__) self.repo = DataRepository("smtm.db", interval=interval) diff --git a/tests/simulation_trader_test.py b/tests/simulation_trader_test.py index f24a861..1cfda46 100644 --- a/tests/simulation_trader_test.py +++ b/tests/simulation_trader_test.py @@ -12,17 +12,17 @@ def tearDown(self): def test_initialize_simulation_initialize_virtual_market(self): trader = SimulationTrader() - trader.market.initialize = MagicMock() - trader.market.deposit = MagicMock() + trader.v_market.initialize = MagicMock() + trader.v_market.deposit = MagicMock() trader.initialize_simulation("mango", 500, 5000) - trader.market.initialize.assert_called_once_with("mango", 500, 5000) + trader.v_market.initialize.assert_called_once_with("mango", 500, 5000) self.assertEqual(trader.is_initialized, True) def test_initialize_simulation_set_is_initialized_False_when_invalid_market(self): trader = SimulationTrader() - trader.market = "make exception" + trader.v_market = "make exception" with self.assertRaises(AttributeError): trader.initialize_simulation("mango", 500, 5000) @@ -34,9 +34,9 @@ def test_send_request_call_callback_with_result_of_market_handle_quest(self): dummy_requests = [{"id": "mango", "type": "orange", "price": 500, "amount": 10}] callback = MagicMock() - trader.market.handle_request = MagicMock(return_value="banana") + trader.v_market.handle_request = MagicMock(return_value="banana") trader.send_request(dummy_requests, callback) - trader.market.handle_request.assert_called_once_with(dummy_requests[0]) + trader.v_market.handle_request.assert_called_once_with(dummy_requests[0]) callback.assert_called_once_with("banana") def test_send_request_call_raise_exception_UserWarning_when_is_initialized_False(self): @@ -49,7 +49,7 @@ def test_send_request_call_raise_exception_UserWarning_when_is_initialized_False def test_send_request_call_raise_exception_UserWarning_when_market_is_invalid(self): trader = SimulationTrader() trader.is_initialized = True - trader.market = "make exception" + trader.v_market = "make exception" with self.assertRaises(UserWarning): trader.send_request(None, None) @@ -64,9 +64,9 @@ def test_send_request_call_raise_exception_UserWarning_when_callback_make_TypeEr def test_get_account_info_call_callback_with_virtual_market_get_balance_result(self): trader = SimulationTrader() trader.is_initialized = True - trader.market.get_balance = MagicMock(return_value="banana") + trader.v_market.get_balance = MagicMock(return_value="banana") self.assertEqual(trader.get_account_info(), "banana") - trader.market.get_balance.assert_called_once() + trader.v_market.get_balance.assert_called_once() def test_get_account_info_call_raise_exception_UserWarning_when_is_initialized_False( self, @@ -82,7 +82,7 @@ def test_get_account_info_call_raise_exception_UserWarning_when_market_is_invali ): trader = SimulationTrader() trader.is_initialized = True - trader.market = "make exception" + trader.v_market = "make exception" with self.assertRaises(UserWarning): trader.get_account_info()