In [1]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as onp

from pathlib import Path

import pandas as pd

from datetime import datetime
import yaml

import historical_scenario
import numpy as np
from tqdm import tqdm

from flax.training.train_state import TrainState
from lob.lobster_dataloader import LOBSTER_Dataset
import flax.linen as nn
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union



In [2]:
def parse_args(config_file):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default=config_file,
        help="Path to your YAML config file"
    )

    # Используем parse_known_args вместо parse_args
    args, _ = parser.parse_known_args()
    return args

In [3]:
args = parse_args("1_run_exp_aggresive_scenario.yaml")
with open(args.config, "r") as f:
    cfg = yaml.safe_load(f)

In [4]:
save_folder       = cfg["save_folder"]
batch_size        = cfg["batch_size"]
n_samples         = cfg["n_samples"]
n_gen_msgs        = cfg["n_gen_msgs"]
midprice_step_size= cfg["midprice_step_size"]
num_insertions    = cfg["num_insertions"]
num_coolings      = cfg["num_coolings"]
EVENT_TYPE_i      = cfg["EVENT_TYPE_i"]
DIRECTION_i       = cfg["DIRECTION_i"]
order_volume      = cfg["order_volume"]
bsz               = cfg["bsz"]
n_messages        = cfg["n_messages"]
book_dim          = cfg["book_dim"]
n_vol_series      = cfg["n_vol_series"]
sample_top_n      = cfg["sample_top_n"]
model_size        = cfg["model_size"]
data_dir          = cfg["data_dir"]
sample_all        = cfg["sample_all"]
stock             = cfg["stock"]
tick_size         = cfg["tick_size"]
rng_seed          = cfg["rng_seed"]
ckpt_path         = cfg["ckpt_path"]
num_devices = jax.local_device_count()
print(f'num_devices: ', num_devices)

batch_size = 2
n_samples = 4
bsz = 8
n_gen_msgs = 50
midprice_step_size = 50
num_insertions = 5
num_coolings = 1
EVENT_TYPE_i = 4
DIRECTION_i = 0
order_volume = 75

num_devices:  1


In [5]:
# Load metadata and model
print("Loading metadata from", ckpt_path)
args_ckpt = load_metadata(ckpt_path)
print("Initializing model...")
train_state, model_cls = init_train_state(
    args_ckpt,
    n_classes=len(Vocab()),
    seq_len=n_messages * Message_Tokenizer.MSG_LEN,
    book_dim=book_dim,
    book_seq_len=n_messages,
)

import jax
from lob import init_train

def safe_deduplicate_trainstate(state):
    try:
        devices = jax.devices("gpu")
    except RuntimeError:
        devices = jax.devices("cpu")
        print("[INFO] GPU not available. Falling back to CPU.")
    else:
        print("[INFO] Running on GPU.")
    
    return jax.device_put(
        jax.tree.map(lambda x: x[0], state),
        device=devices[0]
    )

init_train.deduplicate_trainstate = safe_deduplicate_trainstate
print("Loading checkpoint...")
ckpt = load_checkpoint(train_state, ckpt_path, train=False)
state = ckpt["model"]
model = model_cls(training=False, step_rescale=1.0)

Loading metadata from checkpoints/denim-elevator-754_czg1ss71/
Initializing model...
configuring standard optimization setup
[*] Trainable Parameters: 35776312
Loading checkpoint...
[INFO] GPU not available. Falling back to CPU.


In [6]:
# prepare RNG
rng = jax.random.PRNGKey(rng_seed)

# data directory
data_path = Path(data_dir) / stock
data_path.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_path} ({len(list(data_path.iterdir()))} files)")

# Experiment upload folder
exp_folder = historical_scenario.create_next_experiment_folder(save_folder)
print("Experiment dir:", exp_folder)
with open(exp_folder / "used_config.yaml", "w") as f_out:
    yaml.dump(cfg, f_out)

Data directory: /app/data/test_set/GOOG (37 files)
Experiment dir: /app/data_saved/exp_156_20250722_143644


In [7]:
# get dataset
ds = inference.get_dataset(data_path, n_messages, (num_insertions + num_coolings) * n_gen_msgs)

# ================= Scenario debugging =================

In [8]:
def run_historical_scenario(
        n_samples: int,
        batch_size: int,
        ds: LOBSTER_Dataset,
        rng: jax.dtypes.prng_key,
        seq_len: int,
        n_msgs: int,
        n_gen_msgs: int,
        train_state: TrainState,
        model: nn.Module,
        batchnorm: bool,
        encoder: Dict[str, Tuple[jax.Array, jax.Array]],
        stock_symbol: str,
        n_vol_series: int = 500,
        save_folder: str = './data_saved/',
        tick_size: int = 100,
        sample_top_n: int = -1,
        sample_all: bool = False,
        num_insertions: int = 2,
        num_coolings: int = 2,
        midprice_step_size=100,
        EVENT_TYPE_i = 4,
        DIRECTION_i = 0,
        order_volume = 75
    ):
    """
    Manual, step-by-step scenario runner: for each batch, processes messages one at a time,
    updating the orderbook and tracking midprices, mimicking track_midprices_during_messages.
    Saves processed messages, books, and midprices for each batch.

    
    """
    

    rng, rng_ = jax.random.split(rng)
    if sample_all:
        sample_i = jnp.arange(
            len(ds) // batch_size * batch_size,
            dtype=jnp.int32
        ).reshape(-1, batch_size).tolist()
    else:
        assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
        sample_i = jax.random.choice(
            rng_,
            jnp.arange(len(ds), dtype=jnp.int32),
            shape=(n_samples // batch_size, batch_size),
            replace=False
        ).tolist()
    rng, rng_ = jax.random.split(rng)

    save_folder = Path(save_folder)
    save_folder.joinpath('msgs_decoded_doubled').mkdir(exist_ok=True, parents=True)
    save_folder.joinpath('l2_book_states_halved').mkdir(exist_ok=True, parents=True)
    save_folder.joinpath('b_seq_gen_doubled').mkdir(exist_ok=True, parents=True)
    save_folder.joinpath('mid_price').mkdir(exist_ok=True, parents=True)
    base_save_folder = save_folder

    for batch_i in tqdm(sample_i):
        print('BATCH', batch_i)
        m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]
        m_seq = jnp.array(m_seq)
        b_seq_pv = jnp.array(b_seq_pv)
        msg_seq_raw = jnp.array(msg_seq_raw)
        book_l2_init = jnp.array(book_l2_init)

        #=============#
        # Step 1: Prepare positions where to insert messages (accounting for prior insertions)
        insertion_points = [n_msgs + (i + 1) * n_gen_msgs + i for i in range(num_insertions)]
        insertion_points = [p for p in insertion_points if p <= msg_seq_raw.shape[1]]
        print(f"[BATCH {batch_i}] Inserting custom messages at: {insertion_points}")

        # Step 2: Generate placeholder message using same logic as insert_custom_end (just 14-dim msg, no book logic yet)
        def construct_custom_msg(last_msg):
            ORDER_ID_i = 77777777
            EVENT_TYPE = jnp.full((batch_size,), EVENT_TYPE_i)
            SIDE = jnp.full((batch_size,), DIRECTION_i)
            PRICE = last_msg[:, 3]  # or use fixed jnp.full((batch_size,), 123456)
            DISPLAY_FLAG = jnp.ones((batch_size,), dtype=jnp.int32)
            SIZE = jnp.full((batch_size,), order_volume)
            zeros = jnp.zeros((batch_size,), dtype=jnp.int32)
            TIME_s = last_msg[:, 8]
            TIME_ns = last_msg[:, 9]

            msg = jnp.stack([
                jnp.full((batch_size,), ORDER_ID_i),
                EVENT_TYPE,
                SIDE,
                PRICE,
                DISPLAY_FLAG,
                SIZE,
                zeros, zeros,
                TIME_s, TIME_ns,
                zeros, zeros, zeros, zeros,
            ], axis=1)

            return msg.astype(jnp.int32)

        # Step 3: Loop and insert messages
        for i, idx in enumerate(insertion_points):
            custom_msg = construct_custom_msg(msg_seq_raw[:, idx - 1])
            msg_seq_raw = jnp.concatenate([
                msg_seq_raw[:, :idx, :],
                custom_msg[:, None, :],  # shape (B, 1, 14)
                msg_seq_raw[:, idx:, :]
            ], axis=1)

        print(f"[BATCH {batch_i}] msg_seq_raw shape after insertions: {msg_seq_raw.shape}")
        #=============#

        batch_size, T, msg_dim = msg_seq_raw.shape
        current_book = book_l2_init

        books = []
        messages = []
        midprices = []

        for t in range(T):
            msg = msg_seq_raw[:, t:t+1, :]
            sim_init, sim_states = inference.get_sims_vmap(current_book, msg)
            mid_price = inference.batched_get_safe_mid_price(sim_init, sim_states, tick_size)
            full_l2_state = jax.vmap(sim_init.get_L2_state, in_axes=(0, None))(sim_states, current_book.shape[1])
            current_book = full_l2_state[:, : current_book.shape[1]]
            books.append(current_book)
            messages.append(msg)
            midprices.append(mid_price)

        books = jnp.stack(books, axis=1)             # (batch, T, book_dim)
        messages = jnp.concatenate(messages, axis=1) # (batch, T, msg_dim)
        midprices = jnp.stack(midprices, axis=0)     # (T, batch)

        print(f"[BATCH {batch_i}] Finished all {T} steps")
        print(f"[BATCH {batch_i}] Final messages shape: {messages.shape}")
        print(f"[BATCH {batch_i}] Final books shape: {books.shape}")
        print(f"[BATCH {batch_i}] Final midprices shape: {midprices.shape}")

        np.save(os.path.join(base_save_folder, 'msgs_decoded_doubled', f'msgs_decoded_doubled_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(messages)))
        np.save(os.path.join(base_save_folder, 'l2_book_states_halved', f'l2_book_states_halved_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(books)))
        np.save(os.path.join(base_save_folder, 'mid_price', f'mid_price_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(midprices)))

        # ========================
        transform_L2_state_batch = jax.jit(jax.vmap(preproc.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))

        # Get midprices for each step: (T, B) → (B, T)
        midprices_batched = midprices.T  # (B, T)
        p_mid = midprices_batched[:, :, None]  # (B, T, 1)

        # Add midprice as first column to each book state
        books_with_mid = jnp.concatenate([p_mid, books], axis=-1)  # (B, T, 41)
        print(f"[BATCH {batch_i}] books_with_mid.shape: {books_with_mid.shape}")

        # Transform each book+midprice into model input format
        books_transformed = transform_L2_state_batch(books_with_mid, n_vol_series, tick_size)  # (B, T, D)
        print(f"[BATCH {batch_i}] books_transformed.shape: {books_transformed.shape}")

        # Save transformed books
        np.save(os.path.join(base_save_folder, 'b_seq_gen_doubled', f'b_seq_gen_doubled_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(books_transformed)))


In [9]:
results = run_historical_scenario(
        n_samples,
        batch_size,
        ds,
        rng,
        n_messages * Message_Tokenizer.MSG_LEN,
        n_messages,
        n_gen_msgs,
        state,
        model,
        args_ckpt.batchnorm,
        Vocab().ENCODING,
        stock,
        n_vol_series,
        exp_folder,
        tick_size,
        sample_top_n,
        sample_all,
        num_insertions,
        num_coolings,
        midprice_step_size,
        EVENT_TYPE_i,
        DIRECTION_i,
        order_volume,
    )

  0%|          | 0/2 [00:00<?, ?it/s]

BATCH [24152, 8549]
[BATCH [24152, 8549]] Inserting custom messages at: [550, 601, 652, 703, 754]
[BATCH [24152, 8549]] msg_seq_raw shape after insertions: (2, 805, 14)
[BATCH [24152, 8549]] Finished all 805 steps
[BATCH [24152, 8549]] Final messages shape: (2, 805, 14)
[BATCH [24152, 8549]] Final books shape: (2, 805, 40)
[BATCH [24152, 8549]] Final midprices shape: (805, 2)
[BATCH [24152, 8549]] books_with_mid.shape: (2, 805, 41)


 50%|█████     | 1/2 [00:21<00:21, 21.31s/it]

[BATCH [24152, 8549]] books_transformed.shape: (2, 805, 501)
BATCH [9483, 19480]
[BATCH [9483, 19480]] Inserting custom messages at: [550, 601, 652, 703, 754]
[BATCH [9483, 19480]] msg_seq_raw shape after insertions: (2, 805, 14)


100%|██████████| 2/2 [00:26<00:00, 13.47s/it]

[BATCH [9483, 19480]] Finished all 805 steps
[BATCH [9483, 19480]] Final messages shape: (2, 805, 14)
[BATCH [9483, 19480]] Final books shape: (2, 805, 40)
[BATCH [9483, 19480]] Final midprices shape: (805, 2)
[BATCH [9483, 19480]] books_with_mid.shape: (2, 805, 41)
[BATCH [9483, 19480]] books_transformed.shape: (2, 805, 501)





# ================= Data saving debugging =================

In [10]:
example_inp_file = "/app/data_saved/exp_73_20250701_210502/b_seq_gen_doubled/b_seq_inp_[14606, 16120].npy"
current_inp_file = "/app/data_saved/exp_153_20250722_140141/l2_book_states_halved/l2_book_states_halved_batch_[13104, 5937]_iter_0.npy"

example_inp_file = np.load(example_inp_file)
current_inp_file = np.load(current_inp_file)