In [1]:
import os
from typing import Dict, List

import numpy as np
import yaml

from src.multivariate_hawkes_training.lob_event_combinator import LOBEventCombinator
os.chdir('C:\\Users\\Admin\\Desktop\\phd\\multivariate_hawkes\\src')

from src.conf.training.model.multivariate_hawkes_training_conf import MultivariateHawkesTrainingConf
from src.lob_data_loader.lob_data_loader import LOBDataLoader
from src.lob_period.lob_period_extractor import LOBPeriodExtractor
from src.events_extractor.multivariate_lob_events_extractor import MultivariateLOBEventsExtractor
import src.constants as CONST
from src.conf.events_conf.events_conf import EventsConf
from src.conf.testing.testing_conf import TestingConf

def get_conf(path: str) -> MultivariateHawkesTrainingConf:
    with open(path, 'r') as f:
        conf = yaml.safe_load(f)
    return conf
def get_event_type_times_maps_with_combined_types(
    event_type_times_map: List[Dict[str, np.ndarray]],
    combined_name_events_to_combine_map: Dict[str, List[str]]
) -> List[Dict[str, np.ndarray]]:
    
    lob_event_combinator = LOBEventCombinator([event_type_times_map])

    for combination_name, lob_events_to_combine in combined_name_events_to_combine_map.items():
        event_type_times_maps = lob_event_combinator.get_event_type_times_maps_with_new_combination(
            lob_events_to_combine,
            combination_name=combination_name,
        )
        lob_event_combinator.event_type_times_maps = event_type_times_maps
    
    return event_type_times_maps

def get_event_type_times_maps_filtered(
    event_type_times_map: List[Dict[str, np.ndarray]],
    events_to_compute: List[str]
) -> List[Dict[str, np.ndarray]]:
    return [
        {key: value for key, value in event_type_times.items() if key in events_to_compute}
        for event_type_times in event_type_times_map
    ]


In [2]:
from time_prediction_model.period_for_simulation import PeriodForSimulation
from time_prediction_model.time_prediction_model_factory.time_prediction_model_factory import TimePredictionModelFactory
from time_prediction_tester.every_time_prediction_tester import EveryTimePredictionTester


testing_conf_map = get_conf(
    os.path.join(
        CONST.CONF_TESTING_FOLDER,
        'C:\\Users\\Admin\\Desktop\\phd\\multivariate_hawkes\\conf\\conf_testing\\btc_usd_testing_conf.yml'
    )
)
testing_conf = TestingConf.from_dict(
    testing_conf_map
)

events_conf_map = get_conf(
    os.path.join(
        CONST.CONF_EVENTS_FOLDER,
        'C:\\Users\\Admin\\Desktop\\phd\\multivariate_hawkes\\conf\\conf_events\\mid_price_change_events_conf.yml'
    )
)
events_conf = EventsConf.from_dict(
    events_conf_map
)

lob_df_loader = LOBDataLoader()
lob_df = lob_df_loader.get_lob_dataframe(
    'C:\\Users\\Admin\\Desktop\\phd\\multivariate_hawkes\\data\\orderbook_changes\\BTC_USD\\orderbook_changes_1705164940479.tsv', 10
)

lob_period_extractor = LOBPeriodExtractor(lob_df)
start_simulation_time = 1705163925

start_warmup_time = start_simulation_time - 150
end_simulation_time = start_simulation_time + 120

lob_period = lob_period_extractor.get_lob_period(start_warmup_time, end_simulation_time)
lob_df_for_events = lob_period.get_lob_df_with_timestamp_column()

lob_df_for_events['Timestamp'] = lob_df_for_events['Timestamp'] * 1000

lob_events_extractor = MultivariateLOBEventsExtractor(
    lob_df_for_events,
    events_conf.num_levels_in_a_side,
    events_conf.num_levels_for_which_save_events
)

event_type_times_map = lob_events_extractor.get_events()
event_type_times_map = {
    key.name: value for key, value in event_type_times_map.items()
}

event_type_times_maps = get_event_type_times_maps_with_combined_types(
    event_type_times_map,
    events_conf.combined_event_types_map
)

event_type_times_maps = get_event_type_times_maps_filtered(
    event_type_times_maps,
    events_conf.events_to_compute
)

event_type_times_maps_formatted_in_seconds = [
    {
        event_type:(times / 1000)
        for event_type, times in event_type_times_map.items()
    }
    for event_type_times_map in event_type_times_maps
]

time_prediction_model_factory = TimePredictionModelFactory(
    'univariate_hawkes',
    30,
    'C:\\Users\\Admin\\Desktop\\phd\\multivariate_hawkes\\data\\trained_params\\univariate_hawkes\\BTC_USD',
    1705164940479,
    start_simulation_time,
)

time_prediction_model = time_prediction_model_factory.get_model()

period_for_simulation = PeriodForSimulation(
    event_type_event_times_map=event_type_times_maps_formatted_in_seconds[0],
    event_types_to_predict=['MID_PRICE_CHANGE'],
    event_types_order=events_conf.events_to_compute
)

time_prediction_tester = EveryTimePredictionTester(
    time_prediction_model,
    period_for_simulation,
    testing_conf.seconds_warm_up_period,
)


In [51]:
df_mio = lob_period.get_lob_df_with_timestamp_column()
df_mio[df_mio['Timestamp'] >= 120 ][['Timestamp', 'AskPrice1', 'BidPrice1']].head(10)

Unnamed: 0,Timestamp,AskPrice1,BidPrice1
584,120.173,42706,42705
585,120.174,42706,42705
586,120.413,42706,42705
587,120.643,42706,42705
588,120.644,42706,42705
589,120.843,42706,42705
590,121.107,42706,42705
591,121.341,42706,42705
592,121.597,42706,42705
593,121.826,42706,42705


In [3]:
time_prediction_tester.get_predicted_event_times()

{'MID_PRICE_CHANGE': array([  1.02750062,   1.02750062,   1.99116184,   1.99116184,
          1.5495146 ,   1.63169641,   1.63169641,   1.52576133,
          1.52576133,   1.57808795,   1.57808795,   2.39960844,
          2.39960844,   2.76425344,   3.57383118,   3.57383118,
         10.82165586,   4.69242562,   4.69242562,   4.69636523,
          4.69636523,   4.89041673,   4.89041673,   5.13381426,
          5.4285627 ,   5.60927219,   5.83472651,   6.44090007,
          6.44090007,   6.76141451,   7.0035912 ,   7.54079355,
          7.6467832 ,   7.87796948,   8.08544018,   8.49281967,
          8.5493804 ,   8.77827004,  18.06524533,  10.82283437,
         11.94768772,  12.44569404,  13.11164001,  13.50623888,
         13.51713711,  13.30198582,  13.30198582,  13.57404342,
         13.58522048,  13.58522048,  13.70288102,  13.9167903 ,
         14.1273272 ,  14.1273272 ,  14.40482925,  14.40482925,
         14.4956865 ,  14.4956865 ,  21.1450764 ,  21.1450764 ,
         16.47681227