In [1]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as onp

from pathlib import Path

import pandas as pd

from datetime import datetime
import yaml

import historical_scenario
import numpy as np
from tqdm import tqdm

from flax.training.train_state import TrainState
from lob.lobster_dataloader import LOBSTER_Dataset
import flax.linen as nn
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import json

2025-08-14 20:01:25.967352: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.8 which is older than the ptxas CUDA version (12.9.41). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
def parse_args(config_file):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default=config_file,
        help="Path to your YAML config file"
    )

    # Используем parse_known_args вместо parse_args
    args, _ = parser.parse_known_args()
    return args

In [3]:
args = parse_args("historical_scenario.yaml")
with open(args.config, "r") as f:
    cfg = yaml.safe_load(f)

In [4]:
save_folder       = cfg["save_folder"]
batch_size        = cfg["batch_size"]
n_samples         = cfg["n_samples"]
n_gen_msgs        = cfg["n_gen_msgs"]
midprice_step_size= cfg["midprice_step_size"]
num_insertions    = cfg["num_insertions"]
num_coolings      = cfg["num_coolings"]
EVENT_TYPE_i      = cfg["EVENT_TYPE_i"]
DIRECTION_i       = cfg["DIRECTION_i"]
order_volume      = cfg["order_volume"]
bsz               = cfg["bsz"]
n_messages        = cfg["n_messages"]
book_dim          = cfg["book_dim"]
n_vol_series      = cfg["n_vol_series"]
sample_top_n      = cfg["sample_top_n"]
model_size        = cfg["model_size"]
data_dir          = cfg["data_dir"]
sample_all        = cfg["sample_all"]
stock             = cfg["stock"]
tick_size         = cfg["tick_size"]
rng_seed          = cfg["rng_seed"]
ckpt_path         = cfg["ckpt_path"]
num_devices = jax.local_device_count()
print(f'num_devices: ', num_devices)

use_sample_file  = cfg["use_sample_file"]
sample_file_path = cfg["sample_file_path"]
start_batch      = cfg["start_batch"]
end_batch        = cfg["end_batch"]

order_volume = cfg["order_volume"]
order_volume_ratio = cfg["order_volume_ratio"]
use_relative_volume = cfg["use_relative_volume"]


# batch_size = 4
# n_samples = 8
# bsz = 25
# n_gen_msgs = 50
# midprice_step_size = 50
# num_insertions = 2
# num_coolings = 2
# EVENT_TYPE_i = 4
# DIRECTION_i = 0
# order_volume = 75

num_devices:  1


In [5]:
print(use_sample_file)
print(sample_file_path)
print(start_batch)
print(end_batch)

True
sample_indices_b64_bs16_ins80_cool20.json
0
125


In [6]:
# Load metadata and model
print("Loading metadata from", ckpt_path)
args_ckpt = load_metadata(ckpt_path)
print("Initializing model...")
train_state, model_cls = init_train_state(
    args_ckpt,
    n_classes=len(Vocab()),
    seq_len=n_messages * Message_Tokenizer.MSG_LEN,
    book_dim=book_dim,
    book_seq_len=n_messages,
)

import jax
from lob import init_train

def safe_deduplicate_trainstate(state):
    try:
        devices = jax.devices("gpu")
    except RuntimeError:
        devices = jax.devices("cpu")
        print("[INFO] GPU not available. Falling back to CPU.")
    else:
        print("[INFO] Running on GPU.")
    
    return jax.device_put(
        jax.tree.map(lambda x: x[0], state),
        device=devices[0]
    )

init_train.deduplicate_trainstate = safe_deduplicate_trainstate
print("Loading checkpoint...")
ckpt = load_checkpoint(train_state, ckpt_path, train=False)
state = ckpt["model"]
model = model_cls(training=False, step_rescale=1.0)

Loading metadata from checkpoints/denim-elevator-754_czg1ss71/
Initializing model...
configuring standard optimization setup
[*] Trainable Parameters: 35776312
Loading checkpoint...
[INFO] Running on GPU.


In [7]:
# prepare RNG
rng = jax.random.PRNGKey(rng_seed)

# data directory
data_path = Path(data_dir) / stock
data_path.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_path} ({len(list(data_path.iterdir()))} files)")

# Experiment upload folder
exp_folder = historical_scenario.create_next_experiment_folder(save_folder)
print("Experiment dir:", exp_folder)
with open(exp_folder / "used_config.yaml", "w") as f_out:
    yaml.dump(cfg, f_out)

Data directory: /app/data/test_set/GOOG (37 files)
Experiment dir: /app/data_saved/exp_53_20250814_200156


In [8]:
# get dataset
ds = inference.get_dataset(data_path, n_messages, (num_insertions + num_coolings) * n_gen_msgs)

In [9]:
ds.shape

(3951, 22)

# ================= Experiment indexes =================

In [10]:
rng, rng_ = jax.random.split(rng)
if sample_all:
    sample_i = jnp.arange(
        len(ds) // batch_size * batch_size,
        dtype=jnp.int32
    ).reshape(-1, batch_size).tolist()
else:
    assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
    sample_i = jax.random.choice(
        rng_,
        jnp.arange(len(ds), dtype=jnp.int32),
        shape=(n_samples // batch_size, batch_size),
        replace=False
    ).tolist()
rng, rng_ = jax.random.split(rng)

print(sample_i)

filename = f"sample_indices_b{len(sample_i)}_bs{batch_size}_ins{num_insertions}_cool{num_coolings}.json"

# сохраняем список в файл
with open(filename, "w") as f:
    json.dump(sample_i, f)

[[2271, 2227, 3720, 1167, 1976, 3036, 197, 1384, 797, 2359, 1679, 2347, 3916, 479, 3750, 870], [790, 3390, 1800, 1048, 2146, 2113, 1922, 2580, 3423, 2166, 3568, 2203, 855, 3881, 2916, 3205], [795, 520, 1612, 3903, 2429, 2627, 3180, 3408, 611, 95, 2139, 3597, 2481, 1168, 1414, 2949], [1690, 3150, 1099, 10, 2738, 2286, 1733, 2274, 3695, 1291, 3232, 2138, 2527, 3754, 616, 2276], [2781, 1014, 689, 465, 3261, 2151, 619, 1264, 967, 1683, 3219, 354, 3562, 2984, 464, 1311], [949, 771, 1365, 428, 1079, 2853, 1511, 1792, 2630, 2252, 1767, 1512, 1020, 1429, 2788, 2085], [1424, 2944, 2961, 3922, 2028, 1555, 1139, 96, 63, 2594, 767, 3334, 90, 741, 3353, 3493], [3054, 453, 545, 1789, 370, 3438, 2244, 2059, 1543, 106, 3764, 2931, 2374, 3121, 3806, 3352], [1879, 2420, 796, 1275, 627, 2199, 3711, 667, 724, 2588, 1889, 848, 3696, 1073, 2260, 1267], [1971, 3410, 956, 2253, 3825, 2569, 1258, 230, 1967, 3124, 1636, 3847, 3863, 2667, 1832, 262], [9, 2977, 2660, 1951, 2356, 100, 3851, 2266, 1055, 1457, 3348,

# ================= Scenario debugging =================

In [11]:
# def construct_custom_msg(last_msg, current_book, sim):
#             ORDER_ID_i = 77777777
#             EVENT_TYPE = jnp.full((batch_size,), EVENT_TYPE_i)
#             SIDE = jnp.full((batch_size,), DIRECTION_i)
#             PRICE = last_msg[:, 3]  # Use previous price
#             DISPLAY_FLAG = jnp.ones((batch_size,), dtype=jnp.int32)

#             best_bid_info, best_ask_info = jax.vmap(sim.get_best_bid_and_ask_inclQuants)(current_book)

#             if DIRECTION_i == 0:
#                 best_volume = best_ask_info[:, 1]  # ← второй столбец = объём
#             else:
#                 best_volume = best_bid_info[:, 1]

#             # Decide volume based on config flag
#             if use_relative_volume:
#                 SIZE = (order_volume_ratio * best_volume).astype(jnp.int32)
#             else:
#                 SIZE = jnp.full((batch_size,), order_volume, dtype=jnp.int32)

#             print(f'SIZE: ', SIZE)

#             # Ensure SIZE is within allowed limits
#             SIZE = jnp.clip(SIZE, 1, 999999)

#             zeros = jnp.zeros((batch_size,), dtype=jnp.int32)
#             TIME_s = last_msg[:, 8]
#             TIME_ns = last_msg[:, 9]

#             msg = jnp.stack([
#                 jnp.full((batch_size,), ORDER_ID_i),
#                 EVENT_TYPE,
#                 SIDE,
#                 PRICE,
#                 DISPLAY_FLAG,
#                 SIZE,
#                 zeros, zeros,
#                 TIME_s, TIME_ns,
#                 zeros, zeros, zeros, zeros,
#             ], axis=1)

#             return msg.astype(jnp.int32)

# def run_historical_scenario(
#         n_samples: int,
#         batch_size: int,
#         ds: LOBSTER_Dataset,
#         rng: jax.dtypes.prng_key,
#         seq_len: int,
#         n_msgs: int,
#         n_gen_msgs: int,
#         train_state: TrainState,
#         model: nn.Module,
#         batchnorm: bool,
#         encoder: Dict[str, Tuple[jax.Array, jax.Array]],
#         stock_symbol: str,
#         n_vol_series: int = 500,
#         save_folder: str = './data_saved/',
#         tick_size: int = 100,
#         sample_top_n: int = -1,
#         sample_all: bool = False,
#         num_insertions: int = 2,
#         num_coolings: int = 2,
#         midprice_step_size=100,
#         EVENT_TYPE_i = 4,
#         DIRECTION_i = 0,
#         order_volume = 75,
#         use_sample_file: bool = False,
#         sample_file_path: Optional[str] = None,
#         start_batch: int = 0,
#         end_batch: int = -1
#     ):
#     """
#     Manual, step-by-step scenario runner: for each batch, processes messages one at a time,
#     updating the orderbook and tracking midprices, mimicking track_midprices_during_messages.
#     Saves processed messages, books, and midprices for each batch.

    
#     """
    
#     rng, rng_ = jax.random.split(rng)

#     if use_sample_file:
#         assert sample_file_path is not None, "Path to sample file not provided"
        
#         with open(sample_file_path, "r") as f:
#             sample_i_full = json.load(f)

#         sample_i = sample_i_full[start_batch:end_batch if end_batch != -1 else None]

#         for i, batch in enumerate(sample_i):
#             assert len(batch) == batch_size, f"Batch {i} has incorrect size {len(batch)}, expected {batch_size}"

#     else:
#         if sample_all:
#             sample_i = jnp.arange(
#                 len(ds) // batch_size * batch_size,
#                 dtype=jnp.int32
#             ).reshape(-1, batch_size).tolist()
#         else:
#             assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
#             sample_i = jax.random.choice(
#                 rng_,
#                 jnp.arange(len(ds), dtype=jnp.int32),
#                 shape=(n_samples // batch_size, batch_size),
#                 replace=False
#             ).tolist()

#     rng, rng_ = jax.random.split(rng)

    
#     save_folder = Path(save_folder)
#     save_folder.joinpath('msgs_decoded_doubled').mkdir(exist_ok=True, parents=True)
#     # save_folder.joinpath('l2_book_states_halved').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('b_seq_gen_doubled').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('mid_price').mkdir(exist_ok=True, parents=True)
#     base_save_folder = save_folder

#     for batch_i in tqdm(sample_i):
#         print('BATCH', batch_i)
#         m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]
#         m_seq = jnp.array(m_seq)
#         b_seq_pv = jnp.array(b_seq_pv)
#         msg_seq_raw = jnp.array(msg_seq_raw)
#         book_l2_init = jnp.array(book_l2_init)

#         current_book = book_l2_init

#         #=============#
#         # # Step 1: Prepare positions where to insert messages (accounting for prior insertions)
#         insertion_points = [n_msgs + (i + 1) * n_gen_msgs + i for i in range(num_insertions)]
#         insertion_points = [p for p in insertion_points if p <= msg_seq_raw.shape[1]]
#         print(f"[BATCH {batch_i}] Inserting custom messages at: {insertion_points}")

#         # # Step 2: Generate placeholder message using same logic as insert_custom_end (just 14-dim msg, no book logic yet)
#         # def construct_custom_msg(last_msg):
#         #     ORDER_ID_i = 77777777
#         #     EVENT_TYPE = jnp.full((batch_size,), EVENT_TYPE_i)
#         #     SIDE = jnp.full((batch_size,), DIRECTION_i)
#         #     PRICE = last_msg[:, 3]  # or use fixed jnp.full((batch_size,), 123456)
#         #     DISPLAY_FLAG = jnp.ones((batch_size,), dtype=jnp.int32)
            

#         #     #
#         #     SIZE = jnp.full((batch_size,), order_volume) # - need to get an actual best ask/bid volume and choose min(999999, available_best_volume)
#         #     #

#         #     zeros = jnp.zeros((batch_size,), dtype=jnp.int32)
#         #     TIME_s = last_msg[:, 8]
#         #     TIME_ns = last_msg[:, 9]

#         #     msg = jnp.stack([
#         #         jnp.full((batch_size,), ORDER_ID_i),
#         #         EVENT_TYPE,
#         #         SIDE,
#         #         PRICE,
#         #         DISPLAY_FLAG,
#         #         SIZE,
#         #         zeros, zeros,
#         #         TIME_s, TIME_ns,
#         #         zeros, zeros, zeros, zeros,
#         #     ], axis=1)

#         #     return msg.astype(jnp.int32)

        

#         # Step 3: Loop and insert messages
#         # for i, idx in enumerate(insertion_points):
#         #     custom_msg = construct_custom_msg(msg_seq_raw[:, idx - 1])
#         #     msg_seq_raw = jnp.concatenate([
#         #         msg_seq_raw[:, :idx, :],
#         #         custom_msg[:, None, :],  # shape (B, 1, 14)
#         #         msg_seq_raw[:, idx:, :]
#         #     ], axis=1)

#         for i, idx in enumerate(insertion_points):
#             # Simulate state before insertion
#             msg_prev = msg_seq_raw[:, idx - 1:idx, :]
#             sim_init, sim_state = inference.get_sims_vmap(current_book, msg_prev)

#             # Construct custom message using sim state and current book
#             custom_msg = construct_custom_msg(msg_prev[:, 0, :], sim_state, sim_init)

#             # Insert the message into sequence
#             msg_seq_raw = jnp.concatenate([
#                 msg_seq_raw[:, :idx, :],
#                 custom_msg[:, None, :],
#                 msg_seq_raw[:, idx:, :]
#             ], axis=1)

#         # print(f"[BATCH {batch_i}] msg_seq_raw shape after insertions: {msg_seq_raw.shape}")
#         #=============#

#         batch_size, T, msg_dim = msg_seq_raw.shape
#         current_book = book_l2_init

#         books = []
#         messages = []
#         midprices = []

#         for t in range(T):
#             msg = msg_seq_raw[:, t:t+1, :]
#             sim_init, sim_states = inference.get_sims_vmap(current_book, msg)
#             mid_price = inference.batched_get_safe_mid_price(sim_init, sim_states, tick_size)
#             full_l2_state = jax.vmap(sim_init.get_L2_state, in_axes=(0, None))(sim_states, current_book.shape[1])
#             current_book = full_l2_state[:, : current_book.shape[1]]
#             books.append(current_book)
#             messages.append(msg)
#             midprices.append(mid_price)

#         books = jnp.stack(books, axis=1)             # (batch, T, book_dim)
#         messages = jnp.concatenate(messages, axis=1) # (batch, T, msg_dim)
#         midprices = jnp.stack(midprices, axis=0)     # (T, batch)

#         # print(f"[BATCH {batch_i}] Finished all {T} steps")
#         # print(f"[BATCH {batch_i}] Final messages shape: {messages.shape}")
#         # print(f"[BATCH {batch_i}] Final books shape: {books.shape}")
#         # print(f"[BATCH {batch_i}] Final midprices shape: {midprices.shape}")

#         np.save(os.path.join(base_save_folder, 'msgs_decoded_doubled', f'msgs_decoded_doubled_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(messages)))
#         # np.save(os.path.join(base_save_folder, 'l2_book_states_halved', f'l2_book_states_halved_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(books)))
#         np.save(os.path.join(base_save_folder, 'mid_price', f'mid_price_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(midprices)))

#         # ========================
#         transform_L2_state_batch = jax.jit(jax.vmap(preproc.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))

#         # Get midprices for each step: (T, B) → (B, T)
#         midprices_batched = midprices.T  # (B, T)
#         p_mid = midprices_batched[:, :, None]  # (B, T, 1)

#         # Add midprice as first column to each book state
#         books_with_mid = jnp.concatenate([p_mid, books], axis=-1)  # (B, T, 41)

#         # Transform each book+midprice into model input format
#         books_transformed = transform_L2_state_batch(books_with_mid, n_vol_series, tick_size)  # (B, T, D)

#         # Save transformed books
#         np.save(os.path.join(base_save_folder, 'b_seq_gen_doubled', f'b_seq_gen_doubled_batch_{batch_i}_iter_0.npy'), np.array(jax.device_get(books_transformed)))

In [12]:
def construct_custom_msg(
    last_msg: jnp.ndarray,
    current_book: jnp.ndarray,
    sim_init,
    sim_state,
    DIRECTION_i: int,
    EVENT_TYPE_i: int,
    order_volume: int,
    use_relative_volume: bool,
    order_volume_ratio: float
) -> jnp.ndarray:
    """
    Construct a custom message (e.g., a large market order) based on the last message and current order book.
    """

    batch_size = last_msg.shape[0]
    ORDER_ID_i = 77777777

    EVENT_TYPE = jnp.full((batch_size,), EVENT_TYPE_i, dtype=jnp.int32)
    SIDE = jnp.full((batch_size,), DIRECTION_i, dtype=jnp.int32)
    PRICE = last_msg[:, 3]
    DISPLAY_FLAG = jnp.ones((batch_size,), dtype=jnp.int32)

    # Correctly call per-state bid/ask info
    best_bid_info, best_ask_info = jax.vmap(
        lambda state: sim_init.get_best_bid_and_ask_inclQuants(state)
    )(sim_state)

    # Get price and volume from current book state
    PRICE = jnp.where(
        DIRECTION_i == 0,
        best_ask_info[:, 0],  # best ask price
        best_bid_info[:, 0]   # best bid price
    )

    best_volume = jnp.where(
        DIRECTION_i == 0,
        best_ask_info[:, 1],  # volume at best ask
        best_bid_info[:, 1]   # volume at best bid
    )

    if use_relative_volume:
        SIZE = (order_volume_ratio * best_volume).astype(jnp.int32)
    else:
        SIZE = jnp.full((batch_size,), order_volume, dtype=jnp.int32)

    SIZE = jnp.clip(SIZE, 1, 999999)

    zeros = jnp.zeros((batch_size,), dtype=jnp.int32)
    TIME_s = last_msg[:, 8]
    TIME_ns = last_msg[:, 9]

    custom_msg = jnp.stack([
        jnp.full((batch_size,), ORDER_ID_i, dtype=jnp.int32),
        EVENT_TYPE,
        SIDE,
        PRICE,
        DISPLAY_FLAG,
        SIZE,
        zeros, zeros,
        TIME_s,
        TIME_ns,
        zeros, zeros, zeros, zeros
    ], axis=1)

    return custom_msg.astype(jnp.int32)

In [13]:
# def run_historical_scenario(
#         n_samples: int,
#         batch_size: int,
#         ds: LOBSTER_Dataset,
#         rng: jax.dtypes.prng_key,
#         seq_len: int,
#         n_msgs: int,
#         n_gen_msgs: int,
#         train_state: TrainState,
#         model: nn.Module,
#         batchnorm: bool,
#         encoder: Dict[str, Tuple[jax.Array, jax.Array]],
#         stock_symbol: str,
#         n_vol_series: int = 500,
#         save_folder: str = './data_saved/',
#         tick_size: int = 100,
#         sample_top_n: int = -1,
#         sample_all: bool = False,
#         num_insertions: int = 2,
#         num_coolings: int = 2,
#         midprice_step_size=100,
#         EVENT_TYPE_i = 4,
#         DIRECTION_i = 0,
#         order_volume = 75,
#         use_sample_file: bool = False,
#         sample_file_path: Optional[str] = None,
#         start_batch: int = 0,
#         end_batch: int = -1
#     ):


#     rng, rng_ = jax.random.split(rng)

#     if use_sample_file:
#         assert sample_file_path is not None, "Path to sample file not provided"
        
#         with open(sample_file_path, "r") as f:
#             sample_i_full = json.load(f)

#         sample_i = sample_i_full[start_batch:end_batch if end_batch != -1 else None]

#         for i, batch in enumerate(sample_i):
#             assert len(batch) == batch_size, f"Batch {i} has incorrect size {len(batch)}, expected {batch_size}"

#     else:
#         if sample_all:
#             sample_i = jnp.arange(
#                 len(ds) // batch_size * batch_size,
#                 dtype=jnp.int32
#             ).reshape(-1, batch_size).tolist()
#         else:
#             assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
#             sample_i = jax.random.choice(
#                 rng_,
#                 jnp.arange(len(ds), dtype=jnp.int32),
#                 shape=(n_samples // batch_size, batch_size),
#                 replace=False
#             ).tolist()

#     rng, rng_ = jax.random.split(rng)
    
#     save_folder = Path(save_folder)
#     save_folder.joinpath('msgs_decoded_doubled').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('b_seq_gen_doubled').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('mid_price').mkdir(exist_ok=True, parents=True)
    
#     for batch_i in tqdm(sample_i):
#         print('BATCH', batch_i)
#         m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]
#         msg_seq_raw = jnp.array(msg_seq_raw)
#         book_l2_init = jnp.array(book_l2_init)
        
#         current_book = book_l2_init
        
        
#         insertion_points = [n_msgs + (i + 1) * n_gen_msgs + i for i in range(num_insertions)]
#         insertion_points = sorted([p for p in insertion_points if p <= msg_seq_raw.shape[1]])  
#         print(f"[BATCH {batch_i}] Planned insertion points at indices: {insertion_points}")
        
        
#         books = []
#         messages = []
#         midprices = []
        
        
#         original_idx = 0
#         steps_count = 0
#         insertion_iter = 0

        
        
        
#         while original_idx < msg_seq_raw.shape[1] or insertion_iter < len(insertion_points):
#             if insertion_iter < len(insertion_points) and steps_count == insertion_points[insertion_iter]:
#                 if steps_count == 0:
#                     last_msg = msg_seq_raw[:, 0, :]
#                 else:
#                     last_msg = messages[-1]
#                 sim_init, sim_state = inference.get_sims_vmap(current_book, last_msg[:, None, :])
#                 custom_msg = construct_custom_msg(
#                     last_msg,
#                     current_book,
#                     sim_init,
#                     sim_state,
#                     DIRECTION_i=DIRECTION_i,
#                     EVENT_TYPE_i=EVENT_TYPE_i,
#                     order_volume=order_volume,
#                     use_relative_volume=use_relative_volume,
#                     order_volume_ratio=order_volume_ratio
#                 )
                
                
#                 sim_init_ins, sim_state_ins = inference.get_sims_vmap(current_book, custom_msg[:, None, :])
                
#                 mid_price_ins = inference.batched_get_safe_mid_price(sim_init_ins, sim_state_ins, tick_size)
                
#                 full_l2_state_ins = jax.vmap(sim_init_ins.get_L2_state, in_axes=(0, None))(
#                     sim_state_ins, current_book.shape[1]
#                 )
#                 current_book = full_l2_state_ins[:, : current_book.shape[1]]  # update current book to new state
                
                
#                 messages.append(custom_msg)
#                 books.append(current_book)
#                 midprices.append(mid_price_ins)
                
#                 steps_count += 1
#                 insertion_iter += 1
#                 continue
            
#             if original_idx < msg_seq_raw.shape[1]:
#                 msg = msg_seq_raw[:, original_idx: original_idx + 1, :]
#                 sim_init_hist, sim_state_hist = inference.get_sims_vmap(current_book, msg)
#                 mid_price_hist = inference.batched_get_safe_mid_price(sim_init_hist, sim_state_hist, tick_size)
#                 full_l2_state_hist = jax.vmap(sim_init_hist.get_L2_state, in_axes=(0, None))(
#                     sim_state_hist, current_book.shape[1]
#                 )
#                 current_book = full_l2_state_hist[:, : current_book.shape[1]]
                
#                 messages.append(msg[:, 0, :])
#                 books.append(current_book)
#                 midprices.append(mid_price_hist)
                
#                 original_idx += 1
#                 steps_count += 1
#             else:
#                 print("No more historical messages. Awaiting remaining insertions...")
#                 break
        
#         messages_arr = jnp.concatenate([m.reshape(m.shape[0], 1, m.shape[1]) for m in messages], axis=1)  # (batch, T_total, 14)
#         books_arr = jnp.stack(books, axis=1)  # (batch, T_total, book_dim)
#         midprices_arr = jnp.stack(midprices, axis=0)  # (T_total, batch)
        
#         np.save(save_folder / 'msgs_decoded_doubled' / f'msgs_decoded_doubled_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(messages_arr))
#         np.save(save_folder / 'mid_price' / f'mid_price_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(midprices_arr))
        
#         transform_L2_state_batch = jax.jit(jax.vmap(preproc.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))
        
#         midprices_batched = midprices_arr.T[:, :, None]
#         books_with_mid = jnp.concatenate([midprices_batched, books_arr], axis=-1)
#         books_transformed = transform_L2_state_batch(books_with_mid, n_vol_series, tick_size)
        
#         np.save(save_folder / 'b_seq_gen_doubled' / f'b_seq_gen_doubled_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(books_transformed))

In [14]:
# results = run_historical_scenario(
#         n_samples,
#         batch_size,
#         ds,
#         rng,
#         n_messages * Message_Tokenizer.MSG_LEN,
#         n_messages,
#         n_gen_msgs,
#         state,
#         model,
#         args_ckpt.batchnorm,
#         Vocab().ENCODING,
#         stock,
#         n_vol_series,
#         exp_folder,
#         tick_size,
#         sample_top_n,
#         sample_all,
#         num_insertions,
#         num_coolings,
#         midprice_step_size,
#         EVENT_TYPE_i,
#         DIRECTION_i,
#         order_volume,
#         use_sample_file,
#         sample_file_path,
#         start_batch,
#         end_batch
#     )

# ================= Shifting =================

In [15]:
import jax.numpy as jnp

def shift_sell_prices(msg, tick_size, PRICE_ABS_i, DIRECTION_i, EVENT_TYPE_i):
    """
    Сдвигает все цены SELL (ask) сообщений на +tick_size.
    """
    # копируем чтобы не портить оригинал
    msg_shifted = msg.copy()

    # mask: Execution или Limit ордера, направление = Sell (0)
    is_relevant = (msg[..., EVENT_TYPE_i] == 1) | (msg[..., EVENT_TYPE_i] == 4)  # Limit or Execution
    is_sell = (msg[..., DIRECTION_i] == 0)

    mask = is_relevant & is_sell

    msg_shifted = msg_shifted.at[..., PRICE_ABS_i].set(
        jnp.where(mask, msg[..., PRICE_ABS_i] + tick_size, msg[..., PRICE_ABS_i])
    )
    return msg_shifted


def shift_buy_prices(msg, tick_size, PRICE_ABS_i, DIRECTION_i, EVENT_TYPE_i):
    """
    Сдвигает все цены BUY (bid) сообщений на +tick_size.
    """
    msg_shifted = msg.copy()

    # mask: Execution или Limit ордера, направление = Buy (1)
    is_relevant = (msg[..., EVENT_TYPE_i] == 1) | (msg[..., EVENT_TYPE_i] == 4)  # Limit or Execution
    is_buy = (msg[..., DIRECTION_i] == 1)

    mask = is_relevant & is_buy

    msg_shifted = msg_shifted.at[..., PRICE_ABS_i].set(
        jnp.where(mask, msg[..., PRICE_ABS_i] + tick_size, msg[..., PRICE_ABS_i])
    )
    return msg_shifted

In [None]:
# just plain historical scenario

# def run_historical_scenario( 
#         n_samples: int,
#         batch_size: int,
#         ds: LOBSTER_Dataset,
#         rng: jax.dtypes.prng_key,
#         seq_len: int,
#         n_msgs: int,
#         n_gen_msgs: int,
#         train_state: TrainState,
#         model: nn.Module,
#         batchnorm: bool,
#         encoder: Dict[str, Tuple[jax.Array, jax.Array]],
#         stock_symbol: str,
#         n_vol_series: int = 500,
#         save_folder: str = './data_saved/',
#         tick_size: int = 100,
#         sample_top_n: int = -1,
#         sample_all: bool = False,
#         num_insertions: int = 2,
#         num_coolings: int = 2,
#         midprice_step_size=100,
#         EVENT_TYPE_i = 4,
#         DIRECTION_i = 0,
#         order_volume = 75,
#         use_sample_file: bool = False,
#         sample_file_path: Optional[str] = None,
#         start_batch: int = 0,
#         end_batch: int = -1
#     ):


#     rng, rng_ = jax.random.split(rng)

#     if use_sample_file:
#         assert sample_file_path is not None, "Path to sample file not provided"
        
#         with open(sample_file_path, "r") as f:
#             sample_i_full = json.load(f)

#         sample_i = sample_i_full[start_batch:end_batch if end_batch != -1 else None]

#         for i, batch in enumerate(sample_i):
#             assert len(batch) == batch_size, f"Batch {i} has incorrect size {len(batch)}, expected {batch_size}"

#     else:
#         if sample_all:
#             sample_i = jnp.arange(
#                 len(ds) // batch_size * batch_size,
#                 dtype=jnp.int32
#             ).reshape(-1, batch_size).tolist()
#         else:
#             assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
#             sample_i = jax.random.choice(
#                 rng_,
#                 jnp.arange(len(ds), dtype=jnp.int32),
#                 shape=(n_samples // batch_size, batch_size),
#                 replace=False
#             ).tolist()

#     rng, rng_ = jax.random.split(rng)
    
#     save_folder = Path(save_folder)
#     save_folder.joinpath('msgs_decoded_doubled').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('b_seq_gen_doubled').mkdir(exist_ok=True, parents=True)
#     save_folder.joinpath('mid_price').mkdir(exist_ok=True, parents=True)
    
#     for batch_i in tqdm(sample_i):
#         print('BATCH', batch_i)
#         m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]
#         msg_seq_raw = jnp.array(msg_seq_raw)
#         book_l2_init = jnp.array(book_l2_init)
        
#         current_book = book_l2_init
        
        
#         insertion_points = [n_msgs + (i + 1) * n_gen_msgs + i for i in range(num_insertions)]
#         insertion_points = sorted([p for p in insertion_points if p <= msg_seq_raw.shape[1]])  
#         print(f"[BATCH {batch_i}] Planned insertion points at indices: {insertion_points}")
        
        
#         books = []
#         messages = []
#         midprices = []
        
        
#         original_idx = 0
#         steps_count = 0
#         insertion_iter = 0

        
        
#         shift_prices_flag = False
#         while original_idx < msg_seq_raw.shape[1] or insertion_iter < len(insertion_points):
#             if insertion_iter < len(insertion_points) and steps_count == insertion_points[insertion_iter]:
#                 if steps_count == 0:
#                     last_msg = msg_seq_raw[:, 0, :]
#                 else:
#                     last_msg = messages[-1]
#                 sim_init, sim_state = inference.get_sims_vmap(current_book, last_msg[:, None, :])
#                 custom_msg = construct_custom_msg(
#                     last_msg,
#                     current_book,
#                     sim_init,
#                     sim_state,
#                     DIRECTION_i=DIRECTION_i,
#                     EVENT_TYPE_i=EVENT_TYPE_i,
#                     order_volume=order_volume,
#                     use_relative_volume=use_relative_volume,
#                     order_volume_ratio=order_volume_ratio
#                 )
                
                
#                 sim_init_ins, sim_state_ins = inference.get_sims_vmap(current_book, custom_msg[:, None, :])

#                 best_ask_vol = current_book[:, current_book.shape[1] // 2]
#                 if jnp.all(best_ask_vol == 0):
#                     shift_prices_flag = True
                
#                 mid_price_ins = inference.batched_get_safe_mid_price(sim_init_ins, sim_state_ins, tick_size)
                
#                 full_l2_state_ins = jax.vmap(sim_init_ins.get_L2_state, in_axes=(0, None))(
#                     sim_state_ins, current_book.shape[1]
#                 )
#                 current_book = full_l2_state_ins[:, : current_book.shape[1]]  # update current book to new state
                
                
#                 messages.append(custom_msg)
#                 books.append(current_book)
#                 midprices.append(mid_price_ins)
                
#                 steps_count += 1
#                 insertion_iter += 1
#                 continue
            
#             if original_idx < msg_seq_raw.shape[1]:
#                 msg = msg_seq_raw[:, original_idx: original_idx + 1, :]

#                 if shift_prices_flag:
#                     msg = shift_sell_prices(
#                         msg, tick_size,
#                         PRICE_ABS_i=3,  # replace with actual index for PRICE_ABS
#                         DIRECTION_i=DIRECTION_i,
#                         EVENT_TYPE_i=EVENT_TYPE_i
#                     )


#                 sim_init_hist, sim_state_hist = inference.get_sims_vmap(current_book, msg)
#                 mid_price_hist = inference.batched_get_safe_mid_price(sim_init_hist, sim_state_hist, tick_size)
#                 full_l2_state_hist = jax.vmap(sim_init_hist.get_L2_state, in_axes=(0, None))(
#                     sim_state_hist, current_book.shape[1]
#                 )
#                 current_book = full_l2_state_hist[:, : current_book.shape[1]]
                
#                 messages.append(msg[:, 0, :])
#                 books.append(current_book)
#                 midprices.append(mid_price_hist)
                
#                 original_idx += 1
#                 steps_count += 1
#             else:
#                 print("No more historical messages. Awaiting remaining insertions...")
#                 break
        
#         messages_arr = jnp.concatenate([m.reshape(m.shape[0], 1, m.shape[1]) for m in messages], axis=1)  # (batch, T_total, 14)
#         books_arr = jnp.stack(books, axis=1)  # (batch, T_total, book_dim)
#         midprices_arr = jnp.stack(midprices, axis=0)  # (T_total, batch)
        
#         np.save(save_folder / 'msgs_decoded_doubled' / f'msgs_decoded_doubled_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(messages_arr))
#         np.save(save_folder / 'mid_price' / f'mid_price_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(midprices_arr))
        
#         transform_L2_state_batch = jax.jit(jax.vmap(preproc.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))
        
#         midprices_batched = midprices_arr.T[:, :, None]
#         books_with_mid = jnp.concatenate([midprices_batched, books_arr], axis=-1)
#         books_transformed = transform_L2_state_batch(books_with_mid, n_vol_series, tick_size)
        
#         np.save(save_folder / 'b_seq_gen_doubled' / f'b_seq_gen_doubled_batch_{batch_i}_iter_0.npy',
#                 jax.device_get(books_transformed))

In [None]:
# consuming 2 levels - actually heuristic, but works well


def run_historical_scenario(
        n_samples: int,
        batch_size: int,
        ds: LOBSTER_Dataset,
        rng: jax.dtypes.prng_key,
        seq_len: int,
        n_msgs: int,
        n_gen_msgs: int,
        train_state: TrainState,
        model: nn.Module,
        batchnorm: bool,
        encoder: Dict[str, Tuple[jax.Array, jax.Array]],
        stock_symbol: str,
        n_vol_series: int = 500,
        save_folder: str = './data_saved/',
        tick_size: int = 100,
        sample_top_n: int = -1,
        sample_all: bool = False,
        num_insertions: int = 2,
        num_coolings: int = 2,
        midprice_step_size=100,
        EVENT_TYPE_i = 4,
        DIRECTION_i = 0,
        order_volume = 75,
        use_sample_file: bool = False,
        sample_file_path: Optional[str] = None,
        start_batch: int = 0,
        end_batch: int = -1
    ):


    rng, rng_ = jax.random.split(rng)

    if use_sample_file:
        assert sample_file_path is not None, "Path to sample file not provided"
        
        with open(sample_file_path, "r") as f:
            sample_i_full = json.load(f)

        sample_i = sample_i_full[start_batch:end_batch if end_batch != -1 else None]

        for i, batch in enumerate(sample_i):
            assert len(batch) == batch_size, f"Batch {i} has incorrect size {len(batch)}, expected {batch_size}"

    else:
        if sample_all:
            sample_i = jnp.arange(
                len(ds) // batch_size * batch_size,
                dtype=jnp.int32
            ).reshape(-1, batch_size).tolist()
        else:
            assert n_samples % batch_size == 0, 'n_samples must be divisible by batch_size'
            sample_i = jax.random.choice(
                rng_,
                jnp.arange(len(ds), dtype=jnp.int32),
                shape=(n_samples // batch_size, batch_size),
                replace=False
            ).tolist()

    rng, rng_ = jax.random.split(rng)
    
    save_folder = Path(save_folder)
    save_folder.joinpath('msgs_decoded_doubled').mkdir(exist_ok=True, parents=True)
    save_folder.joinpath('b_seq_gen_doubled').mkdir(exist_ok=True, parents=True)
    save_folder.joinpath('mid_price').mkdir(exist_ok=True, parents=True)
    
    for batch_i in tqdm(sample_i):
        print('BATCH', batch_i)
        m_seq, _, b_seq_pv, msg_seq_raw, book_l2_init = ds[batch_i]
        msg_seq_raw = jnp.array(msg_seq_raw)
        book_l2_init = jnp.array(book_l2_init)
        
        current_book = book_l2_init
        
        
        insertion_points = [n_msgs + (i + 1) * n_gen_msgs + i for i in range(num_insertions)]
        insertion_points = sorted([p for p in insertion_points if p <= msg_seq_raw.shape[1]])  
        print(f"[BATCH {batch_i}] Planned insertion points at indices: {insertion_points}")
        
        
        books = []
        messages = []
        midprices = []
        
        
        original_idx = 0
        steps_count = 0
        insertion_iter = 0

        
        
        shift_ticks = 0
        while original_idx < msg_seq_raw.shape[1] or insertion_iter < len(insertion_points):
            if insertion_iter < len(insertion_points) and steps_count == insertion_points[insertion_iter]:
                if steps_count == 0:
                    last_msg = msg_seq_raw[:, 0, :]
                else:
                    last_msg = messages[-1]

                # Агрессивный ордер: съедаем весь best ask
                sim_init, sim_state = inference.get_sims_vmap(current_book, last_msg[:, None, :])
                custom_msg = construct_custom_msg(
                    last_msg,
                    current_book,
                    sim_init,
                    sim_state,
                    DIRECTION_i=DIRECTION_i,
                    EVENT_TYPE_i=EVENT_TYPE_i,
                    order_volume=order_volume,
                    use_relative_volume=use_relative_volume,
                    order_volume_ratio=order_volume_ratio
                )
                sim_init_ins, sim_state_ins = inference.get_sims_vmap(current_book, custom_msg[:, None, :])
                mid_price_ins = inference.batched_get_safe_mid_price(sim_init_ins, sim_state_ins, tick_size)
                full_l2_state_ins = jax.vmap(sim_init_ins.get_L2_state, in_axes=(0, None))(
                    sim_state_ins, current_book.shape[1]
                )
                current_book = full_l2_state_ins[:, : current_book.shape[1]]

                messages.append(custom_msg)
                books.append(current_book)
                midprices.append(mid_price_ins)

                shift_ticks += 1

                steps_count += 1
                insertion_iter += 1
                continue
            
            if original_idx < msg_seq_raw.shape[1]:
                msg = msg_seq_raw[:, original_idx: original_idx + 1, :]

                
                # сдвигаем цены sell
                msg = shift_sell_prices(
                    msg, tick_size*shift_ticks,
                    PRICE_ABS_i=3,
                    DIRECTION_i=2,
                    EVENT_TYPE_i=1
                )
                # сдвигаем цены buy
                msg = shift_buy_prices(
                    msg, tick_size*shift_ticks,
                    PRICE_ABS_i=3,
                    DIRECTION_i=2,
                    EVENT_TYPE_i=1
                )


                sim_init_hist, sim_state_hist = inference.get_sims_vmap(current_book, msg)
                mid_price_hist = inference.batched_get_safe_mid_price(sim_init_hist, sim_state_hist, tick_size)
                full_l2_state_hist = jax.vmap(sim_init_hist.get_L2_state, in_axes=(0, None))(
                    sim_state_hist, current_book.shape[1]
                )
                current_book = full_l2_state_hist[:, : current_book.shape[1]]

                messages.append(msg[:, 0, :])
                books.append(current_book)
                midprices.append(mid_price_hist)

                original_idx += 1
                steps_count += 1
            else:
                print("No more historical messages. Awaiting remaining insertions...")
                break
        
        messages_arr = jnp.concatenate([m.reshape(m.shape[0], 1, m.shape[1]) for m in messages], axis=1)  # (batch, T_total, 14)
        books_arr = jnp.stack(books, axis=1)  # (batch, T_total, book_dim)
        midprices_arr = jnp.stack(midprices, axis=0)  # (T_total, batch)
        
        np.save(save_folder / 'msgs_decoded_doubled' / f'msgs_decoded_doubled_batch_{batch_i}_iter_0.npy',
                jax.device_get(messages_arr))
        np.save(save_folder / 'mid_price' / f'mid_price_batch_{batch_i}_iter_0.npy',
                jax.device_get(midprices_arr))
        
        transform_L2_state_batch = jax.jit(jax.vmap(preproc.transform_L2_state, in_axes=(0, None, None)), static_argnums=(1, 2))
        
        midprices_batched = midprices_arr.T[:, :, None]
        books_with_mid = jnp.concatenate([midprices_batched, books_arr], axis=-1)
        books_transformed = transform_L2_state_batch(books_with_mid, n_vol_series, tick_size)
        
        np.save(save_folder / 'b_seq_gen_doubled' / f'b_seq_gen_doubled_batch_{batch_i}_iter_0.npy',
                jax.device_get(books_transformed))

In [18]:
results = run_historical_scenario(
        n_samples,
        batch_size,
        ds,
        rng,
        n_messages * Message_Tokenizer.MSG_LEN,
        n_messages,
        n_gen_msgs,
        state,
        model,
        args_ckpt.batchnorm,
        Vocab().ENCODING,
        stock,
        n_vol_series,
        exp_folder,
        tick_size,
        sample_top_n,
        sample_all,
        num_insertions,
        num_coolings,
        midprice_step_size,
        EVENT_TYPE_i,
        DIRECTION_i,
        order_volume,
        use_sample_file,
        sample_file_path,
        start_batch,
        end_batch
    )

  0%|          | 0/64 [00:00<?, ?it/s]

BATCH [2271, 2227, 3720, 1167, 1976, 3036, 197, 1384, 797, 2359, 1679, 2347, 3916, 479, 3750, 870]
[BATCH [2271, 2227, 3720, 1167, 1976, 3036, 197, 1384, 797, 2359, 1679, 2347, 3916, 479, 3750, 870]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  2%|▏         | 1/64 [02:26<2:33:34, 146.26s/it]

BATCH [790, 3390, 1800, 1048, 2146, 2113, 1922, 2580, 3423, 2166, 3568, 2203, 855, 3881, 2916, 3205]
[BATCH [790, 3390, 1800, 1048, 2146, 2113, 1922, 2580, 3423, 2166, 3568, 2203, 855, 3881, 2916, 3205]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  3%|▎         | 2/64 [04:44<2:26:25, 141.70s/it]

BATCH [795, 520, 1612, 3903, 2429, 2627, 3180, 3408, 611, 95, 2139, 3597, 2481, 1168, 1414, 2949]
[BATCH [795, 520, 1612, 3903, 2429, 2627, 3180, 3408, 611, 95, 2139, 3597, 2481, 1168, 1414, 2949]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  5%|▍         | 3/64 [07:03<2:22:35, 140.26s/it]

BATCH [1690, 3150, 1099, 10, 2738, 2286, 1733, 2274, 3695, 1291, 3232, 2138, 2527, 3754, 616, 2276]
[BATCH [1690, 3150, 1099, 10, 2738, 2286, 1733, 2274, 3695, 1291, 3232, 2138, 2527, 3754, 616, 2276]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  6%|▋         | 4/64 [09:21<2:19:21, 139.37s/it]

BATCH [2781, 1014, 689, 465, 3261, 2151, 619, 1264, 967, 1683, 3219, 354, 3562, 2984, 464, 1311]
[BATCH [2781, 1014, 689, 465, 3261, 2151, 619, 1264, 967, 1683, 3219, 354, 3562, 2984, 464, 1311]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  8%|▊         | 5/64 [11:40<2:16:53, 139.22s/it]

BATCH [949, 771, 1365, 428, 1079, 2853, 1511, 1792, 2630, 2252, 1767, 1512, 1020, 1429, 2788, 2085]
[BATCH [949, 771, 1365, 428, 1079, 2853, 1511, 1792, 2630, 2252, 1767, 1512, 1020, 1429, 2788, 2085]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


  9%|▉         | 6/64 [14:00<2:14:51, 139.50s/it]

BATCH [1424, 2944, 2961, 3922, 2028, 1555, 1139, 96, 63, 2594, 767, 3334, 90, 741, 3353, 3493]
[BATCH [1424, 2944, 2961, 3922, 2028, 1555, 1139, 96, 63, 2594, 767, 3334, 90, 741, 3353, 3493]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


 11%|█         | 7/64 [16:18<2:12:06, 139.06s/it]

BATCH [3054, 453, 545, 1789, 370, 3438, 2244, 2059, 1543, 106, 3764, 2931, 2374, 3121, 3806, 3352]
[BATCH [3054, 453, 545, 1789, 370, 3438, 2244, 2059, 1543, 106, 3764, 2931, 2374, 3121, 3806, 3352]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


 12%|█▎        | 8/64 [18:38<2:09:56, 139.23s/it]

BATCH [1879, 2420, 796, 1275, 627, 2199, 3711, 667, 724, 2588, 1889, 848, 3696, 1073, 2260, 1267]
[BATCH [1879, 2420, 796, 1275, 627, 2199, 3711, 667, 724, 2588, 1889, 848, 3696, 1073, 2260, 1267]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


 14%|█▍        | 9/64 [20:56<2:07:32, 139.13s/it]

BATCH [1971, 3410, 956, 2253, 3825, 2569, 1258, 230, 1967, 3124, 1636, 3847, 3863, 2667, 1832, 262]
[BATCH [1971, 3410, 956, 2253, 3825, 2569, 1258, 230, 1967, 3124, 1636, 3847, 3863, 2667, 1832, 262]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


 16%|█▌        | 10/64 [23:16<2:05:14, 139.16s/it]

BATCH [9, 2977, 2660, 1951, 2356, 100, 3851, 2266, 1055, 1457, 3348, 1732, 404, 432, 3699, 2970]
[BATCH [9, 2977, 2660, 1951, 2356, 100, 3851, 2266, 1055, 1457, 3348, 1732, 404, 432, 3699, 2970]] Planned insertion points at indices: [550, 601, 652, 703, 754, 805, 856, 907, 958, 1009, 1060, 1111, 1162, 1213, 1264, 1315, 1366, 1417, 1468, 1519, 1570, 1621, 1672, 1723, 1774, 1825, 1876, 1927, 1978, 2029, 2080, 2131, 2182, 2233, 2284, 2335, 2386, 2437, 2488, 2539, 2590, 2641, 2692, 2743, 2794, 2845, 2896, 2947, 2998, 3049, 3100, 3151, 3202, 3253, 3304, 3355, 3406, 3457, 3508, 3559, 3610, 3661, 3712, 3763, 3814, 3865, 3916, 3967, 4018, 4069, 4120, 4171, 4222, 4273, 4324, 4375, 4426, 4477, 4528, 4579]


 16%|█▌        | 10/64 [25:00<2:15:01, 150.02s/it]
