In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

PROJ_ROOT = Path('..').resolve()
sys.path.insert(0, str(PROJ_ROOT))

from kalman_filter_bank.filter_bank import SinusoidalFilterBank, run_filter_bank
from util import extract_low_pass_components
from optimization.optimization_util import PositionErrorContext, position_error

In [2]:
np.random.seed(0)
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.grid'] = True

DATA_PATH = PROJ_ROOT / 'data' / 'btc_1m.csv'
OUT_DIR = (PROJ_ROOT / 'optimization' / 'grid_search_outputs')
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# Grid config

OMEGAS = np.array([0.02, 0.66, 2.04, 3.90])

MAX_FREQ_FOR_TRUTH = 4.5
dt = 1 / (24*60)

CtxClass, loss_fn = PositionErrorContext, position_error

In [None]:
# Read in data

price_df = pd.read_csv(DATA_PATH)
rate = (price_df.Open.values - price_df.Open.values[0]) / price_df.Open.values[0]

Loaded 14,400 rows. dt = 0.00069444 days/sample (~1.000 min/sample)
Measurement preview: [ 0.00000000e+00  0.00000000e+00 -9.04169762e-05  2.74803326e-04
  3.21153656e-04]


In [5]:
# FFT low-pass truth

truth_dict = extract_low_pass_components(
    signal=rate, dt=dt, max_freq=MAX_FREQ_FOR_TRUTH)
truth = truth_dict['truth']

In [93]:
import itertools
rgrid = np.array(list(itertools.product(np.arange(-1, .5, .5), repeat=4)))
qgrid = np.array(list(itertools.product(np.arange(0, 1.5, .5), repeat=4)))
full_grid = np.zeros((qgrid.shape[0]*rgrid.shape[0], 8))
i = 0
for qi in range(qgrid.shape[0]):
    for ri in range(rgrid.shape[0]):
        full_grid[i] = np.concatenate([qgrid[qi], rgrid[ri]])
        i += 1

print(full_grid.shape)

(6561, 8)


In [94]:
# Grid search

results = []

for row in tqdm(full_grid):
    Q = 10.0 ** row[:4]
    R = 10.0 ** row[4:]
    bank = SinusoidalFilterBank(
        dim_x=2, dim_z=1,
        omegas=OMEGAS, dt=dt,
        sigma_xi=Q,
        rho=R
    )
    out = run_filter_bank(bank, rate, verbose=False)
    
    # Compute loss for this filter bank
    ctx = CtxClass(
            filter_state = out['x'],
            truth_position = truth
        )
    loss_val = float(loss_fn(ctx))
    
    run_data = {
        'q_exp_1': row[0],
        'q_exp_2': row[1],
        'q_exp_3': row[2],
        'q_exp_4': row[3],
        'r_exp_1': row[4],
        'r_exp_2': row[5],
        'r_exp_3': row[6],
        'r_exp_4': row[7],
        'loss': loss_val
    }
    
    results.append(run_data)
    
results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by='loss', ascending=True)
results_df.to_csv(OUT_DIR / 'grid_search_results.csv', index=False)
print("Grid search complete\n")
print(results_df.head(10))

100%|██████████| 6561/6561 [1:35:24<00:00,  1.15it/s]

Grid search complete

      q_exp_1  q_exp_2  q_exp_3  q_exp_4  r_exp_1  r_exp_2  r_exp_3  r_exp_4  \
2160      0.0      1.0      1.0      1.0      0.0     -1.0     -1.0     -1.0   
2133      0.0      1.0      1.0      1.0     -0.5     -1.0     -1.0     -1.0   
5040      1.0      0.0      1.0      1.0     -1.0      0.0     -1.0     -1.0   
5031      1.0      0.0      1.0      1.0     -1.0     -0.5     -1.0     -1.0   
2106      0.0      1.0      1.0      1.0     -1.0     -1.0     -1.0     -1.0   
4347      0.5      1.0      1.0      1.0      0.0     -1.0     -1.0     -1.0   
6480      1.0      1.0      1.0      1.0     -1.0     -1.0     -1.0     -1.0   
5769      1.0      0.5      1.0      1.0     -1.0      0.0     -1.0     -1.0   
5022      1.0      0.0      1.0      1.0     -1.0     -1.0     -1.0     -1.0   
4320      0.5      1.0      1.0      1.0     -0.5     -1.0     -1.0     -1.0   

          loss  
2160  0.031376  
2133  0.031502  
5040  0.031557  
5031  0.031691  
2106  0.0316


