In [11]:
import tool
import scipy
import pandas as pd
import numpy as np
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import interact, SelectMultiple, fixed
import warnings
warnings.filterwarnings('ignore')

In [12]:
"""
Data initialisation
"""

asset_index = pd.read_csv("data/AIDX.csv", encoding='gbk')

index_list = asset_index['S_IRDCODE'].drop_duplicates().sample(np.random.randint(*(5, 10))).tolist()

asset_index['TRADE_DT'] = pd.to_datetime(asset_index['TRADE_DT'], format='%Y%m%d')
asset_index.sort_values(by='TRADE_DT', inplace=True)
asset_index.set_index('TRADE_DT', inplace=True)
asset_index = asset_index.pivot(columns='S_IRDCODE', values='CLOSE').ffill()[index_list].dropna()

print(index_list)
print(asset_index)

['h01051.CSI', '931698.CSI', 'h00060.SH', 'h30136.CSI', '932093.CSI', 'L11621.CSI', '930660.CSI']
S_IRDCODE   h01051.CSI  931698.CSI  h00060.SH  h30136.CSI  932093.CSI  \
TRADE_DT                                                                
2023-01-03    810.4735   8576.9398  5206.2104   1596.3014    825.7359   
2023-01-04    809.6737   8612.3302  5251.7832   1646.3276    818.8669   
2023-01-05    814.7967   8759.5356  5282.9599   1658.7949    829.2426   
2023-01-06    825.3162   8702.6039  5280.9130   1663.0547    833.4055   
2023-01-09    829.1884   8768.6148  5306.7821   1700.2410    837.8665   
...                ...         ...        ...         ...         ...   
2023-12-20    798.2527   7638.7525  5578.6980   1485.8548    756.5774   
2023-12-21    798.2527   7638.7525  5578.6980   1485.8548    756.5774   
2023-12-22    798.2527   7638.7525  5578.6980   1485.8548    756.5774   
2023-12-25    798.2527   7638.7525  5578.6980   1485.8548    756.5774   
2023-12-26    798.2527   7

In [13]:
"""
Parameters
"""
BACKTEST_DAY = 30
MODEL_TYPE = 'RP' # MVO, RP
TARGET_RETURN = 0.0 # target return
RISK_FREE_RATE = 0.02 # risk-free rate
REBALANCE_DAYS = 20

n = len(index_list)
index_min_weight = [0 for _ in range(n)]
index_max_weight = [1 for _ in range(n)]
WEIGHT_CONSTRAINTS = list(zip(index_min_weight, index_max_weight))

In [14]:
"""
Rebalancing
"""
def rebalance(asset_index, T):
    predicts = []
    actuals = []
    
    for i in range(T, len(asset_index), T):
        
        if i+T >= len(asset_index):
            break
        
        historical_data = asset_index[i-T:i]
        future_data = asset_index[i:i+T]
        
        predict, actual = tool.evaluate(historical_data, future_data, WEIGHT_CONSTRAINTS, MODEL_TYPE, TARGET_RETURN, RISK_FREE_RATE)
        
        predicts.append(predict)
        actuals.append(actual)
    
    return predicts, actuals

predicts, actuals = rebalance(asset_index, REBALANCE_DAYS)
print(predicts)
print(actuals)

[(1.32813430375707, 56.37762958284468, 0.02320307386877305), (-0.013009634729776204, 30.784629345633874, -0.0010722765039384145), (0.7606891869292984, 66.25308641187765, 0.011179693310054004), (0.3067936293836631, 54.02516895854805, 0.005308518879482109), (0.34290190055774916, 24.851160305823534, 0.012993433569461199), (2.4181721718034033, 46.3761272966235, 0.0517113504640998), (-0.048657292330602155, 28.888082951019435, -0.0023766648845135394), (-0.17601729662780916, 45.964446362922885, -0.004264541665097179), (0.02094809411058799, 43.25905350313749, 2.1916663306542916e-05), (-0.14372402575829213, 24.036658392042654, -0.006811430403008641), (0.0802316717625676, 34.49104805522562, 0.001746298682084904)]
[(0.08188261887054518, 21.692814471501936, 0.002852678196821394), (0.4532695492189157, 72.24705063264201, 0.005997055179760649), (0.5030766879645534, 58.525529564251165, 0.008254119041062527), (0.30869079541496625, 29.039530502367377, 0.009941303816583102), (1.82954645224246, 45.6719918

In [15]:
"""
Output
"""

def display(L1, L2, normalize=False, lines_to_show=None):
    line_names = ['r1', 'v1', 's1', 'r2', 'v2', 's2']
    line_styles = {
        'r1': 'r-', 'v1': 'r--', 's1': 'r:',
        'r2': 'b-', 'v2': 'b--', 's2': 'b:'
    }

    a1, b1, c1 = zip(*L1)
    a2, b2, c2 = zip(*L2)

    lines = {'r1': a1, 'v1': b1, 's1': c1, 'r2': a2, 'v2': b2, 's2': c2}

    def normalize_data(data):
        return (data - np.mean(data)) / np.std(data)

    if normalize:
        lines = {name: normalize_data(data) for name, data in lines.items()}

    plt.figure(figsize=(10, 6))
    for line in lines_to_show:
        plt.plot(lines[line], line_styles[line], label=line, marker='o')
    
    plt.title("R: Return\tV: Volatility\tS: Sharpe Ratio\n1: Predicted\t2: Actual\nButton 'normalize': Normalise each line")
    plt.legend()
    
    for pair in [('r1', 'r2'), ('v1', 'v2'), ('s1', 's2')]:
        if pair[0] in lines_to_show and pair[1] in lines_to_show:
            corr, _ = scipy.stats.spearmanr(lines[pair[0]], lines[pair[1]])
            print(f"Spearman correlation between {pair[0]} and {pair[1]}: {corr:.2f}")
    
    plt.show()


# @interact
interact(display, 
         L1=fixed(predicts), 
         L2=fixed(actuals), 
         normalize=True, 
         lines_to_show=SelectMultiple(options=['r1', 'r2', 'v1', 'v2', 's1', 's2'],
                                      value=['r1', 'r2', 'v1', 'v2', 's1', 's2'], 
                                      description='Lines'))



interactive(children=(Checkbox(value=True, description='normalize'), SelectMultiple(description='Lines', index…

<function __main__.display(L1, L2, normalize=False, lines_to_show=None)>