In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as onp

from pathlib import Path
import os

import pandas as pd
import plotly.graph_objs as go

2025-05-15 22:29:24.026180: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.8 which is older than the ptxas CUDA version (12.9.41). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
# experiment_name = 'exp_3_20250429_213228'
# num_insertions = 20
# num_coolings = 20
# midprice_step_size = 1
# hist_msgs = 500
# n_gen_msgs = 50

# experiment_name = '3'
# num_insertions = 10
# num_coolings = 10
# midprice_step_size = 10
# hist_msgs = 500
# n_gen_msgs = 500

experiment_name = 'exp_27_20250515_221221'
num_insertions = 20
num_coolings = 40
midprice_step_size = 1
hist_msgs = 500
n_gen_msgs = 50

In [4]:

batch_size: 3
n_samples: 9

# generation
n_gen_msgs = 50
midprice_step_size = 50

num_insertions = 2
num_coolings = 2

In [5]:
import os, glob, re
import numpy as np
import pandas as pd

# 1) Путь и шаблон файлов
DATA_DIR = f"/app/data_saved/{experiment_name}/mid_price"
pattern = os.path.join(DATA_DIR, "mid_price_batch_*_iter_*.npy")

# 2) Собираем все пути
files = sorted(glob.glob(pattern))
if not files:
    raise FileNotFoundError(f"No .npy files matching {pattern}")

# 3) Разбираем имя, загружаем и формируем «плоские» записи
rx = re.compile(r"mid_price_batch_\[([\d,\s]+)\]_iter_(\d+)\.npy$")
records = []
for f in files:
    m = rx.search(os.path.basename(f))
    if not m:
        continue
    rng = m.group(1).replace(" ", "")      # e.g. "3464,4169,4497,6855"
    itr = int(m.group(2))                  # iteration number
    arr = np.load(f)                       # shape (n_steps, batch_size)

    ids = list(map(int, rng.split(",")))
    # now split out each column into its own sample-stream:
    for col, sample_id in enumerate(ids):
        records.append({
            "id":        sample_id,
            "iteration": itr,
            "data":      arr[:, col]
        })

# 4) Создаём DataFrame и сортируем
df = pd.DataFrame.from_records(records, columns=["id","iteration","data"])
df = df.sort_values(["id","iteration"]).reset_index(drop=True)

# 5) Группируем по id и склеиваем временные ряды
merged = (
    df.groupby("id", as_index=False)
      .agg(merged_data=("data", lambda s: np.concatenate(s.tolist())))
)

# 6) (опционально) превратить в чистые списки
merged["merged_data"] = merged["merged_data"].apply(lambda a: a.tolist())

In [6]:
if experiment_name == '3':
    bad_ids = {4169, 10853, 13577}
    merged = merged[~merged["id"].isin(bad_ids)].reset_index(drop=True)

In [7]:
print(merged)

      id                                        merged_data
0   1895  [891800, 892000, 892000, 892100, 892200, 89210...
1   2787  [897700, 897700, 897700, 897600, 897700, 89770...
2   4668  [886800, 886900, 886800, 886600, 886700, 88670...
3  14606  [867800, 867900, 868000, 868100, 868200, 86820...
4  16120  [872200, 872400, 872500, 872400, 872400, 87240...
5  22181  [875100, 875200, 875300, 875500, 875400, 87550...


In [8]:
# convert to “mid-price sample” units
hist_steps   = hist_msgs // midprice_step_size       # 500
gen_steps     = n_gen_msgs // midprice_step_size     # 50
gen_block     = gen_steps + 1                        # 51

# build zeroed series…
all_series = []
for row in merged.itertuples(index=False):
    data = np.array(row.merged_data)
    all_series.append(data - data[0])
all_series = np.vstack(all_series)
x          = np.arange(1, all_series.shape[1]+1)

# compute mean/std
mean_s = all_series.mean(axis=0)
std_s  = all_series.std(axis=0)

fig = go.Figure()

# faint individual
for row, arr0 in zip(merged.itertuples(False), all_series):
    fig.add_trace(go.Scatter(x=x, y=arr0, mode='lines',
                             opacity=1.0, hoverinfo='skip',
                             line=dict(width=1), name=f"id {row.id}"))

# history-end
fig.add_vline(x=hist_steps, line=dict(color='blue',width=2,dash='dash'))

# all event-positions in one go
events = np.arange(1, num_insertions+num_coolings+1)
positions = hist_steps + gen_block*events

# first N as solid (insertions)
for pos in positions[:num_insertions]:
    fig.add_vline(x=pos, line=dict(color='red',width=2,dash='solid'))
# next M as dashed (coolings)
for pos in positions[num_insertions:]:
    fig.add_vline(x=pos, line=dict(color='red',width=2,dash='dash'))

print("insertion positions:", positions[:num_insertions].tolist())
print("cooling positions:  ", positions[num_insertions:].tolist())

# zero-line & layout
fig.add_hline(y=0, line=dict(color='black',width=2,dash='solid'),
              annotation_text="0-line", annotation_position="bottom right")
fig.update_layout(
    title="All midprice series (zeroed) with insertion/cooling lines",
    xaxis_title="Index", yaxis_title="Price – first_price",
    template="plotly_white", hovermode="x unified",
    height=800, width=1200, margin={"b":150}
)
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)

info = (
    f"Midprice every {midprice_step_size} msgs;<br>"
    f"solid red = insertion (every {n_gen_msgs} msgs +1 step);<br>"
    f"dashed red = cooling (same spacing); blue = end of history."
)
fig.add_annotation(text=info, xref="paper", yref="paper",
                   x=0, y=-0.225, showarrow=False, align="left")
fig.show()

insertion positions: [12, 14]
cooling positions:   [16, 18]


In [9]:
# 1) how many mid-price samples in one historical block?
hist_steps = hist_msgs // midprice_step_size

# 2) how many mid-price samples in one generated block?
gen_steps  = n_gen_msgs // midprice_step_size

# 3) the “block” between events in mid-price units:
gen_block  = gen_steps + 1    # one extra step after each generated block

# collect your zeroed series…
all_series = []
for row in merged.itertuples(index=False):
    data = np.array(row.merged_data)
    arr0 = data - data[0]
    all_series.append(arr0)
all_series = np.vstack(all_series)
time       = all_series.shape[1]
x          = np.arange(1, time+1)

# compute mean/std
mean_series = all_series.mean(axis=0)
std_series  = all_series.std(axis=0)

fig = go.Figure()

# faint individual traces
for row, arr0 in zip(merged.itertuples(index=False), all_series):
    fig.add_trace(go.Scatter(
        x=x, y=arr0, mode='lines',
        name=f"id {row.id}", line=dict(width=1),
        opacity=0.2, hoverinfo='skip'
    ))

# ±1 std band
fig.add_trace(go.Scatter(
    x=np.concatenate([x, x[::-1]]),
    y=np.concatenate([mean_series + std_series, (mean_series - std_series)[::-1]]),
    fill='toself', fillcolor='rgba(0,0,0,0.1)',
    line=dict(color='rgba(0,0,0,0)'), hoverinfo='skip',
    showlegend=True, name='±1 std'
))

# bold mean line
fig.add_trace(go.Scatter(
    x=x, y=mean_series, mode='lines',
    name='Mean', line=dict(color='black', width=4)
))

# history-end (blue dashed)
fig.add_vline(x=hist_steps, line=dict(color='blue', width=2, dash='dash'))

# build all event positions in mid-price units
events = np.arange(1, num_insertions + num_coolings + 1)
positions = hist_steps + gen_block * events

# insertion (solid red) = first num_insertions events
for pos in positions[:num_insertions]:
    fig.add_vline(x=pos, line=dict(color='red', width=2, dash='solid'))
# cooling (dashed red) = next num_coolings events
for pos in positions[num_insertions:]:
    fig.add_vline(x=pos, line=dict(color='red', width=2, dash='dash'))

print("insertion positions:", positions[:num_insertions].tolist())
print("cooling positions:  ", positions[num_insertions:].tolist())

# zero‐line
fig.add_hline(y=0, line=dict(color='black', width=2, dash='solid'),
              annotation_text="0-line", annotation_position="bottom right")

# layout
fig.update_layout(
    title="All midprice series (zeroed) with insertion/cooling lines",
    xaxis_title="Index", yaxis_title="Price – first_price",
    template="plotly_white", hovermode="x unified",
    height=800, width=1200, margin={"b":150}
)
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)

info_text = (
    f"Midprice every {midprice_step_size} msgs;<br>"
    f"solid red = insertion (every {n_gen_msgs} msgs +1 step);<br>"
    f"dashed red = cooling (same spacing); blue = end of history."
)
fig.add_annotation(
    text=info_text, xref="paper", yref="paper",
    x=0, y=-0.225, showarrow=False, align="left", font={"size":12}
)

fig.show()

insertion positions: [12, 14]
cooling positions:   [16, 18]


In [10]:
import os, glob, re
import numpy as np
import pandas as pd

def build_and_merge(folder, batch_prefix, inp_prefix):
    # STEP 1: load every .npy (shape (batch_size, time, feat)) into a DataFrame
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0
        batch = np.load(f)  # shape (batch_size, time, features)
        rec.append({"range": rng, "iteration": itr, "batch": batch})
    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

    # STEP 2: parse the comma‐separated list of IDs into Python lists
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    # explode each batch into one row per sample
    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # shape (time, features)
            rows.append({
                "id":        sample_id,
                "iteration": r["iteration"],
                "data":      single
            })
    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    # STEP 3: for each id, concatenate all its iterations end-to-end
    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)   # (sum_time, features)
        merged.append({"id": id_val, "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

# — example usage —
b_folder      = f"/app/data_saved/{experiment_name}/b_seq_gen_doubled"
b_batch_pref  = "b_seq_gen_doubled_batch"
b_inp_pref    = "b_seq_inp"
_, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)

m_folder      = f"/app/data_saved/{experiment_name}/msgs_decoded_doubled"
m_batch_pref  = "msgs_decoded_doubled_batch"
m_inp_pref    = "msgs_decoded_doubled_inp"
_, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

# build your dicts
b_dict = { int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows() }
m_dict = { int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows() }

# prepend zero‐row so that your interactive plot indexing from t=1…T works
for d in (b_dict, m_dict):
    for key, arr in d.items():
        zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
        d[key] = np.vstack([zero, arr])

In [13]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

def interactive_lob_plot(b_seq_inp, msg_seq_raw):
    # allow DataFrame or dict
    if isinstance(b_seq_inp, pd.DataFrame):
        b_seq_inp = {int(r.id): np.array(r.merged_data) for _,r in b_seq_inp.iterrows()}
    if isinstance(msg_seq_raw, pd.DataFrame):
        msg_seq_raw = {int(r.id): np.array(r.merged_data) for _,r in msg_seq_raw.iterrows()}

    # controls
    id_dd       = widgets.Dropdown(options=sorted(b_seq_inp.keys()), description="Sample ID:")
    time_slider = widgets.IntSlider(min=1, max=1, step=1, description="t:")
    btn_prev    = widgets.Button(description="←")
    btn_next    = widgets.Button(description="→")
    msg_box     = widgets.HTML()

    # figure with two subplots
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Book state t–1","Book state t"])
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=1)
    fig.add_trace(go.Bar(x=[],y=[]), row=1,col=2)
    fig.update_layout(width=800, height=400, showlegend=False, template='plotly_white')
    fig_widget = go.FigureWidget(fig)

    def update_slider_range(*_):
        arr = b_seq_inp[id_dd.value]
        # now slider runs 1…(T−1)
        time_slider.min = 1
        time_slider.max = arr.shape[0] - 1
        time_slider.value = 1

    def update_plot(*_):
        sid = id_dd.value
        t   = time_slider.value
        arr = b_seq_inp[sid]
        msgs= msg_seq_raw[sid]

        # take book states at t−1 and t
        s0 = arr[t-1, 240:263]
        s1 = arr[t,   240:263]

        # s0 = arr[t-1]
        # s1 = arr[t]


        diff = s1 - s0
        x = np.arange(len(s0)) - len(s0)//2

        with fig_widget.batch_update():
            fig_widget.data = []
            fig_widget.add_bar(x=x, y=s0, row=1, col=1, marker_color='orange')
            colors = ['orange' if abs(d)<1e-8 else ('red' if d>0 else 'blue') for d in diff]
            fig_widget.add_bar(x=x, y=s1, row=1, col=2, marker_color=colors)
            fig_widget.layout.annotations[0].text = f"Book state {t-1}"
            fig_widget.layout.annotations[1].text = f"Book state {t}"

        # show message at index t
        m = msgs[t].astype(int)
        # fields: [0]=timestamp, [1]=etype, [2]=dir, [3]=abspr, [4]=relpr, [5]=size, …
        et, dr, abspr, relpr, sz = m[1], m[2], m[3], m[4], m[5]
        et_map = {1:"Limit",2:"PartialCancel",3:"Delete",4:"Execution"}
        dr_map = {1:"Buy",0:"Sell"}
        info = (
            f"{et_map.get(et,'?')} • "
            f"{dr_map.get(dr,'?')} • "
            f"abs={abspr} • rel={relpr} • size={sz}"
        )
        msg_box.value = f"<b>{info}<br></b>raw:{m.tolist()}"

    def on_prev(b):
        if time_slider.value>time_slider.min:
            time_slider.value -= 1
    def on_next(b):
        if time_slider.value<time_slider.max:
            time_slider.value += 1

    # wire up events
    id_dd.observe(lambda c: update_slider_range(), names='value')
    time_slider.observe(lambda c: update_plot(), names='value')
    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)

    # initial draw
    update_slider_range()
    update_plot()

    display(widgets.HBox([id_dd, btn_prev, btn_next, time_slider]))
    display(fig_widget, msg_box)


interactive_lob_plot(b_dict, m_dict)

HBox(children=(Dropdown(description='Sample ID:', options=(1895, 2787, 4668, 14606, 16120, 22181), value=1895)…

FigureWidget({
    'data': [{'marker': {'color': 'orange'},
              'type': 'bar',
              'uid': 'ccae450c-410c-4fc4-aed4-3dd4e250dffe',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x',
              'y': {'bdata': ('AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' ... 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAA='),
                    'dtype': 'f4'},
              'yaxis': 'y'},
             {'marker': {'color': [orange, red, red, red, red, red, red, red, red,
                                   red, red, orange, blue, blue, blue, blue, blue,
                                   blue, blue, blue, blue, blue, orange]},
              'type': 'bar',
              'uid': '06374043-c809-41f2-b6dd-68ec6dd2083c',
              'x': {'bdata': '9fb3+Pn6+/z9/v8AAQIDBAUGBwgJCgs=', 'dtype': 'i1'},
              'xaxis': 'x2',
              'y': {'bdata': ('AAAAAK9H4T5WDg0/MN0EP5qZGT+pxm' ... 'rA2s4Xv/T9FL/qJvG+wcohvwAAAAA='),
                    'dtype': 'f

HTML(value='<b>Delete • Buy • abs=891500 • rel=-9 • size=500<br></b>raw:[-2, 3, 1, 891500, -9, 500, 0, 387358,…