In [21]:
import tool
import scipy
import pandas as pd
import numpy as np
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, SelectMultiple, fixed
import warnings
warnings.filterwarnings('ignore')

In [22]:
"""
Data initialisation
"""

asset_index = pd.read_csv("data/AIDX.csv", encoding='gbk')

index_list = asset_index['S_IRDCODE'].drop_duplicates().sample(np.random.randint(*(5, 10))).tolist()

asset_index['TRADE_DT'] = pd.to_datetime(asset_index['TRADE_DT'], format='%Y%m%d')
asset_index.sort_values(by='TRADE_DT', inplace=True)
asset_index.set_index('TRADE_DT', inplace=True)
asset_index = asset_index.pivot(columns='S_IRDCODE', values='CLOSE').ffill()[index_list].dropna()

print(index_list)
print(asset_index)

['000841.CSI', 'n00905.CSI', 'h30253.CSI', 'h20556.CSI', 'h40098.SH', 'h40037.SH', 'h20496.CSI', '000904.CSI']
S_IRDCODE   000841.CSI  n00905.CSI  h30253.CSI  h20556.CSI  h40098.SH  \
TRADE_DT                                                                
2023-01-03   9611.9678   6851.9274   3720.1662   1922.7936  7802.9055   
2023-01-04   9679.3683   6847.9662   3741.5124   1926.9373  7836.4019   
2023-01-05   9960.3645   6914.7039   3790.4024   1946.0585  7920.1189   
2023-01-06   9864.4389   6917.4087   3813.2324   1921.8768  7921.0884   
2023-01-09   9952.3246   6948.7055   3822.1741   1931.0004  7972.3173   
...                ...         ...         ...         ...        ...   
2023-12-20   9260.6286   6640.8475   3865.8508   1718.5113  8311.9133   
2023-12-21   9260.6286   6640.8475   3865.8508   1718.5113  8311.9133   
2023-12-22   9260.6286   6640.8475   3865.8508   1718.5113  8311.9133   
2023-12-25   9260.6286   6640.8475   3865.8508   1718.5113  8311.9133   
2023-12-26   

In [23]:
"""
Parameters
"""
BACKTEST_DAY = 30
MODEL_TYPE = 'BL' # MVO, RP
TARGET_RETURN = 0.0 # target return
RISK_FREE_RATE = 0.02 # risk-free rate
REBALANCE_DAYS = 20

n = len(index_list)
index_min_weight = [0 for _ in range(n)]
index_max_weight = [1 for _ in range(n)]
WEIGHT_CONSTRAINTS = list(zip(index_min_weight, index_max_weight))

In [24]:
"""
Rebalancing
"""
def rebalance(asset_index, T):
    predicts = []
    actuals = []
    
    for i in range(T, len(asset_index), T):
        
        if i+T >= len(asset_index):
            break
        
        historical_data = asset_index[i-T:i]
        future_data = asset_index[i:i+T]
        
        predict, actual = tool.evaluate(historical_data, future_data, WEIGHT_CONSTRAINTS, MODEL_TYPE, TARGET_RETURN, RISK_FREE_RATE)
        
        predicts.append(predict)
        actuals.append(actual)
    
    return predicts, actuals

predicts, actuals = rebalance(asset_index, REBALANCE_DAYS)
print(predicts)
print(actuals)

[(1.2832922730595708, 177.08622336045525, 0.007133769353069133), (-0.06263304055164529, 74.36444513306091, -0.0011111901716443833), (-0.14858038831763312, 95.56780915453639, -0.0017639871606246922), (-0.01545741624430215, 96.29235416590824, -0.00036822670451290774), (-0.023819535264056918, 63.78792331618421, -0.0006869566053569747), (0.021919215328120578, 63.3354038590306, 3.0302409255845124e-05), (-0.06170278668279294, 57.072939575547665, -0.0014315503510142957), (-0.04892255849817917, 81.72130333130745, -0.0008433854538364285), (-0.200244083580997, 73.63471502212865, -0.002991036001358998), (-0.01136991023143305, 33.38749298529466, -0.0009395707022761417), (0.07320893509711843, 44.00956064099366, 0.0012090312723448506)]
[(-0.06263304055164529, 74.36444513306091, -0.0011111901716443833), (-0.14858038831763312, 95.56780915453639, -0.0017639871606246922), (-0.01545741624430215, 96.29235416590824, -0.00036822670451290774), (-0.023819535264056918, 63.78792331618421, -0.0006869566053569747

In [25]:
"""
Output
"""

def display(L1, L2, normalize=False, lines_to_show=None):
    line_names = ['r1', 'v1', 's1', 'r2', 'v2', 's2']
    line_styles = {
        'r1': 'r-', 'v1': 'r--', 's1': 'r:',
        'r2': 'b-', 'v2': 'b--', 's2': 'b:'
    }

    a1, b1, c1 = zip(*L1)
    a2, b2, c2 = zip(*L2)

    lines = {'r1': a1, 'v1': b1, 's1': c1, 'r2': a2, 'v2': b2, 's2': c2}

    def normalize_data(data):
        return (data - np.mean(data)) / np.std(data)

    if normalize:
        lines = {name: normalize_data(data) for name, data in lines.items()}

    plt.figure(figsize=(10, 6))
    for line in lines_to_show:
        plt.plot(lines[line], line_styles[line], label=line, marker='o')
    
    plt.title("R: Return\tV: Volatility\tS: Sharpe Ratio\n1: Predicted\t2: Actual\nButton 'normalize': Normalise each line")
    plt.legend()
    
    for pair in [('r1', 'r2'), ('v1', 'v2'), ('s1', 's2')]:
        if pair[0] in lines_to_show and pair[1] in lines_to_show:
            corr, _ = scipy.stats.spearmanr(lines[pair[0]], lines[pair[1]])
            print(f"Spearman correlation between {pair[0]} and {pair[1]}: {corr:.2f}")
    
    plt.show()


# @interact
interact(display, 
         L1=fixed(predicts), 
         L2=fixed(actuals), 
         normalize=True, 
         lines_to_show=SelectMultiple(options=['r1', 'r2', 'v1', 'v2', 's1', 's2'],
                                      value=['r1', 'r2', 'v1', 'v2', 's1', 's2'], 
                                      description='Lines'))



interactive(children=(Checkbox(value=True, description='normalize'), SelectMultiple(description='Lines', index…

<function __main__.display(L1, L2, normalize=False, lines_to_show=None)>