In [21]:
from vectorbtpro import *
import numpy as np
import pandas as pd
from numba import njit

@njit
def rolling_mean_nb(arr, window):
    out = np.empty_like(arr)
    for i in range(len(arr)):
        if i < window - 1:
            out[i] = np.nan
        else:
            out[i] = np.mean(arr[i - window + 1:i + 1])
    return out

@njit
def annualized_volatility_nb(returns, window):
    out = np.empty_like(returns)
    for i in range(len(returns)):
        if i < window - 1:
            out[i] = np.nan
        else:
            out[i] = np.std(returns[i - window + 1:i + 1]) * np.sqrt(365)
    return out

@njit
def determine_regime_nb(price, ma_short, ma_long, vol_short, avg_vol_threshold):
    regimes = np.empty_like(price, dtype=np.int32)
    for i in range(len(price)):
        if np.isnan(ma_short[i]) or np.isnan(ma_long[i]) or np.isnan(vol_short[i]):
            regimes[i] = -1  # Unknown
        elif price[i] > ma_short[i] and price[i] > ma_long[i]:
            if vol_short[i] > avg_vol_threshold:
                regimes[i] = 0  # Above Avg Vol Bull Trend
            else:
                regimes[i] = 1  # Below Avg Vol Bull Trend
        elif price[i] < ma_short[i] and price[i] < ma_long[i]:
            if vol_short[i] > avg_vol_threshold:
                regimes[i] = 2  # Above Avg Vol Bear Trend
            else:
                regimes[i] = 3  # Below Avg Vol Bear Trend
        else:
            if vol_short[i] > avg_vol_threshold:
                regimes[i] = 4  # Above Avg Vol Sideways
            else:
                regimes[i] = 5  # Below Avg Vol Sideways
    return regimes

@njit
def calculate_regimes_nb(price, returns, ma_short_window, ma_long_window, vol_short_window, avg_vol_window):
    ma_short = rolling_mean_nb(price, ma_short_window)
    ma_long = rolling_mean_nb(price, ma_long_window)
    vol_short = annualized_volatility_nb(returns, vol_short_window)
    avg_vol_threshold = np.nanmean(annualized_volatility_nb(returns, avg_vol_window))
    regimes = determine_regime_nb(price, ma_short, ma_long, vol_short, avg_vol_threshold)
    return regimes

RegimeIndicator = vbt.IndicatorFactory(
    class_name='RegimeIndicator',
    input_names=['price', 'returns'],
    param_names=['ma_short_window', 'ma_long_window', 'vol_short_window', 'avg_vol_window'],
    output_names=['regimes']
).with_apply_func(calculate_regimes_nb)

In [39]:
# Example data
btc_data = vbt.YFData.fetch('BTC-USD', end='2024-01-01')
btc_data = btc_data.get()
btc_data['Return'] = btc_data['Close'].pct_change()

# Run the indicator
regime_indicator = RegimeIndicator.run(
    btc_data['Close'].values,
    btc_data['Return'].values,
    ma_short_window=21,
    ma_long_window=88,
    vol_short_window=21,
    avg_vol_window=365
)

# Add regimes to DataFrame using .values
btc_data['Market Regime'] = regime_indicator.regimes.values

# Print the final DataFrame
print(btc_data[['Close', 'Market Regime']])

                                  Close  Market Regime
Date                                                  
2014-09-17 00:00:00+00:00    457.334015             -1
2014-09-18 00:00:00+00:00    424.440002             -1
2014-09-19 00:00:00+00:00    394.795990             -1
2014-09-20 00:00:00+00:00    408.903992             -1
2014-09-21 00:00:00+00:00    398.821014             -1
...                                 ...            ...
2023-12-27 00:00:00+00:00  43442.855469              1
2023-12-28 00:00:00+00:00  42627.855469              5
2023-12-29 00:00:00+00:00  42099.402344              5
2023-12-30 00:00:00+00:00  42156.902344              5
2023-12-31 00:00:00+00:00  42265.187500              5

[3393 rows x 2 columns]
