In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as onp

from pathlib import Path
import os

import pandas as pd

2025-04-29 21:32:24.615661: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.6 which is older than the ptxas CUDA version (12.8.93). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [4]:
import os
from pathlib import Path
from datetime import datetime

def create_next_experiment_folder(save_folder: str) -> Path:
    """
    Сканирует save_folder, находит папки вида 'exp_<число>_*',
    определяет максимальный <число>, создаёт новую папку
    'exp_<max+1>_YYYYMMDD_HHMMSS' и возвращает её Path.
    """
    base = Path(save_folder)
    # Не создаём автоматически base — предполагаем, что папка есть или
    # что запись идёт в текущую директорию.
    if not base.exists():
        raise FileNotFoundError(f"Каталога {save_folder!r} не существует")

    max_idx = 0
    for entry in base.iterdir():
        if entry.is_dir() and entry.name.startswith("exp_"):
            parts = entry.name.split("_")
            if len(parts) >= 2 and parts[1].isdigit():
                idx = int(parts[1])
                max_idx = max(max_idx, idx)

    next_idx = max_idx + 1
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    new_folder_name = f"exp_{next_idx}_{timestamp}"
    new_folder = base / new_folder_name
    new_folder.mkdir()
    return new_folder

In [5]:
# == Parameters for Config == #

save_folder = 'data_saved/' # - need to scan folders in this folder and create next folder 2,3,...,99

batch_size = 4
n_samples = 20

n_gen_msgs = 50 # how many messages to generate into the future
midprice_step_size = 1

num_insertions = 20
num_coolings = 20

EVENT_TYPE_i = 4
DIRECTION_i = 0
order_volume = 75

# ======== #

# scale down to single GPU, single sample inference
bsz = 1 #1, 10
num_devices = 1

n_messages = 500  # length of input sequence to model
book_dim = 501 #b_enc.shape[1] 500+1=501
n_vol_series = 500
sample_top_n = -1

model_size = 'large'
data_dir ='data/test_set/'
sample_all = False # action='store_true'
stock = 'GOOG'  # 'GOOG', 'INTC'

tick_size = 100
sample_all = False
rng_seed = 42


# ======== #

v = Vocab()
n_classes = len(v)
seq_len = n_messages * Message_Tokenizer.MSG_LEN
book_seq_len = n_messages

n_eval_messages = n_gen_msgs
eval_seq_len = n_eval_messages * Message_Tokenizer.MSG_LEN

rng = jax.random.key(rng_seed)
rng, rng_ = jax.random.split(rng)

In [6]:
save_folder = create_next_experiment_folder(save_folder)
print("Created new experiment directory:", save_folder)

Created new experiment directory: data_saved/exp_3_20250429_213228


In [7]:
if stock == 'GOOG':
    # ckpt_path = './checkpoints/treasured-leaf-149_84yhvzjt/' # 0.5 y GOOG, (full model)
    ckpt_path = './checkpoints/denim-elevator-754_czg1ss71/' # large model
    # ckpt_path = './checkpoints/stilted-deluge-759_8g3vqor4'  # small model
elif stock == 'INTC':
    # ckpt_path = './checkpoints/pleasant-cherry-152_i6h5n74c/' # 0.5 y INTC, (full model)
    ckpt_path = './checkpoints/eager-sea-755_2rw1ofs3/'  # large model
else:
    raise ValueError(f'stock {stock} not recognized')

print('Loading metadata:', ckpt_path)
args_ckpt = load_metadata(ckpt_path)

Loading metadata: ./checkpoints/denim-elevator-754_czg1ss71/


In [8]:
# load train state from disk

print('Initializing model...')
new_train_state, model_cls = init_train_state(
    args_ckpt,
    n_classes=n_classes,
    seq_len=seq_len,
    book_dim=book_dim,
    book_seq_len=book_seq_len,
)

print('\nLoading model checkpoint...')
ckpt = load_checkpoint(
    new_train_state,
    ckpt_path,
    train=False,
)
state = ckpt['model']

model = model_cls(training=False, step_rescale=1.0)

param_count = sum(x.size for x in jax.tree_leaves(state.params))
print('param count:', param_count)

Initializing model...
configuring standard optimization setup
[*] Trainable Parameters: 35776312

Loading model checkpoint...
param count: 35776312



jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) or jax.tree_util.tree_leaves (any JAX version).



In [9]:
# scale down to single GPU, single sample inference
args_ckpt.bsz = bsz #1, 10
args_ckpt.num_devices = num_devices

batchnorm = args_ckpt.batchnorm

data_dir = data_dir + stock
print(f"Directory Path: {data_dir}")

Directory Path: data/test_set/GOOG


In [10]:
data_dir = Path(data_dir)
Path(data_dir).mkdir(parents=True, exist_ok=True)
folder_path = Path(data_dir)
file_count = len([f for f in folder_path.iterdir() if f.is_file()])
print(f"There are {file_count} files in the folder {str(data_dir)}.")

There are 36 files in the folder data/test_set/GOOG.


In [11]:
print(n_messages, n_eval_messages)

500 50


In [12]:
from pathlib import Path
import os

# Print current working directory to help verify the path
print(f"Current working directory: {os.getcwd()}")

# Use relative path to data/test_set/GOOG
data_dir = Path("data/test_set/GOOG")

try:
    Path(data_dir).mkdir(parents=True, exist_ok=True)
    print(f"Successfully created or verified directory: {data_dir}")
    
    file_count = len([f for f in Path(data_dir).iterdir() if f.is_file()])
    print(f"There are {file_count} files in the folder {data_dir}.")
except Exception as e:
    print(f"Error: {str(e)}")

ds = inference.get_dataset(data_dir, n_messages, n_eval_messages)

Current working directory: /app
Successfully created or verified directory: data/test_set/GOOG
There are 36 files in the folder data/test_set/GOOG.


# check 1 by 1

In [13]:
print(save_folder)

data_saved/exp_3_20250429_213228


In [14]:
inference.run_generation_scenario(
    n_samples, 
    batch_size,
    ds,
    rng,
    seq_len,
    n_messages,
    n_gen_msgs,
    state,
    model,
    batchnorm,
    v.ENCODING,
    stock,
    n_vol_series,
    save_folder,
    tick_size,
    sample_top_n,
    sample_all,
    num_insertions,
    num_coolings,
    midprice_step_size,
    EVENT_TYPE_i, 
    DIRECTION_i, 
    order_volume
    )

  0%|          | 0/5 [00:00<?, ?it/s]

BATCH [7694, 15801, 3956, 30284]

ITERATION  1
Im using m_seq_raw_inp[:, :1, :]
midprice after processing -499 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :2, :]
midprice after processing -498 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :3, :]
midprice after processing -497 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :4, :]
midprice after processing -496 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :5, :]
midprice after processing -495 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :6, :]
midprice after processing -494 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :7, :]
midprice after processing -493 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :8, :]
midprice after processing -492 messages = [888600 870400 901500 925100]
Im using m_seq_raw_inp[:, :9, :]
midprice after processing -491 messages = [888600 870400 901500 925000]
Im using

 20%|██        | 1/5 [3:08:23<12:33:34, 11303.53s/it]

midprice after processing 2000 messages = [890900 871000 900300 926000]
midprices are successfully calculated iteration no. 40
midprices:  [Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871200, 900300, 926000], dtype=int32), Array([890900, 871200, 900300, 926000], dtype=int32), Array([890900, 871200, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890900, 871100, 900300, 926000], dtype=int32), Array([890800, 871100, 900300, 926000], dtype=int3

 20%|██        | 1/5 [3:10:45<12:43:00, 11445.18s/it]E0430 00:46:08.042236  317408 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: CUDA driver ran out of memory trying to instantiate CUDA graph with 36 nodes and 0 conditionals (total of 0 alive CUDA graphs in the process). You can try to (a) Give more memory to CUDA driver by reducing XLA_PYTHON_CLIENT_MEM_FRACTION (b) Disable CUDA graph with 'XLA_FLAGS=--xla_gpu_enable_command_buffer=' (empty set). Original error: Failed to instantiate CUDA graph:CUDA_ERROR_OUT_OF_MEMORY: out of memory


midprice after processing -363 messages = [906200 861000 894000 907600]
Im using m_seq_raw_inp[:, :138, :]



