In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as np

from pathlib import Path
import os

import pandas as pd
import plotly.graph_objs as go
import yaml

from filtration_utils import summary_table, build_zero_padded_series, plot_midprice_series_with_insertions, prepare_volatility_filtered_series, plot_midprice_series_with_mean_std

2025-06-30 18:30:39.952825: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.8 which is older than the ptxas CUDA version (12.9.41). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
experiment_name = 'exp_50_20250610_145911'
CONFIG_PATH = f"/app/data_saved/{experiment_name}/used_config.yaml"

# Load YAML config
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Extract values
num_insertions      = config["num_insertions"]
num_coolings        = config["num_coolings"]
midprice_step_size  = config["midprice_step_size"]
hist_msgs           = config["n_messages"]
n_gen_msgs          = config["n_gen_msgs"]

In [4]:
print(f'num_insertions: {num_insertions}')
print(f'num_coolings: {num_coolings}')
print(f'midprice_step_size: {midprice_step_size}')
print(f'hist_msgs: {hist_msgs}')
print(f'n_gen_msgs: {n_gen_msgs}')

num_insertions: 2
num_coolings: 2
midprice_step_size: 1
hist_msgs: 500
n_gen_msgs: 50


In [5]:
hist_steps = hist_msgs // midprice_step_size       # 500
gen_steps = n_gen_msgs // midprice_step_size     # 50
gen_block = gen_steps + 1                        # 51

merged = summary_table(experiment_name)
x, all_series = build_zero_padded_series(hist_msgs, n_gen_msgs, midprice_step_size, merged)

print(merged)

      id                                        merged_data
0   2787  [897700, 897700, 897700, 897700, 897700, 89770...
1   4668  [887000, 887000, 887000, 887000, 887000, 88700...
2  14606  [867800, 867800, 867800, 867800, 867800, 86780...
3  16120  [872300, 872300, 872300, 872300, 872300, 87230...


In [6]:
fig = plot_midprice_series_with_insertions(
    merged,
    all_series,
    x,
    hist_steps,
    gen_block,
    num_insertions,
    num_coolings,
    n_gen_msgs,
    midprice_step_size
)
fig.show()

insertion positions: [551, 602]
cooling positions:   [653, 704]


In [7]:
x, all_series, merged, hist_steps, gen_block = prepare_volatility_filtered_series(merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=0.50)

      id     std_dev  max_abs_dev
0  14606  151.463024          700
1   4668  119.873043          500
2   2787  102.088661          300
3  16120  122.800037          200

Before filtering: 4 samples

After filtering: 2 samples


In [8]:
fig = plot_midprice_series_with_mean_std(
    merged=merged,
    all_series=all_series,
    x=x,
    hist_steps=hist_steps,
    gen_block=gen_block,
    num_insertions=num_insertions,
    num_coolings=num_coolings,
    n_gen_msgs=n_gen_msgs,
    midprice_step_size=midprice_step_size,
)
fig.show()

insertion positions: [551, 602]
cooling positions:   [653, 704]


# ----------------------------------------------------

# ----------------------------------------------------

# ----------------------------------------------------

# OREDER-PLAYER

In [9]:
import os, glob, re
import numpy as np
import pandas as pd

def build_and_merge(folder, batch_prefix, inp_prefix):
    # STEP 1: load every .npy (shape (batch_size, time, feat)) into a DataFrame
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0
        batch = np.load(f)  # shape (batch_size, time, features)
        rec.append({"range": rng, "iteration": itr, "batch": batch})
    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

    # STEP 2: parse the comma‐separated list of IDs into Python lists
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    # explode each batch into one row per sample
    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # shape (time, features)
            rows.append({
                "id":        sample_id,
                "iteration": r["iteration"],
                "data":      single
            })
    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    # STEP 3: for each id, concatenate all its iterations end-to-end
    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)   # (sum_time, features)
        merged.append({"id": id_val, "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

# — example usage —
b_folder      = f"/app/data_saved/{experiment_name}/b_seq_gen_doubled"
b_batch_pref  = "b_seq_gen_doubled_batch"
b_inp_pref    = "b_seq_inp"
_, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)

m_folder      = f"/app/data_saved/{experiment_name}/msgs_decoded_doubled"
m_batch_pref  = "msgs_decoded_doubled_batch"
m_inp_pref    = "msgs_decoded_doubled_inp"
_, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

# build your dicts
b_dict = { int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows() }
m_dict = { int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows() }

# prepend zero‐row so that your interactive plot indexing from t=1…T works
for d in (b_dict, m_dict):
    for key, arr in d.items():
        zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
        d[key] = np.vstack([zero, arr])

In [10]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

def interactive_lob_plot(b_seq_inp, msg_seq_raw):
    # allow DataFrame or dict
    if isinstance(b_seq_inp, pd.DataFrame):
        b_seq_inp = {int(r.id): np.array(r.merged_data) for _,r in b_seq_inp.iterrows()}
    if isinstance(msg_seq_raw, pd.DataFrame):
        msg_seq_raw = {int(r.id): np.array(r.merged_data) for _,r in msg_seq_raw.iterrows()}

    # controls
    id_dd       = widgets.Dropdown(options=sorted(b_seq_inp.keys()), description="Sample ID:")
    time_slider = widgets.IntSlider(min=1, max=1, step=1, description="t:")
    btn_prev    = widgets.Button(description="←")
    btn_next    = widgets.Button(description="→")
    msg_box     = widgets.HTML()

    # figure with two subplots
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Book state t–1","Book state t"])
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=1)
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=2)
    fig.update_layout(width=800, height=400, showlegend=False, template='plotly_white')
    fig_widget = go.FigureWidget(fig)

    def update_slider_range(*_):
        arr = b_seq_inp[id_dd.value]
        # now slider runs 1…(T−1)
        time_slider.min = 1
        time_slider.max = arr.shape[0] - 1
        time_slider.value = 1

    def update_plot(*_):
        sid = id_dd.value
        t   = time_slider.value
        arr = b_seq_inp[sid]
        msgs= msg_seq_raw[sid]

        # take book states at t−1 and t
        s0 = arr[t-1, 240:263]
        s1 = arr[t,   240:263]

        # s0 = arr[t-1]
        # s1 = arr[t]


        diff = s1 - s0
        x = np.arange(len(s0)) - len(s0)//2

        with fig_widget.batch_update():
            fig_widget.data = []
            fig_widget.add_bar(x=x, y=s0, row=1, col=1, marker_color='orange')
            colors = ['orange' if abs(d)<1e-8 else ('red' if d>0 else 'blue') for d in diff]
            fig_widget.add_bar(x=x, y=s1, row=1, col=2, marker_color=colors)
            fig_widget.layout.annotations[0].text = f"Book state {t-1}"
            fig_widget.layout.annotations[1].text = f"Book state {t}"

        # show message at index t
        m = msgs[t].astype(int)
        # fields: [0]=timestamp, [1]=etype, [2]=dir, [3]=abspr, [4]=relpr, [5]=size, …
        et, dr, abspr, relpr, sz = m[1], m[2], m[3], m[4], m[5]
        et_map = {1:"Limit",2:"PartialCancel",3:"Delete",4:"Execution"}
        dr_map = {1:"Buy",0:"Sell"}
        info = (
            f"{et_map.get(et,'?')} • "
            f"{dr_map.get(dr,'?')} • "
            f"abs={abspr} • rel={relpr} • size={sz}"
        )
        msg_box.value = f"<b>{info}<br></b>raw:{m.tolist()}"

    def on_prev(b):
        if time_slider.value>time_slider.min:
            time_slider.value -= 1
    def on_next(b):
        if time_slider.value<time_slider.max:
            time_slider.value += 1

    # wire up events
    id_dd.observe(lambda c: update_slider_range(), names='value')
    time_slider.observe(lambda c: update_plot(), names='value')
    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)

    # initial draw
    update_slider_range()
    update_plot()

    display(widgets.HBox([id_dd, btn_prev, btn_next, time_slider]))
    display(fig_widget, msg_box)


interactive_lob_plot(b_dict, m_dict)

HBox(children=(Dropdown(description='Sample ID:', options=(2787, 4668, 14606, 16120), value=2787), Button(desc…

FigureWidget({
    'data': [{'marker': {'color': 'orange'},
              'type': 'bar',
              'uid': '8f92ef12-cc43-4249-89b3-2b9883835d88',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x',
              'y': {'bdata': ('AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' ... 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAA='),
                    'dtype': 'f4'},
              'yaxis': 'y'},
             {'marker': {'color': [orange, red, red, red, red, red, red, red, red,
                                   red, red, orange, blue, blue, blue, blue, blue,
                                   blue, blue, blue, blue, blue, orange]},
              'type': 'bar',
              'uid': '2a7949dd-4f4a-4250-b700-852ba72b421e',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x2',
              'y': {'bdata': ('AAAAAFCNlz4JrDw/2KPwPobrkT/azj' ... 'K/tcj2vipcz77BymG/NDOzvgAAAAA='),
                    'dtype': 'f

HTML(value='<b>Delete • Sell • abs=897700 • rel=1 • size=73<br></b>raw:[694062990, 3, 0, 897700, 1, 73, 0, 274…

In [None]:
path = "/app/data_saved/exp_50_20250610_145911/msgs_decoded_doubled/msgs_decoded_doubled_batch_[2787, 4668]_iter_1.npy"
data = np.load(path)  # shape = (batch_size, time, features)
sample_2787 = data[0]
print(sample_2787[:5])

In [None]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

# Загружаем
path = "/app/data_saved/exp_50_20250610_145911/b_seq_gen_doubled/b_seq_gen_doubled_batch_[2787, 4668]_iter_1.npy"
data = np.load(path)
sample_2787 = data[0]

# Настраиваем кликер
t_slider = widgets.IntSlider(value=0, min=0, max=sample_2787.shape[0]-1, step=1, description="t:")
btn_prev = widgets.Button(description="← Prev")
btn_next = widgets.Button(description="Next →")

# График
fig = go.FigureWidget()
bar = fig.add_bar(x=[], y=[], marker_color='orange')
fig.update_layout(
    title=f"Book state for sample 2787 at t=0",
    xaxis_title="Relative price level",
    yaxis_title="Size",
    template="plotly_white",
    width=700,
    height=400
)

def update_plot(change=None):
    t = t_slider.value
    state = sample_2787[t, 240:263]
    x = np.arange(len(state)) - len(state)//2

    with fig.batch_update():
        fig.data[0].x = x
        fig.data[0].y = state
        fig.layout.title = f"Book state for sample 2787 at t={t}"

def on_prev_clicked(b):
    if t_slider.value > t_slider.min:
        t_slider.value -= 1

def on_next_clicked(b):
    if t_slider.value < t_slider.max:
        t_slider.value += 1

# Wire up
t_slider.observe(update_plot, names="value")
btn_prev.on_click(on_prev_clicked)
btn_next.on_click(on_next_clicked)

# Initial draw
update_plot()

# Display
display(widgets.HBox([btn_prev, btn_next, t_slider]))
display(fig)

In [None]:
# оно же просто не совпадает!! book state. сообщения да - букстейт в виджете откуда то другой