In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as onp

from pathlib import Path
import os

import pandas as pd

2025-04-24 18:54:48.706589: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.6 which is older than the ptxas CUDA version (12.8.93). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
# viz function


import plotly.graph_objs as go
import numpy as np
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

import plotly.io as pio
pio.renderers.default = 'notebook'  # Or 'notebook_connected'


def interactive_lob_plot(b_seq_inp, msg_seq_raw):
    index = {"value": 0}

    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            f"Sample 0 – book state {index['value']}",
            f"Sample 0 – book state {index['value'] + 1}",
            f"Sample 1 – book state {index['value']}",
            f"Sample 1 – book state {index['value'] + 1}"
        ),
        specs=[[{}, {}], [{}, {}]],
        vertical_spacing=0.2
    )

    for _ in range(4):
        fig.add_trace(go.Bar(x=[], y=[]), row=1 if _ < 2 else 2, col=1 if _ % 2 == 0 else 2)

    fig.update_layout(
        title_text="L2 Book State Diff – Sample 0 and Sample 1",
        height=1000,
        width=1000,
        showlegend=False,
        template='plotly_white'
    )

    for r in [1, 2]:
        fig.update_xaxes(title_text="Bin Index", row=r, col=1)
        fig.update_yaxes(title_text="Volume", row=r, col=1)
        fig.update_xaxes(title_text="Bin Index", row=r, col=2)
        fig.update_yaxes(title_text="Volume", row=r, col=2)

    fig_widget = go.FigureWidget(fig)

    msg_text_boxes = [widgets.HTML(), widgets.HTML()]
    display_area = widgets.VBox([widgets.HBox(msg_text_boxes)])

    def update_plot(i):
        fig_widget.data = []

        for sample in [0, 1]:
            snapshot_0 = np.array(b_seq_inp[sample][i])[240:263]
            snapshot_1 = np.array(b_seq_inp[sample][i+1])[240:263]
            diff = abs(snapshot_1) - abs(snapshot_0)

            colors_0 = ['orange'] * len(snapshot_0)
            colors_1 = ['orange' if np.isclose(d, 0) else ('red' if d > 0 else 'blue') for d in diff]
            bin_x = np.arange(len(snapshot_0)) - len(snapshot_0) // 2

            fig_widget.add_trace(go.Bar(x=bin_x, y=snapshot_0, marker_color=colors_0), row=sample + 1, col=1)
            fig_widget.add_trace(go.Bar(x=bin_x, y=snapshot_1, marker_color=colors_1), row=sample + 1, col=2)

            fig_widget.layout.annotations[sample * 2 + 0].text = f"Sample {sample} – book state {i}"
            fig_widget.layout.annotations[sample * 2 + 1].text = f"Sample {sample} – book state {i + 1}"

            # build message info box
            msg = np.array(msg_seq_raw[sample][i + 1])
            event_type = int(msg[1])
            direction = int(msg[2])
            rel_price = int(msg[4])
            abs_price = int(msg[3])
            size = int(msg[5])

            event_type_map = {1: "Limit Order", 2: "Partial Cancel", 3: "Delete", 4: "Execution"}
            direction_map = {1: "Buy", 0: "Sell"}

            event_type_str = event_type_map.get(event_type, "Unknown")
            direction_str = direction_map.get(direction, "Unknown")

            msg_desc = f"<b>Sample {sample}:</b> {event_type_str}, {direction_str}, rel_price={rel_price}, abs_price={abs_price}, size={size}"
            full_msg = f"<i>Raw:</i> {msg.tolist()}"

            msg_text_boxes[sample].value = msg_desc + "<br>" + full_msg

        index["value"] = i
        jump_box.value = str(i)

    def on_click_left(b):
        if index["value"] > 0:
            index["value"] -= 1
            update_plot(index["value"])

    def on_click_right(b):
        if index["value"] < b_seq_inp.shape[1] - 2:
            index["value"] += 1
            update_plot(index["value"])

    def on_enter_text(change):
        try:
            new_i = int(change['new'])
            if 0 <= new_i < b_seq_inp.shape[1] - 1:
                update_plot(new_i)
        except:
            pass

    button_left = widgets.Button(description="←")
    button_right = widgets.Button(description="→")
    jump_box = widgets.Text(value=str(index["value"]), description='Jump to i:', layout=widgets.Layout(width="150px"))

    button_left.on_click(on_click_left)
    button_right.on_click(on_click_right)
    jump_box.observe(on_enter_text, names='value')

    update_plot(index["value"])
    display(widgets.HBox([button_left, button_right, jump_box]))
    display(fig_widget)
    display(display_area)

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def plot_l2_change_all_samples(sim_states_init, new_sim_state, sort_by_price=True):
    asks_before, bids_before, _, _ = sim_states_init
    asks_after, bids_after, _, _ = new_sim_state

    batch_size = asks_before.shape[0]
    n_levels = asks_before.shape[1]

    fig = make_subplots(
        rows=batch_size, cols=1,
        subplot_titles=[f"Sample {i}" for i in range(batch_size)],
        shared_xaxes=False
    )

    for i in range(batch_size):
        a_b = np.array(asks_before[i])
        a_a = np.array(asks_after[i])
        b_b = np.array(bids_before[i])
        b_a = np.array(bids_after[i])

        if sort_by_price:
            # сортировка по цене: ask по возрастанию, bid по убыванию
            valid_asks_mask = a_b[:, 0] != -1
            valid_bids_mask = b_b[:, 0] != -1

            a_b = a_b[valid_asks_mask]
            a_a = a_a[valid_asks_mask]
            b_b = b_b[valid_bids_mask]
            b_a = b_a[valid_bids_mask]

            a_sorted = np.lexsort((a_b[:, 0],))  # по цене вверх
            b_sorted = np.lexsort((-b_b[:, 0],))  # по цене вниз

            a_b = a_b[a_sorted]
            a_a = a_a[a_sorted]
            b_b = b_b[b_sorted]
            b_a = b_a[b_sorted]
        else:
            b_b = b_b[::-1]
            b_a = b_a[::-1]

        # объёмы и цены
        vol_b_b, vol_b_a = b_b[:, 1], b_a[:, 1]
        vol_a_b, vol_a_a = a_b[:, 1], a_a[:, 1]
        price_b_b, price_b_a = b_b[:, 0], b_a[:, 0]
        price_a_b, price_a_a = a_b[:, 0], a_a[:, 0]

        # цвета
        colors_bids_before = ['yellow' if vol_b_b[j] == vol_b_a[j] else 'red' for j in range(len(vol_b_b))]
        colors_bids_after  = ['orange' if vol_b_b[j] == vol_b_a[j] else 'red' for j in range(len(vol_b_b))]
        colors_asks_before = ['lightblue' if vol_a_b[j] == vol_a_a[j] else 'red' for j in range(len(vol_a_b))]
        colors_asks_after  = ['navy' if vol_a_b[j] == vol_a_a[j] else 'red' for j in range(len(vol_a_b))]

        # hover text
        hover_bids_before = [f"bid | price: {int(p)} | volume: {int(v)}" for p, v in zip(price_b_b, vol_b_b)]
        hover_bids_after  = [f"bid | price: {int(p)} | volume: {int(v)}" for p, v in zip(price_b_a, vol_b_a)]
        hover_asks_before = [f"ask | price: {int(p)} | volume: {int(v)}" for p, v in zip(price_a_b, vol_a_b)]
        hover_asks_after  = [f"ask | price: {int(p)} | volume: {int(v)}" for p, v in zip(price_a_a, vol_a_a)]

        # уровни по индексу (не цене)
        x_bids = np.arange(-len(vol_b_b), 0)
        x_asks = np.arange(0, len(vol_a_b))

        fig.add_trace(go.Bar(
            x=x_bids,
            y=vol_b_b,
            name='bids before',
            marker_color=colors_bids_before,
            hovertext=hover_bids_before,
            hoverinfo="text",
            showlegend=(i == 0)
        ), row=i + 1, col=1)

        fig.add_trace(go.Bar(
            x=x_bids,
            y=vol_b_a,
            name='bids after',
            marker_color=colors_bids_after,
            hovertext=hover_bids_after,
            hoverinfo="text",
            showlegend=(i == 0)
        ), row=i + 1, col=1)

        fig.add_trace(go.Bar(
            x=x_asks,
            y=vol_a_b,
            name='asks before',
            marker_color=colors_asks_before,
            hovertext=hover_asks_before,
            hoverinfo="text",
            showlegend=(i == 0)
        ), row=i + 1, col=1)

        fig.add_trace(go.Bar(
            x=x_asks,
            y=vol_a_a,
            name='asks after',
            marker_color=colors_asks_after,
            hovertext=hover_asks_after,
            hoverinfo="text",
            showlegend=(i == 0)
        ), row=i + 1, col=1)

    fig.update_layout(
        title_text="L2 Depth Change by Level – All Samples in Batch",
        xaxis_title="Level index (sorted by price)" if sort_by_price else "Level index (raw order)",
        yaxis_title="Volume",
        barmode="group",
        height=300 * batch_size,
        width=1100,
        template="plotly_white"
    )

    fig.show()


In [5]:
save_folder = 'data_saved/'
batch_size = 2
model_size = 'large'
data_dir ='data/test_set/'
rng_seed = 42
sample_all = False # action='store_true'

In [6]:
stock = 'GOOG'  # 'GOOG', 'INTC'
# n_gen_msgs = 500 # how many messages to generate into the future

#decreasing
n_gen_msgs = 500 # how many messages to generate into the future
book_dim = 501 #b_enc.shape[1] 500+1=501
n_messages = 500  # length of input sequence to model

n_samples = 2
tick_size = 100

In [7]:
# n_messages = 500  # length of input sequence to model
# batch_size = batch_size

v = Vocab()
n_classes = len(v)
seq_len = n_messages * Message_Tokenizer.MSG_LEN
# book_dim = 501 #b_enc.shape[1]
book_seq_len = n_messages

n_eval_messages = n_gen_msgs  # how many to load from dataset
eval_seq_len = n_eval_messages * Message_Tokenizer.MSG_LEN

rng = jax.random.key(rng_seed)
rng, rng_ = jax.random.split(rng)

In [8]:
if stock == 'GOOG':
    # ckpt_path = './checkpoints/treasured-leaf-149_84yhvzjt/' # 0.5 y GOOG, (full model)
    ckpt_path = './checkpoints/denim-elevator-754_czg1ss71/' # large model
    # ckpt_path = './checkpoints/stilted-deluge-759_8g3vqor4'  # small model
elif stock == 'INTC':
    # ckpt_path = './checkpoints/pleasant-cherry-152_i6h5n74c/' # 0.5 y INTC, (full model)
    ckpt_path = './checkpoints/eager-sea-755_2rw1ofs3/'  # large model
else:
    raise ValueError(f'stock {stock} not recognized')

In [9]:
print('Loading metadata:', ckpt_path)
args_ckpt = load_metadata(ckpt_path)

Loading metadata: ./checkpoints/denim-elevator-754_czg1ss71/


In [10]:
# scale down to single GPU, single sample inference
args_ckpt.bsz = 1 #1, 10
args_ckpt.num_devices = 1

batchnorm = args_ckpt.batchnorm

In [11]:
# load train state from disk

print('Initializing model...')
new_train_state, model_cls = init_train_state(
    args_ckpt,
    n_classes=n_classes,
    seq_len=seq_len,
    book_dim=book_dim,
    book_seq_len=book_seq_len,
)

print('\nLoading model checkpoint...')
ckpt = load_checkpoint(
    new_train_state,
    ckpt_path,
    train=False,
)
state = ckpt['model']

model = model_cls(training=False, step_rescale=1.0)

param_count = sum(x.size for x in jax.tree_leaves(state.params))
print('param count:', param_count)

Initializing model...


In [None]:
data_dir = data_dir + stock
# data_dir = data_dir + 'GOOG'
print(f"Directory Path: {data_dir}")

In [None]:
# data_dir = Path.home() / "data/test_set/GOOG"
# Or use a relative path from current working directory
# data_dir = Path("data/test_set/GOOG")
data_dir = Path(data_dir)

Path(data_dir).mkdir(parents=True, exist_ok=True)
folder_path = Path(data_dir)
file_count = len([f for f in folder_path.iterdir() if f.is_file()])
print(f"There are {file_count} files in the folder {str(data_dir)}.")

In [None]:
print(n_messages, n_eval_messages)

In [None]:
from pathlib import Path
import os

# Print current working directory to help verify the path
print(f"Current working directory: {os.getcwd()}")

# Use relative path to data/test_set/GOOG
data_dir = Path("data/test_set/GOOG")

try:
    Path(data_dir).mkdir(parents=True, exist_ok=True)
    print(f"Successfully created or verified directory: {data_dir}")
    
    file_count = len([f for f in Path(data_dir).iterdir() if f.is_file()])
    print(f"There are {file_count} files in the folder {data_dir}.")
except Exception as e:
    print(f"Error: {str(e)}")

ds = inference.get_dataset(data_dir, n_messages, n_eval_messages)

In [None]:
ds

# check 1 by 1

In [None]:
batch_i = [4497, 6855]
# batch_i = [3397, 5855]
m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]

In [None]:
print(f'm_seq.shape: {m_seq[0].shape}\n\n')
print(f'b_seq_pv.shape: {b_seq_pv[0].shape}\n\n')
print(f'msg_seq_raw.shape: {msg_seq_raw[0].shape}\n\n')
print(f'book_l2_init.shape: {book_l2_init[0].shape}\n\n')

In [None]:
b_seq_inp, msg_seq_raw, midprices, m_seq_gen_doubled, b_seq_gen_doubled, msgs_decoded_doubled, l2_book_states_halved, l2_book_states, sim_init, sim_states_init, book_l2_init = inference.run_generation_scenario(
    n_samples, 
    batch_size,
    ds,
    rng,
    seq_len,
    n_messages,
    n_gen_msgs,
    state,
    model,
    batchnorm,
    v.ENCODING,
    stock,
    n_vol_series = 500,
    # sim_book_levels: int = 20,
    # sim_queue_len: int = 100,
    # data_levels: int = 10,
    save_folder = './data_saved/',
    tick_size = 100,
    sample_top_n = -1,
    sample_all = False,
    # Insertions variables
    num_insertions = 20,
    num_coolings = 20
    )

In [None]:
transform_L2_state_batch = jax.jit(
        jax.vmap(
            inference.transform_L2_state,
            in_axes=(0, None, None)
        ),
        static_argnums=(1, 2)
    )

In [None]:
def insert_custom_end(m_seq_gen_doubled, b_seq_gen_doubled, msgs_decoded_doubled,
                        l2_book_states_halved, encoder, mid_price, tick_size = 100):
    
    ORDER_ID_i = 1236128736
    EVENT_TYPE_i = 4
    DIRECTION_i = 0
    # sim_init, sim_states_init = inference.get_sims_vmap(l2_book_states_halved[:,0], msgs_decoded_doubled)
    sim_init, sim_states_init = inference.get_sims_vmap(l2_book_states_halved[:,-2], msgs_decoded_doubled[:,-1:])
    PRICE_i = jax.vmap(sim_init.get_best_ask)(sim_states_init)
    PRICE_i = jnp.expand_dims(PRICE_i, axis=-1)
    mid_price = jnp.expand_dims(mid_price, axis=-1)

    TIMEs_i = msgs_decoded_doubled[:, -1:, 8].astype(jnp.int32)
    TIMEns_i = msgs_decoded_doubled[:, -1:, 9].astype(jnp.int32)

    batch_size = TIMEns_i.shape[0]

    best_bid_ask = jax.vmap(sim_init.get_best_bid_and_ask_inclQuants)(sim_states_init)
    best_ask_volume = best_bid_ask[1][:, 1]


    SIZE_i = jnp.minimum(best_ask_volume, 75)
    batched_quantity = SIZE_i#jnp.array([SIZE_i] * batch_size, dtype=jnp.int32)


    # SIZE_i = 1500
    # batched_quantity = jnp.array([SIZE_i] * batch_size, dtype=jnp.int32)


    
    batched_new_order_id = jnp.array([ORDER_ID_i] * batch_size, dtype=jnp.int32)
    batched_EVENT_TYPE = jnp.array([EVENT_TYPE_i] * batch_size, dtype=jnp.int32)
    batched_side = jnp.array([DIRECTION_i] * batch_size, dtype=jnp.int32)

    batched_p_abs = PRICE_i.squeeze(-1)
    # batched_p_abs = jnp.array([1] * batch_size, dtype=jnp.int32)

    
    batched_time_s = TIMEs_i.squeeze(-1)
    batched_time_ns = TIMEns_i.squeeze(-1)

    batched_construct_sim_msg = jax.vmap(inference.construct_sim_msg)
    batched_sim_msg = batched_construct_sim_msg(
        batched_EVENT_TYPE,
        batched_side,
        batched_quantity,
        batched_p_abs,
        batched_new_order_id,
        batched_time_s,
        batched_time_ns,
    )

    # =================== #
    
    new_sim_state = jax.vmap(sim_init.process_order_array)(sim_states_init, batched_sim_msg)
    # plot_l2_change_all_samples(sim_states_init, new_sim_state)
    
    

    p_mid_new = inference.batched_get_safe_mid_price(sim_init, new_sim_state, tick_size)
    p_mid_new = p_mid_new[:, None]
    p_change = ((p_mid_new - mid_price) // tick_size).astype(jnp.int32)

    book_l2 = jax.vmap(sim_init.get_L2_state, in_axes=(0, None))(new_sim_state, 20)
    new_l2_book_states_halved = book_l2

    new_book_raw = jnp.concatenate([p_change, book_l2], axis=1)
    new_book_raw = new_book_raw[:, None, :]

    transform_L2_state_batch = jax.jit(jax.vmap(inference.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))

    new_book = transform_L2_state_batch(new_book_raw, 500, 100)  # shape: (batch, 501)


    b_seq_gen_doubled = jnp.concatenate([b_seq_gen_doubled[:, 1:, :], new_book], axis=1)
    
   # =================== #

    ins_msg = jnp.concatenate([
        batched_new_order_id.reshape(-1, 1),
        batched_EVENT_TYPE.reshape(-1, 1),
        batched_side.reshape(-1, 1),
        batched_p_abs.reshape(-1, 1),
        jnp.full((batch_size, 1), 1, dtype=jnp.int32),              # abs_price
        batched_quantity.reshape(-1, 1),
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # dt_s
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # dt_ns
        batched_time_s.reshape(-1, 1),
        batched_time_ns.reshape(-1, 1),
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # ref_price
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # ref_size
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # ref_time_s
        jnp.full((batch_size, 1), 0, dtype=jnp.int32),              # ref_time_ns
    ], axis=1)


    new_batched_sim_msg = ins_msg[:, None, :]
    UPDATED_msgs_decoded_doubled = jnp.concatenate([msgs_decoded_doubled[:, 1:, :], new_batched_sim_msg], axis=1)


    msg_encoded = jax.vmap(lambda m: encoding.encode_msg(m, encoder))(ins_msg)
    shift = Message_Tokenizer.MSG_LEN
    UPDATED_m_seq_gen_doubled = jnp.concatenate([m_seq_gen_doubled[:, shift:], msg_encoded], axis=1)


    return UPDATED_m_seq_gen_doubled, b_seq_gen_doubled, UPDATED_msgs_decoded_doubled, new_l2_book_states_halved, p_mid_new

In [None]:
UPDATED_m_seq_gen_doubled, b_seq_gen_doubled, UPDATED_msgs_decoded_doubled, new_l2_book_states_halved, p_mid_new = insert_custom_end(m_seq_gen_doubled, 
                                                                                                                                            b_seq_gen_doubled, 
                                                                                                                                            msgs_decoded_doubled, 
                                                                                                                                            l2_book_states, 
                                                                                                                                            v.ENCODING, 
                                                                                                                                            midprices[-1])

In [None]:
UPDATED_m_seq_gen_doubled.shape