In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as np

from pathlib import Path
import os

import pandas as pd
import plotly.graph_objs as go
import yaml

from filtration_utils import summary_table, build_zero_padded_series, plot_midprice_series_with_insertions, prepare_volatility_filtered_series, plot_midprice_series_with_mean_std



In [None]:
# experiment_name = 'exp_99_20250708_173043'
experiment_name = 'exp_165_20250722_162019'


CONFIG_PATH = f"/app/data_saved/{experiment_name}/used_config.yaml"

# Load YAML config
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Extract values
num_insertions      = config["num_insertions"]
num_coolings        = config["num_coolings"]
midprice_step_size  = config["midprice_step_size"]
hist_msgs           = config["n_messages"]
n_gen_msgs          = config["n_gen_msgs"]
Direction           = config["DIRECTION_i"]

In [4]:
print(f"Aggressive {'buy' if Direction == 0 else 'sell'}\n")
print(f'num_insertions: {num_insertions}')
print(f'num_coolings: {num_coolings}')
print(f'midprice_step_size: {midprice_step_size}')
print(f'hist_msgs: {hist_msgs}')
print(f'n_gen_msgs: {n_gen_msgs}')

Aggressive buy

num_insertions: 10
num_coolings: 30
midprice_step_size: 1
hist_msgs: 500
n_gen_msgs: 50


In [5]:
hist_steps = hist_msgs // midprice_step_size       # 500
gen_steps = n_gen_msgs // midprice_step_size     # 50
gen_block = gen_steps + 1                        # 51

merged = summary_table(experiment_name)
x, all_series = build_zero_padded_series(hist_msgs, n_gen_msgs, midprice_step_size, merged)

print(merged)

      id                                        merged_data
0    218  [893400, 893400, 893400, 893400, 893400, 89340...
1   1632  [886600, 886500, 886500, 886500, 886500, 88650...
2   2685  [876200, 876200, 876200, 876200, 876200, 87620...
3   3035  [861200, 861200, 861200, 861200, 861200, 86120...
4   3133  [862300, 862300, 862400, 862400, 862400, 86240...
5   3252  [867100, 867100, 867100, 867100, 867100, 86710...
6   3315  [864500, 864500, 864500, 864500, 864500, 86450...
7   3540  [870000, 870000, 870000, 870000, 870000, 87000...
8   3644  [866700, 866700, 866700, 866700, 866700, 86670...
9   5003  [884800, 884800, 884800, 884800, 884800, 88480...
10  5012  [886600, 886600, 886600, 886600, 886600, 88660...
11  5285  [889700, 889700, 889700, 889700, 889700, 88970...
12  6153  [910700, 910700, 910700, 910700, 910700, 91070...
13  6515  [915300, 915300, 915300, 915300, 915300, 91530...
14  6572  [919600, 919600, 919600, 919600, 919600, 91960...
15  6583  [918300, 918300, 918300, 91830

In [6]:
fig = plot_midprice_series_with_insertions(
    merged,
    all_series,
    x,
    hist_steps,
    gen_block,
    num_insertions,
    num_coolings,
    n_gen_msgs,
    midprice_step_size
)
fig.show()

insertion positions: [551, 602, 653, 704, 755, 806, 857, 908, 959, 1010]
cooling positions:   [1061, 1112, 1163, 1214, 1265, 1316, 1367, 1418, 1469, 1520, 1571, 1622, 1673, 1724, 1775, 1826, 1877, 1928, 1979, 2030, 2081, 2132, 2183, 2234, 2285, 2336, 2387, 2438, 2489, 2540]


In [7]:
# x, all_series, merged, hist_steps, gen_block = prepare_volatility_filtered_series(merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=0.3)

In [8]:
fig, mean_series, std_series = plot_midprice_series_with_mean_std(
    merged=merged,
    all_series=all_series,
    x=x,
    hist_steps=hist_steps,
    gen_block=gen_block,
    num_insertions=num_insertions,
    num_coolings=num_coolings,
    n_gen_msgs=n_gen_msgs,
    midprice_step_size=midprice_step_size,
)
fig.show()

insertion positions: [551, 602, 653, 704, 755, 806, 857, 908, 959, 1010]
cooling positions:   [1061, 1112, 1163, 1214, 1265, 1316, 1367, 1418, 1469, 1520, 1571, 1622, 1673, 1724, 1775, 1826, 1877, 1928, 1979, 2030, 2081, 2132, 2183, 2234, 2285, 2336, 2387, 2438, 2489, 2540]


In [9]:
# np.save("/app/data_saved/exp_96_20250703_212149/exp_96_20250703_212149_mean.npy", mean_series)
# np.save("/app/data_saved/exp_96_20250703_212149/exp_96_20250703_212149_std.npy", std_series)

# ----------------------------------------------------

# ----------------------------------------------------

# ----------------------------------------------------

# OREDER-PLAYER

In [10]:
# import os, glob, re
# import numpy as np
# import pandas as pd

# # PRINT EACH NAME + SHAPE WHEN LOADING
# def build_and_merge(folder, batch_prefix, inp_prefix):
#     # STEP 1: load every .npy (shape (batch_size, time, feat)) into a DataFrame
#     files   = glob.glob(os.path.join(folder, "*.npy"))
#     rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
#     rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
#     rec = []
#     for f in files:
#         nm = os.path.basename(f)
#         m  = rx_iter.match(nm)
#         if m:
#             rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
#         else:
#             m2 = rx_inp.match(nm)
#             if not m2:
#                 continue
#             rng, itr = m2.group(1).replace(" ", ""), 0
#         batch = np.load(f)  # shape (batch_size, time, features)

#         # ======== PRINT FILE NAME AND SHAPE =========
#         print(f"Loaded {nm} with shape {batch.shape}")
#         # ============================================

#         rec.append({"range": rng, "iteration": itr, "batch": batch})
#     df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

#     # STEP 2: parse the comma‐separated list of IDs into Python lists
#     df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

#     # explode each batch into one row per sample
#     rows = []
#     for _, r in df.iterrows():
#         for idx, sample_id in enumerate(r["ids"]):
#             single = r["batch"][idx]   # shape (time, features)
#             rows.append({
#                 "id":        sample_id,
#                 "iteration": r["iteration"],
#                 "data":      single
#             })
#     df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

#     # STEP 3: for each id, concatenate all its iterations end-to-end
#     merged = []
#     for id_val, grp in df_sorted.groupby("id", sort=True):
#         arrs = [row.data for _, row in grp.iterrows()]
#         big  = np.concatenate(arrs, axis=0)   # (sum_time, features)
#         merged.append({"id": id_val, "merged_data": big})
#     merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

#     return df, df_sorted, merged_df

# # — example usage —
# b_folder      = f"/app/data_saved/{experiment_name}/b_seq_gen_doubled"
# b_batch_pref  = "b_seq_gen_doubled_batch"
# b_inp_pref    = "b_seq_inp"
# _, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)

# m_folder      = f"/app/data_saved/{experiment_name}/msgs_decoded_doubled"
# m_batch_pref  = "msgs_decoded_doubled_batch"
# m_inp_pref    = "m_seq_raw_inp"
# _, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

# # build your dicts
# b_dict = { int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows() }
# m_dict = { int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows() }

# # prepend zero‐row so that your interactive plot indexing from t=1…T works
# for d in (b_dict, m_dict):
#     for key, arr in d.items():
#         zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
#         d[key] = np.vstack([zero, arr])

In [11]:
import os, glob, re
import numpy as np
import pandas as pd

def build_and_merge(folder, batch_prefix, inp_prefix):
    # STEP 1: load every .npy (shape (batch_size, time, feat)) into a DataFrame
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0

        batch = np.load(f)  # shape (batch_size, time, features)
        print(f"Loaded {nm} with shape {batch.shape}")

        rec.append({"range": rng, "iteration": itr, "batch": batch})
    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

    # STEP 2: parse the comma‐separated list of IDs into Python lists
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    # explode each batch into one row per sample, с учётом slicing
    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # shape (time, features)

            # ====== здесь происходит нужный slice ======
            if r["iteration"] > 0:
                # для первой генерации — последние 51, для остальных — по 50
                n_keep = 51 if r["iteration"] == 1 else 50
                single = single[-n_keep:, :]
            # ============================================

            rows.append({
                "id":        sample_id,
                "iteration": r["iteration"],
                "data":      single
            })

    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    # STEP 3: для каждого id склеить его кусочки
    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)   # (sum_time, features)
        merged.append({"id": id_val, "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

# b_folder      = f"/app/data_saved/{experiment_name}/l2_book_states_halved"
# b_batch_pref  = "l2_book_states_halved_batch" # "b_seq_gen_doubled_batch" 
# b_inp_pref    = "l2_book_states_halved_batch"

b_folder      = f"/app/data_saved/{experiment_name}/b_seq_gen_doubled"
b_batch_pref  = "b_seq_gen_doubled_batch"
b_inp_pref    = "b_seq_inp"

m_folder      = f"/app/data_saved/{experiment_name}/msgs_decoded_doubled"
m_batch_pref  = "msgs_decoded_doubled_batch"
m_inp_pref    = "m_seq_raw_inp"

_, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)
_, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

# Словари
b_dict = { int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows() }
m_dict = { int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows() }

# prepend zero‐row для интерактивного индекса
for d in (b_dict, m_dict):
    for key, arr in d.items():
        zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
        d[key] = np.vstack([zero, arr])

Loaded b_seq_gen_doubled_batch_[218, 1632, 3644, 6583]_iter_0.npy with shape (4, 2510, 501)
Loaded b_seq_gen_doubled_batch_[3315, 2685, 3133, 8525]_iter_0.npy with shape (4, 2510, 501)
Loaded b_seq_gen_doubled_batch_[6749, 5012, 7549, 5285]_iter_0.npy with shape (4, 2510, 501)
Loaded b_seq_gen_doubled_batch_[6908, 5003, 3035, 3540]_iter_0.npy with shape (4, 2510, 501)
Loaded b_seq_gen_doubled_batch_[3252, 6153, 6515, 6572]_iter_0.npy with shape (4, 2510, 501)
Loaded msgs_decoded_doubled_batch_[6908, 5003, 3035, 3540]_iter_0.npy with shape (4, 2510, 14)
Loaded msgs_decoded_doubled_batch_[6749, 5012, 7549, 5285]_iter_0.npy with shape (4, 2510, 14)
Loaded msgs_decoded_doubled_batch_[3315, 2685, 3133, 8525]_iter_0.npy with shape (4, 2510, 14)
Loaded msgs_decoded_doubled_batch_[218, 1632, 3644, 6583]_iter_0.npy with shape (4, 2510, 14)
Loaded msgs_decoded_doubled_batch_[3252, 6153, 6515, 6572]_iter_0.npy with shape (4, 2510, 14)


In [12]:
# m_dict[14606].shape

In [13]:
# b_dict[14606].shape

In [14]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

def interactive_lob_plot(b_seq_inp, msg_seq_raw):
    # allow DataFrame or dict
    if isinstance(b_seq_inp, pd.DataFrame):
        b_seq_inp = {int(r.id): np.array(r.merged_data) for _,r in b_seq_inp.iterrows()}
    if isinstance(msg_seq_raw, pd.DataFrame):
        msg_seq_raw = {int(r.id): np.array(r.merged_data) for _,r in msg_seq_raw.iterrows()}

    # controls
    id_dd       = widgets.Dropdown(options=sorted(b_seq_inp.keys()), description="Sample ID:")
    time_slider = widgets.IntSlider(min=1, max=1, step=1, description="t:")
    btn_prev    = widgets.Button(description="←")
    btn_next    = widgets.Button(description="→")
    msg_box     = widgets.HTML()

    # figure with two subplots
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Book state t–1","Book state t"])
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=1)
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=2)
    fig.update_layout(width=800, height=400, showlegend=False, template='plotly_white')
    fig_widget = go.FigureWidget(fig)

    def update_slider_range(*_):
        arr = b_seq_inp[id_dd.value]
        # now slider runs 1…(T−1)
        time_slider.min = 1
        time_slider.max = arr.shape[0] - 1
        time_slider.value = 1

    def update_plot(*_):
        sid = id_dd.value
        t   = time_slider.value
        arr = b_seq_inp[sid]
        msgs= msg_seq_raw[sid]

        # take book states at t−1 and t
        s0 = arr[t-1, 240:263]
        s1 = arr[t,   240:263]

        # s0 = arr[t-1]
        # s1 = arr[t]


        diff = abs(s1) - abs(s0)
        x = np.arange(len(s0)) - len(s0)//2

        with fig_widget.batch_update():
            fig_widget.data = []
            fig_widget.add_bar(x=x, y=s0, row=1, col=1, marker_color='orange')
            
            colors = ['orange' if abs(d)<1e-8 else ('red' if d>0 else 'blue') for d in diff]
                        
            fig_widget.add_bar(x=x, y=s1, row=1, col=2, marker_color=colors)
            fig_widget.layout.annotations[0].text = f"Book state {t-1}"
            fig_widget.layout.annotations[1].text = f"Book state {t}"

        # show message at index t
        m = msgs[t].astype(int)
        # fields: [0]=timestamp, [1]=etype, [2]=dir, [3]=abspr, [4]=relpr, [5]=size, …
        et, dr, abspr, relpr, sz = m[1], m[2], m[3], m[4], m[5]
        et_map = {1:"Limit",2:"PartialCancel",3:"Delete",4:"Execution"}
        dr_map = {1:"Buy",0:"Sell"}
        info = (
            f"{et_map.get(et,'?')} • "
            f"{dr_map.get(dr,'?')} • "
            f"abs={abspr} • rel={relpr} • size={sz}"
        )
        msg_box.value = f"<b>{info}<br></b>raw:{m.tolist()}"

    def on_prev(b):
        if time_slider.value>time_slider.min:
            time_slider.value -= 1
    def on_next(b):
        if time_slider.value<time_slider.max:
            time_slider.value += 1

    # wire up events
    id_dd.observe(lambda c: update_slider_range(), names='value')
    time_slider.observe(lambda c: update_plot(), names='value')
    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)

    # initial draw
    update_slider_range()
    update_plot()

    display(widgets.HBox([id_dd, btn_prev, btn_next, time_slider]))
    display(fig_widget, msg_box)


interactive_lob_plot(b_dict, m_dict)

HBox(children=(Dropdown(description='Sample ID:', options=(218, 1632, 2685, 3035, 3133, 3252, 3315, 3540, 3644…

FigureWidget({
    'data': [{'marker': {'color': 'orange'},
              'type': 'bar',
              'uid': '307e4fb1-f76e-4e47-85d7-b2ba3007a20c',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x',
              'y': {'bdata': ('AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' ... 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAA='),
                    'dtype': 'f4'},
              'yaxis': 'y'},
             {'marker': {'color': [orange, red, red, red, red, red, red, red, red,
                                   red, red, orange, red, red, red, red, red, red,
                                   red, red, red, red, orange]},
              'type': 'bar',
              'uid': 'c420c7d9-8a4d-494d-a83b-e2875e8ceb83',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x2',
              'y': {'bdata': ('AAAAADQzUz/sUbg+7nz/PuJ6ND/NzA' ... '+/Z2bmvhfZTr8AAIC+3Pl+vwAAAAA='),
                    'dtype': 'f4'},
     

HTML(value='<b>Limit • Buy • abs=893200 • rel=-2 • size=100<br></b>raw:[200962422, 1, 1, 893200, -2, 100, 0, 1…

In [15]:
def show_execution_indices(m_dict):
    """
    For each sample_id in m_dict, prints indices t where etype == 4 (Execution).
    """
    for sid, msgs in m_dict.items():
        etypes = msgs[:, 1]
        exec_indices = np.where(etypes == 4)[0]
        if len(exec_indices) > 0:
            print(f"Sample ID {sid}: Execution at indices {exec_indices.tolist()}")
        else:
            print(f"Sample ID {sid}: No Executions found")

show_execution_indices(m_dict)

Sample ID 218: Execution at indices [249, 431, 432, 551, 602, 653, 704, 710, 711, 755, 802, 806, 837, 841, 857, 868, 877, 882, 883, 888, 893, 900, 908, 909, 910, 959, 965, 1002, 1010, 1103, 1202, 1203, 1567, 1569, 1570, 1571, 1626, 1628, 2130]
Sample ID 1632: Execution at indices [551, 602, 653, 704, 755, 806, 857, 908, 959, 1005, 1010, 1018, 1019, 1028, 1035, 1038, 1050, 1051, 1056, 1078, 1208, 1226, 1303, 1311, 1869]
Sample ID 2685: Execution at indices [119, 454, 496, 499, 551, 602, 653, 657, 658, 704, 755, 806, 857, 908, 959, 1010, 1059, 1060, 1061, 1196, 1462, 1463, 1698, 1897, 1898, 1899, 1934, 2240, 2279, 2302, 2325, 2437]
Sample ID 3035: Execution at indices [49, 138, 139, 140, 156, 175, 364, 383, 413, 414, 429, 468, 477, 537, 538, 551, 602, 653, 704, 755, 806, 820, 821, 822, 823, 857, 876, 908, 959, 1010, 1108, 1127, 1172, 1235, 1244, 1295, 1296, 1388, 1389, 1431, 1452, 1521, 1522, 1523, 1524, 1566, 1567, 1572, 1628, 1629, 1630, 1682, 1779, 2045, 2370, 2371, 2372, 2442, 2501]
