Skip to content

Commit

Permalink
add StrategySmaMl
Browse files Browse the repository at this point in the history
  • Loading branch information
johnverkim committed Apr 9, 2023
1 parent 708fec3 commit aad85ce
Show file tree
Hide file tree
Showing 11 changed files with 703 additions and 9 deletions.
2 changes: 1 addition & 1 deletion integration_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .analyzer_ITG_test import AnalyzerIntegrationTests
from .bithumb_data_provider_ITG_test import BithumbDataProviderIntegrationTests
from .operator_ITG_test import OperatorIntegrationTests
from .simulation_operator_ITG_test import SimulationOperatorIntegrationBnhTests
from .simulation_operator_ITG_test import SimulationOperatorIntegrationTests
from .simulation_trader_ITG_test import SimulationTraderIntegrationTests
from .strategy_bnh_ITG_test import StrategyBuyAndHoldIntegrationTests
from .upbit_data_provider_ITG_test import UpbitDataProviderIntegrationTests
Expand Down
176 changes: 172 additions & 4 deletions integration_tests/simulation_operator_ITG_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
SimulationDataProvider,
SimulationOperator,
SimulationTrader,
StrategyBuyAndHold,
StrategyFactory,
Analyzer,
LogManager,
)
from .data import simulation_data
from unittest.mock import *


class SimulationOperatorIntegrationBnhTests(unittest.TestCase):
class SimulationOperatorIntegrationTests(unittest.TestCase):
def setUp(self):
LogManager.set_stream_level(20)

Expand All @@ -22,11 +22,11 @@ def tearDown(self):
def test_ITG_run_simulation_with_bnh_strategy(self):
trading_snapshot = simulation_data.get_data("bnh_snapshot")
operator = SimulationOperator()
strategy = StrategyBuyAndHold()
strategy = StrategyFactory.create("BNH")
strategy.is_simulation = True
count = 100
budget = 100000
interval = 0.001
interval = 0.0001
time_limit = 15
end_str = "2020-04-30T16:30:00"

Expand Down Expand Up @@ -97,3 +97,171 @@ def check_equal_results_list(self, a, b):
self.assertEqual(a[i]["balance"], b[i]["balance"])
self.assertEqual(a[i]["date_time"], b[i]["date_time"])
self.assertEqual(a[i]["kind"], b[i]["kind"])

def test_ITG_run_simulation_with_sma_strategy(self):
operator = SimulationOperator()
strategy = StrategyFactory.create("SMA")
strategy.is_simulation = True
count = 100
budget = 100000
interval = 0.0001
time_limit = 30
end_str = "2022-04-30T16:30:00"

data_provider = SimulationDataProvider()
data_provider.initialize_simulation(end=end_str, count=count)
trader = SimulationTrader()
trader.initialize_simulation(end=end_str, count=count, budget=budget)
analyzer = Analyzer()
analyzer.is_simulation = True

operator.initialize(
data_provider,
strategy,
trader,
analyzer,
budget=budget,
)

operator.set_interval(interval)
operator.start()
start_time = time.time()
while operator.state == "running":
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

waiting = True
start_time = time.time()
report = None

def callback(return_report):
nonlocal report
nonlocal waiting
report = return_report
waiting = False
self.assertFalse(waiting)

operator.get_score(callback)

while waiting:
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

self.assertIsNotNone(report)
self.assertEqual(report[0], 100000)

def test_ITG_run_simulation_with_rsi_strategy(self):
operator = SimulationOperator()
strategy = StrategyFactory.create("RSI")
strategy.is_simulation = True
count = 100
budget = 100000
interval = 0.0001
time_limit = 30
end_str = "2022-04-30T16:30:00"

data_provider = SimulationDataProvider()
data_provider.initialize_simulation(end=end_str, count=count)
trader = SimulationTrader()
trader.initialize_simulation(end=end_str, count=count, budget=budget)
analyzer = Analyzer()
analyzer.is_simulation = True

operator.initialize(
data_provider,
strategy,
trader,
analyzer,
budget=budget,
)

operator.set_interval(interval)
operator.start()
start_time = time.time()
while operator.state == "running":
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

waiting = True
start_time = time.time()
report = None

def callback(return_report):
nonlocal report
nonlocal waiting
report = return_report
waiting = False
self.assertFalse(waiting)

operator.get_score(callback)

while waiting:
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

self.assertIsNotNone(report)
self.assertEqual(report[0], 100000)

def test_ITG_run_simulation_with_sml_strategy(self):
operator = SimulationOperator()
strategy = StrategyFactory.create("SML")
strategy.is_simulation = True
count = 100
budget = 100000
interval = 0.0001
time_limit = 30
end_str = "2022-07-29T04:30:00"

data_provider = SimulationDataProvider()
data_provider.initialize_simulation(end=end_str, count=count)
trader = SimulationTrader()
trader.initialize_simulation(end=end_str, count=count, budget=budget)
analyzer = Analyzer()
analyzer.is_simulation = True

operator.initialize(
data_provider,
strategy,
trader,
analyzer,
budget=budget,
)

operator.set_interval(interval)
operator.start()
start_time = time.time()
while operator.state == "running":
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

waiting = True
start_time = time.time()
report = None

def callback(return_report):
nonlocal report
nonlocal waiting
report = return_report
waiting = False
self.assertFalse(waiting)

operator.get_score(callback)

while waiting:
time.sleep(0.5)
if time.time() - start_time > time_limit:
self.assertTrue(False, "Time out")
break

self.assertIsNotNone(report)
self.assertEqual(report[0], 100000)
Binary file modified requirements-dev.txt
Binary file not shown.
Binary file modified requirements.txt
Binary file not shown.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ install_requires =
jupyter
psutil
coverage
scikit-learn

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import unittest
from setuptools import find_packages, setup
from setuptools import setup

# Package meta-data.
NAME = "smtm"
Expand All @@ -21,6 +21,7 @@
"python-dotenv",
"jupyter",
"psutil",
"scikit-learn",
]


Expand Down
2 changes: 2 additions & 0 deletions smtm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .simulation_operator import SimulationOperator
from .strategy_bnh import StrategyBuyAndHold
from .strategy_sma_0 import StrategySma0
from .strategy_sma_ml import StrategySmaMl
from .strategy_rsi import StrategyRsi
from .strategy_factory import StrategyFactory
from .virtual_market import VirtualMarket
Expand Down Expand Up @@ -37,6 +38,7 @@
"SimulationOperator",
"StrategyBuyAndHold",
"StrategySma0",
"StrategySmaMl",
"StrategyRsi",
"StrategyFactory",
"VirtualMarket",
Expand Down
4 changes: 2 additions & 2 deletions smtm/strategy_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Strategy 정보 조회 및 생성을 담당하는 Factory 클래스"""

from . import StrategyBuyAndHold, StrategySma0, StrategyRsi
from . import StrategyBuyAndHold, StrategySma0, StrategyRsi, StrategySmaMl


class StrategyFactory:
"""Strategy 정보 조회 및 생성을 담당하는 Factory 클래스"""

STRATEGY_LIST = [StrategyBuyAndHold, StrategySma0, StrategyRsi]
STRATEGY_LIST = [StrategyBuyAndHold, StrategySma0, StrategyRsi, StrategySmaMl]

@staticmethod
def create(code):
Expand Down
Loading

0 comments on commit aad85ce

Please sign in to comment.