Skip to content

Commit f06378b

Browse files
Merge pull request #88 from joshyattridge/codex/update-smc.py-to-handle-close_index--2
Handle early index volumes in OB
2 parents f050646 + 560163e commit f06378b

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

smartmoneyconcepts/smc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,12 @@ def ob(
473473
ob[obIndex] = 1
474474
top_arr[obIndex] = obTop
475475
bottom_arr[obIndex] = obBtm
476-
obVolume[obIndex] = _volume[close_index] + _volume[close_index - 1] + _volume[close_index - 2]
477-
lowVolume[obIndex] = _volume[close_index - 2]
478-
highVolume[obIndex] = _volume[close_index] + _volume[close_index - 1]
476+
vol_cur = _volume[close_index]
477+
vol_prev1 = _volume[close_index - 1] if close_index >= 1 else 0.0
478+
vol_prev2 = _volume[close_index - 2] if close_index >= 2 else 0.0
479+
obVolume[obIndex] = vol_cur + vol_prev1 + vol_prev2
480+
lowVolume[obIndex] = vol_prev2
481+
highVolume[obIndex] = vol_cur + vol_prev1
479482
max_vol = max(highVolume[obIndex], lowVolume[obIndex])
480483
percentage[obIndex] = (min(highVolume[obIndex], lowVolume[obIndex]) / max_vol * 100.0) if max_vol != 0 else 100.0
481484
active_bullish.append(obIndex)
@@ -529,9 +532,12 @@ def ob(
529532
ob[obIndex] = -1
530533
top_arr[obIndex] = obTop
531534
bottom_arr[obIndex] = obBtm
532-
obVolume[obIndex] = _volume[close_index] + _volume[close_index - 1] + _volume[close_index - 2]
533-
lowVolume[obIndex] = _volume[close_index] + _volume[close_index - 1]
534-
highVolume[obIndex] = _volume[close_index - 2]
535+
vol_cur = _volume[close_index]
536+
vol_prev1 = _volume[close_index - 1] if close_index >= 1 else 0.0
537+
vol_prev2 = _volume[close_index - 2] if close_index >= 2 else 0.0
538+
obVolume[obIndex] = vol_cur + vol_prev1 + vol_prev2
539+
lowVolume[obIndex] = vol_cur + vol_prev1
540+
highVolume[obIndex] = vol_prev2
535541
max_vol = max(highVolume[obIndex], lowVolume[obIndex])
536542
percentage[obIndex] = (min(highVolume[obIndex], lowVolume[obIndex]) / max_vol * 100.0) if max_vol != 0 else 100.0
537543
active_bearish.append(obIndex)

tests/unit_tests.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import pandas as pd
77
import unittest
88

9-
sys.path.append(os.path.abspath("../"))
9+
BASE_DIR = os.path.dirname(__file__)
10+
sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "..")))
1011
from smartmoneyconcepts.smc import smc
1112

1213
# define and import test data
1314
test_instrument = "EURUSD"
1415
instrument_data = f"{test_instrument}_15M.csv"
15-
df = pd.read_csv(os.path.join("test_data", test_instrument, instrument_data))
16+
TEST_DATA_DIR = os.path.join(BASE_DIR, "test_data", test_instrument)
17+
df = pd.read_csv(os.path.join(TEST_DATA_DIR, instrument_data))
1618
df = df.set_index("Date")
1719
df.index = pd.to_datetime(df.index)
1820

@@ -24,7 +26,7 @@ def test_fvg(self):
2426
start_time = time.time()
2527
fvg_data = smc.fvg(df)
2628
fvg_result_data = pd.read_csv(
27-
os.path.join("test_data", test_instrument, "fvg_result_data.csv")
29+
os.path.join(TEST_DATA_DIR, "fvg_result_data.csv")
2830
)
2931
print("fvg test time: ", time.time() - start_time)
3032
pd.testing.assert_frame_equal(fvg_data, fvg_result_data, check_dtype=False)
@@ -33,7 +35,7 @@ def test_fvg_consecutive(self):
3335
start_time = time.time()
3436
fvg_data = smc.fvg(df, join_consecutive=True)
3537
fvg_consecutive_result_data = pd.read_csv(
36-
os.path.join("test_data", test_instrument, "fvg_consecutive_result_data.csv")
38+
os.path.join(TEST_DATA_DIR, "fvg_consecutive_result_data.csv")
3739
)
3840
print("fvg consecutive test time: ", time.time() - start_time)
3941
pd.testing.assert_frame_equal(fvg_data, fvg_consecutive_result_data, check_dtype=False)
@@ -42,9 +44,7 @@ def test_swing_highs_lows(self):
4244
start_time = time.time()
4345
swing_highs_lows_data = smc.swing_highs_lows(df, swing_length=5)
4446
swing_highs_lows_result_data = pd.read_csv(
45-
os.path.join(
46-
"test_data", test_instrument, "swing_highs_lows_result_data.csv"
47-
)
47+
os.path.join(TEST_DATA_DIR, "swing_highs_lows_result_data.csv")
4848
)
4949
print("swing_highs_lows test time: ", time.time() - start_time)
5050
pd.testing.assert_frame_equal(swing_highs_lows_data, swing_highs_lows_result_data, check_dtype=False)
@@ -54,7 +54,7 @@ def test_bos_choch(self):
5454
swing_highs_lows_data = smc.swing_highs_lows(df, swing_length=5)
5555
bos_choch_data = smc.bos_choch(df, swing_highs_lows_data)
5656
bos_choch_result_data = pd.read_csv(
57-
os.path.join("test_data", test_instrument, "bos_choch_result_data.csv")
57+
os.path.join(TEST_DATA_DIR, "bos_choch_result_data.csv")
5858
)
5959
print("bos_choch test time: ", time.time() - start_time)
6060
pd.testing.assert_frame_equal(
@@ -66,17 +66,32 @@ def test_ob(self):
6666
swing_highs_lows_data = smc.swing_highs_lows(df, swing_length=5)
6767
ob_data = smc.ob(df, swing_highs_lows_data)
6868
ob_result_data = pd.read_csv(
69-
os.path.join("test_data", test_instrument, "ob_result_data.csv")
69+
os.path.join(TEST_DATA_DIR, "ob_result_data.csv")
7070
)
7171
print("ob test time: ", time.time() - start_time)
7272
pd.testing.assert_frame_equal(ob_data, ob_result_data, check_dtype=False)
7373

74+
def test_ob_early_data(self):
75+
"""Ensure early candles do not cause index errors in OB calculation."""
76+
short_df = pd.DataFrame(
77+
{
78+
"open": [1.0, 1.1, 1.2],
79+
"high": [1.05, 1.15, 1.25],
80+
"low": [0.95, 1.05, 1.15],
81+
"close": [1.02, 1.14, 1.24],
82+
"volume": [5, 6, 7],
83+
}
84+
)
85+
swing = smc.swing_highs_lows(short_df, swing_length=1)
86+
ob_df = smc.ob(short_df, swing)
87+
self.assertEqual(len(ob_df), len(short_df))
88+
7489
def test_liquidity(self):
7590
start_time = time.time()
7691
swing_highs_lows_data = smc.swing_highs_lows(df, swing_length=5)
7792
liquidity_data = smc.liquidity(df, swing_highs_lows_data)
7893
liquidity_result_data = pd.read_csv(
79-
os.path.join("test_data", test_instrument, "liquidity_result_data.csv")
94+
os.path.join(TEST_DATA_DIR, "liquidity_result_data.csv")
8095
)
8196
print("liquidity test time: ", time.time() - start_time)
8297
pd.testing.assert_frame_equal(liquidity_data, liquidity_result_data, check_dtype=False)
@@ -86,9 +101,7 @@ def test_previous_high_low(self):
86101
start_time = time.time()
87102
previous_high_low_data = smc.previous_high_low(df, time_frame="4h")
88103
previous_high_low_result_data = pd.read_csv(
89-
os.path.join(
90-
"test_data", test_instrument, "previous_high_low_result_data_4h.csv"
91-
)
104+
os.path.join(TEST_DATA_DIR, "previous_high_low_result_data_4h.csv")
92105
)
93106
print("previous_high_low test time: ", time.time() - start_time)
94107
pd.testing.assert_frame_equal(previous_high_low_data, previous_high_low_result_data, check_dtype=False)
@@ -97,9 +110,7 @@ def test_previous_high_low(self):
97110
start_time = time.time()
98111
previous_high_low_data = smc.previous_high_low(df, time_frame="1D")
99112
previous_high_low_result_data = pd.read_csv(
100-
os.path.join(
101-
"test_data", test_instrument, "previous_high_low_result_data_1D.csv"
102-
)
113+
os.path.join(TEST_DATA_DIR, "previous_high_low_result_data_1D.csv")
103114
)
104115
print("previous_high_low test time: ", time.time() - start_time)
105116
pd.testing.assert_frame_equal(previous_high_low_data, previous_high_low_result_data, check_dtype=False)
@@ -108,9 +119,7 @@ def test_previous_high_low(self):
108119
start_time = time.time()
109120
previous_high_low_data = smc.previous_high_low(df, time_frame="W")
110121
previous_high_low_result_data = pd.read_csv(
111-
os.path.join(
112-
"test_data", test_instrument, "previous_high_low_result_data_W.csv"
113-
)
122+
os.path.join(TEST_DATA_DIR, "previous_high_low_result_data_W.csv")
114123
)
115124
print("previous_high_low test time: ", time.time() - start_time)
116125
pd.testing.assert_frame_equal(previous_high_low_data, previous_high_low_result_data, check_dtype=False)
@@ -119,7 +128,7 @@ def test_sessions(self):
119128
start_time = time.time()
120129
sessions = smc.sessions(df, session="London")
121130
sessions_result_data = pd.read_csv(
122-
os.path.join("test_data", test_instrument, "sessions_result_data.csv")
131+
os.path.join(TEST_DATA_DIR, "sessions_result_data.csv")
123132
)
124133
print("sessions test time: ", time.time() - start_time)
125134
pd.testing.assert_frame_equal(sessions, sessions_result_data, check_dtype=False)
@@ -129,7 +138,7 @@ def test_retracements(self):
129138
swing_highs_lows_data = smc.swing_highs_lows(df, swing_length=5)
130139
retracements = smc.retracements(df, swing_highs_lows_data)
131140
retracements_result_data = pd.read_csv(
132-
os.path.join("test_data", test_instrument, "retracements_result_data.csv")
141+
os.path.join(TEST_DATA_DIR, "retracements_result_data.csv")
133142
)
134143
print("retracements test time: ", time.time() - start_time)
135144
pd.testing.assert_frame_equal(retracements, retracements_result_data, check_dtype=False)

0 commit comments

Comments
 (0)