In [1]:
%cd /app

/app


In [2]:
import argparse
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import torch
torch.multiprocessing.set_start_method('spawn')

import jax
from lob.encoding import Vocab, Message_Tokenizer

from lob import inference_no_errcorr as inference
from lob.init_train import init_train_state, load_checkpoint, load_metadata, load_args_from_checkpoint

from lob import inference_no_errcorr as inference
import lob.encoding as encoding
import preproc as preproc

import jax.numpy as jnp
import numpy as np

from pathlib import Path
import os

import pandas as pd
import plotly.graph_objs as go
import yaml

from filtration_utils import summary_table, build_zero_padded_series, plot_midprice_series_with_insertions, prepare_volatility_filtered_series, plot_midprice_series_with_mean_std

import os
import yaml
import numpy as np
import pandas as pd
from typing import Dict, Tuple, Any
from sklearn.linear_model import LinearRegression

2025-08-19 09:15:12.586420: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.8 which is older than the ptxas CUDA version (12.9.41). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
# ======= CONFIGURATION =======
experiments = {
    "Plain": 'exp_21_20250804_215633_hist_plain_whole_352',
    "Heuristic": 'exp_22_20250804_215639_hist_heur_whole_360',
    "GenAI": 'exp_43_20250809_085521',
}

DATA_ROOT = "/app/data_saved"  # base dir used in your plotting code

In [4]:
import yaml
import numpy as np
import plotly.graph_objects as go

# Colors for each experiment
colors = {
    "Plain": 'black',
    "Heuristic": 'blue',
    "GenAI": 'red',
}

# Dash style per cutoff
dash_for_cutoff = {
    0.0: "solid",
    0.2: "dot",
    0.4: "dash",
    0.6: "longdash",
    0.8: "dashdot"
}

# Cutoffs to run
CUTOFFS = [0.0, 0.2, 0.4, 0.6, 0.8]

fig = go.Figure()

for label, exp_name in experiments.items():
    # === Load config ===
    CONFIG_PATH = f"{DATA_ROOT}/{exp_name}/used_config.yaml"
    with open(CONFIG_PATH, 'r') as f:
        config = yaml.safe_load(f)

    num_insertions      = config["num_insertions"]
    num_coolings        = config["num_coolings"]
    midprice_step_size  = config["midprice_step_size"]
    hist_msgs           = config["n_messages"]
    n_gen_msgs          = config["n_gen_msgs"]

    # Base merged for filtration
    merged = summary_table(exp_name)

    for cutoff in CUTOFFS:
        # Apply volatility filtration
        x_filt, all_series_filt, merged_filt, hist_steps_filt, gen_block_filt = prepare_volatility_filtered_series(
            merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=cutoff
        )

        # === Mean & Std after filtration ===
        mean_series = all_series_filt.mean(axis=0)
        std_series  = all_series_filt.std(axis=0)

        # ±1 std band
        fig.add_trace(go.Scatter(
            x=np.concatenate([x_filt, x_filt[::-1]]),
            y=np.concatenate([mean_series + std_series, (mean_series - std_series)[::-1]]),
            fill='toself',
            fillcolor='rgba(0,0,0,0.05)',
            line=dict(color='rgba(0,0,0,0)'),
            hoverinfo='skip',
            showlegend=False
        ))

        # Mean line
        fig.add_trace(go.Scatter(
            x=x_filt, y=mean_series, mode='lines',
            name=f"{label} Mean (cutoff={cutoff})",
            line=dict(color=colors[label], width=3, dash=dash_for_cutoff.get(cutoff, "solid"))
        ))

# Add reference lines
fig.add_hline(y=0, line=dict(color='gray', dash='dash'), name="Zero line")

fig.update_layout(
    title="Midprice Mean ±1 Std — volatility cutoffs",
    xaxis_title="Steps (sampled midprice points)",
    yaxis_title="Price – first price",
    template="plotly_white",
    hovermode="x unified",
    height=800,
    width=1200,
    legend=dict(x=0.01, y=0.99),
)
fig.show()


Before filtering: 352 samples

After filtering: 352 samples

Before filtering: 352 samples

After filtering: 282 samples

Before filtering: 352 samples

After filtering: 212 samples

Before filtering: 352 samples

After filtering: 141 samples

Before filtering: 352 samples

After filtering: 71 samples

Before filtering: 360 samples

After filtering: 360 samples

Before filtering: 360 samples

After filtering: 288 samples

Before filtering: 360 samples

After filtering: 216 samples

Before filtering: 360 samples

After filtering: 144 samples

Before filtering: 360 samples

After filtering: 72 samples

Before filtering: 168 samples

After filtering: 168 samples

Before filtering: 168 samples

After filtering: 135 samples

Before filtering: 168 samples

After filtering: 101 samples

Before filtering: 168 samples

After filtering: 68 samples

Before filtering: 168 samples

After filtering: 34 samples



In [5]:
# ===== Prefix-only β(a) with filtration @ cutoffs = {0.0, 0.2} for all experiments =====
import os, yaml, glob, re
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression

# --- config / styles ---
CUTOFFS = [0.0]
PREFIX_MODE = "le"   # <= a
colors = {
    "Plain": "black",
    "Heuristic": "blue",
    "GenAI": "red",
}
dash_for_cutoff = {0.0: "solid", 0.2: "dot"}  # style by cutoff

# --- small helpers ---
def _safe_col(points_df, preferred, fallbacks):
    if preferred in points_df.columns:
        return preferred
    for c in fallbacks:
        if c in points_df.columns:
            return c
    return None

def build_and_merge(folder, batch_prefix, inp_prefix):
    # Load every .npy (shape (batch, time, feat))
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0

        batch = np.load(f)  # shape (batch_size, time, features)
        print(f"Loaded {nm} with shape {batch.shape}")

        rec.append({"range": rng, "iteration": itr, "batch": batch})
    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

    # Parse sample ids from "range"
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    # Explode each batch into one row per sample, with slicing logic
    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # (time, features)
            if r["iteration"] > 0:
                n_keep = 51 if r["iteration"] <= num_insertions else 50
                single = single[-n_keep:, :]
            rows.append({"id": int(sample_id), "iteration": int(r["iteration"]), "data": single})

    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    # Concatenate along time per id
    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)
        merged.append({"id": int(id_val), "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

def global_beta_plot_from_raw(
    b_seq_inp,
    msg_seq_raw,
    all_series,
    x,
    hist_steps=550,
    gen_block=50,
    num_insertions=20,
    *,
    beta_theory=0.5,
    samples_used=None,
    special_first=(79, 15),
    nbins_hist=30
):
    def compute_tables(b_seq_inp, msg_seq_raw, hist_steps, gen_block, num_insertions, beta_theory):
        if isinstance(b_seq_inp, pd.DataFrame):
            b_dict_local = {int(r.id): np.array(r.merged_data) for _, r in b_seq_inp.iterrows()}
        else:
            b_dict_local = b_seq_inp
        if isinstance(msg_seq_raw, pd.DataFrame):
            m_dict_local = {int(r.id): np.array(r.merged_data) for _, r in msg_seq_raw.iterrows()}
        else:
            m_dict_local = msg_seq_raw

        EVENT_TYPE_COL = 1
        PRICE_COL      = 3
        SIZE_COL       = 5

        eps = 1e-12
        tol = 1e-12

        sample_ids = sorted(set(b_dict_local.keys()) & set(m_dict_local.keys()))
        col_names = [f"ins_{i}" for i in range(1, num_insertions + 1)]
        x_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)
        y_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)

        def per_sample(sample_id):
            messages = m_dict_local[sample_id]
            T = len(messages)
            insertion_positions = hist_steps + np.arange(1, num_insertions + 1) * gen_block
            valid_insertions = [pos for pos in insertion_positions if pos < T]
            if not valid_insertions:
                return dict(valid_insertions=[], x_norm=None, y_norm=None,
                            mask_zero=None, mask_pos=None,
                            mu_x=np.nan, mu_y=np.nan, std_x=np.nan, std_y=np.nan)

            ref_idx = valid_insertions[0]
            reference_price_ticks = float(messages[ref_idx, PRICE_COL])

            insert_sizes  = messages[valid_insertions, SIZE_COL].astype(float)
            insert_prices = messages[valid_insertions, PRICE_COL].astype(float)
            Q_cum    = np.cumsum(insert_sizes)
            notional = np.cumsum(insert_sizes * insert_prices)
            vwap_ticks_series = notional / np.maximum(Q_cum, eps)
            impact_ticks = np.abs(vwap_ticks_series - reference_price_ticks)

            evt_types    = messages[:, EVENT_TYPE_COL].astype(int)
            exec_sizes   = np.where(evt_types == 4, messages[:, SIZE_COL].astype(float), 0.0)
            cum_exec_vol = np.cumsum(exec_sizes)
            V_exp = np.array([float(cum_exec_vol[idx-1] if (idx-1) >= 0 else 0.0)
                              for idx in valid_insertions])

            rel_size   = Q_cum / np.maximum(V_exp, eps)
            log_qv     = np.log(np.maximum(rel_size, eps))
            log_imp    = np.log(np.maximum(impact_ticks, eps))
            mask_zero  = impact_ticks <= tol

            used_x_raw = log_qv[~mask_zero]
            used_y_raw = log_imp[~mask_zero]

            if used_x_raw.size >= 2:
                mu_x = float(used_x_raw.mean());  mu_y = float(used_y_raw.mean())
                std_x = float(np.sqrt(max(float(used_x_raw.var(ddof=0)), eps)))
                std_y = float(np.sqrt(max(float(used_y_raw.var(ddof=0)), eps)))
                x_norm_all = (log_qv - mu_x) / std_x
                y_norm_all = (log_imp - mu_y) / std_y
            else:
                mu_x = mu_y = np.nan
                std_x = std_y = np.nan
                x_norm_all = np.full_like(log_qv, np.nan, dtype=float)
                y_norm_all = np.full_like(log_imp, np.nan, dtype=float)

            return dict(valid_insertions=valid_insertions,
                        x_norm=x_norm_all, y_norm=y_norm_all,
                        mask_zero=mask_zero)

        for sid in sample_ids:
            res = per_sample(sid)
            valid_ins = res["valid_insertions"]
            x_norm    = res["x_norm"]
            y_norm    = res["y_norm"]
            mask_zero = res["mask_zero"]

            for j, _idx in enumerate(valid_ins):
                col = f"ins_{j+1}"
                if mask_zero[j]:
                    x_df.loc[sid, col] = "ZERO"
                    y_df.loc[sid, col] = "ZERO"
                else:
                    xv = x_norm[j]; yv = y_norm[j]
                    x_df.loc[sid, col] = float(xv) if np.isfinite(xv) else np.nan
                    y_df.loc[sid, col] = float(yv) if np.isfinite(yv) else np.nan

        return x_df, y_df

    x_df, y_df = compute_tables(
        b_seq_inp, msg_seq_raw, hist_steps, gen_block, num_insertions, beta_theory
    )

    # Tidy points
    all_ids = list(x_df.index)
    ordered_ids = [sid for sid in (special_first or []) if sid in all_ids]
    ordered_ids += [sid for sid in sorted(all_ids) if sid not in ordered_ids]
    if samples_used is not None:
        ordered_ids = ordered_ids[:samples_used]

    rows = []
    for sid in ordered_ids:
        for j, col in enumerate(x_df.columns, start=1):
            xv = x_df.loc[sid, col]
            yv = y_df.loc[sid, col]
            if isinstance(xv, (int, float, np.floating)) and isinstance(yv, (int, float, np.floating)):
                if np.isfinite(xv) and np.isfinite(yv):
                    rows.append({"sample_id": sid, "insertion": j, "x": float(xv), "y": float(yv)})
    points_df = pd.DataFrame(rows)
    if points_df.empty:
        return None, pd.DataFrame()

    return None, points_df

def _fit_beta_prefix(points_df, prefix_mode="le"):
    """
    Compute prefix-only β(a): fit y ~ x on all points with insertion <= a (or < a).
    Returns (a_values, betas_prefix).
    """
    if points_df is None or len(points_df) == 0:
        return np.array([]), np.array([])

    COL_INS = "insertion" if "insertion" in points_df.columns else _safe_col(points_df, "insertion", ["ins_idx","idx"])
    COL_X   = "x"         if "x"         in points_df.columns else _safe_col(points_df, "x", ["log_qv"])
    COL_Y   = "y"         if "y"         in points_df.columns else _safe_col(points_df, "y", ["log_impact"])
    if COL_INS is None or COL_X is None or COL_Y is None:
        return np.array([]), np.array([])

    ins_series = pd.to_numeric(points_df[COL_INS], errors="coerce").dropna().astype(int)
    if ins_series.empty:
        return np.array([]), np.array([])

    ins_max = int(ins_series.max())
    a_values = np.arange(1, ins_max + 1, dtype=int)

    lr = LinearRegression()
    betas = np.full_like(a_values, np.nan, dtype=float)
    for i, a in enumerate(a_values):
        sub = points_df[ins_series <= a] if prefix_mode == "le" else points_df[ins_series < a]
        if len(sub) >= 2:
            X = sub[COL_X].to_numpy().reshape(-1, 1)
            Y = sub[COL_Y].to_numpy()
            mask = np.isfinite(X).ravel() & np.isfinite(Y)
            if mask.sum() >= 2:
                lr.fit(X[mask].reshape(-1, 1), Y[mask])
                betas[i] = float(lr.coef_[0])
    return a_values, betas

# --- main plotting ---
fig = go.Figure()

for label, exp_name in experiments.items():
    # 1) Load config
    CONFIG_PATH = f"{DATA_ROOT}/{exp_name}/used_config.yaml"
    with open(CONFIG_PATH, "r") as f:
        config = yaml.safe_load(f)

    num_insertions      = int(config["num_insertions"])
    midprice_step_size  = int(config["midprice_step_size"])
    hist_msgs           = int(config["n_messages"])
    n_gen_msgs          = int(config["n_gen_msgs"])

    # 2) Base series once (for time axes and derived steps)
    merged = summary_table(exp_name)
    x_base, all_series_base = build_zero_padded_series(hist_msgs, n_gen_msgs, midprice_step_size, merged)
    hist_steps_base = hist_msgs // midprice_step_size
    gen_steps_base  = n_gen_msgs // midprice_step_size
    gen_block_base  = gen_steps_base + 1

    # 3) Build merged dicts for b & m ONCE per experiment
    experiment_name = exp_name
    b_folder      = f"{DATA_ROOT}/{experiment_name}/b_seq_gen_doubled"
    b_batch_pref  = "b_seq_gen_doubled_batch"
    b_inp_pref    = "b_seq_inp"

    m_folder      = f"{DATA_ROOT}/{experiment_name}/msgs_decoded_doubled"
    m_batch_pref  = "msgs_decoded_doubled_batch"
    m_inp_pref    = "m_seq_raw_inp"

    _, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)
    _, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

    b_dict_full = {int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows()}
    m_dict_full = {int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows()}

    # add a zero row at the top of each sample
    for d in (b_dict_full, m_dict_full):
        for key, arr in d.items():
            zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
            d[key] = np.vstack([zero, arr])

    # 4) For each cutoff, filter + compute prefix β(a)
    for cutoff in CUTOFFS:
        x_filt, all_series_filt, merged_filt, hist_steps_filt, gen_block_filt = prepare_volatility_filtered_series(
            merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=cutoff
        )
        filtered_ids = set(merged_filt["id"])
        print(f"[{label}] cutoff={cutoff}: {len(filtered_ids)} samples retained")

        b_dict = {k: v for k, v in b_dict_full.items() if k in filtered_ids}
        m_dict = {k: v for k, v in m_dict_full.items() if k in filtered_ids}
        if not b_dict or not m_dict:
            print(f"[{label}] cutoff={cutoff}: no overlapping samples; skipping.")
            continue

        # Build points_df (ignore fig)
        _, points_df = global_beta_plot_from_raw(
            b_dict, m_dict, all_series_filt, x_filt,
            hist_steps=hist_steps_filt,
            gen_block=gen_block_filt,
            num_insertions=num_insertions,
            beta_theory=0.5,
            samples_used=None,
            special_first=(79, 15),
            nbins_hist=30
        )
        if points_df is None or points_df.empty:
            print(f"[{label}] cutoff={cutoff}: points_df empty; skipping.")
            continue

        a_vals, betas_prefix = _fit_beta_prefix(points_df, prefix_mode=PREFIX_MODE)
        if a_vals.size == 0:
            print(f"[{label}] cutoff={cutoff}: no valid a-values; skipping.")
            continue

        fig.add_trace(go.Scatter(
            x=a_vals,
            y=betas_prefix,
            mode="lines+markers",
            name=f"{label} [:a] (cutoff={cutoff})",
            line=dict(width=3, color=colors.get(label, None), dash=dash_for_cutoff.get(cutoff, "solid"))
        ))

# --- guides + layout ---
if len(fig.data) > 0:
    x_min = min(int(np.nanmin(tr.x)) for tr in fig.data if len(tr.x))
    x_max = max(int(np.nanmax(tr.x)) for tr in fig.data if len(tr.x))
else:
    x_min, x_max = 1, 20

fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.5, 0.5],
    mode="lines", line=dict(dash="dash"), name="y = 0.5"
))
fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.0, 0.0],
    mode="lines", name="y = 0"
))

fig.update_xaxes(title_text="a (insertion index threshold)", dtick=1)
fig.update_yaxes(title_text="β (slope)")
fig.update_layout(
    title="Prefix-only Global β(a) across experiments — volatility cutoffs = 0.0, 0.2",
    template="plotly_white",
    width=1150,
    height=620,
    legend=dict(x=0.01, y=0.99)
)
fig.show()

Loaded b_seq_gen_doubled_batch_[3868, 6391, 6213, 1572]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[7182, 2137, 4796, 1907]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[2640, 4030, 1086, 2674]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[410, 4610, 3384, 3588]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[6910, 2659, 4946, 2810]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[5317, 2178, 3859, 1730]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[3306, 251, 947, 2413]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[7078, 3728, 1563, 470]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[2649, 1217, 5285, 1579]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[6887, 5586, 415, 7000]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[3393, 6350, 2902, 6291]_iter_0.npy with sha

In [6]:
# ===== Prefix-only β(a) vs sample size (N=25,50,75,...) across experiments =====
import os, yaml, glob, re
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression

# ---------------- config / styles ----------------
SAMPLE_SIZES = [25, 50, 75, 100, 150, 200]   # adjust as you wish
PREFIX_MODE  = "le"   # <= a

colors = {
    "Plain": "black",
    "Heuristic": "blue",
    "GenAI": "red",
}

dash_for_size = {
    25:  "dot",
    50:  "dash",
    75:  "longdash",
    100: "dashdot",
    150: "longdashdot",
    200: "solid",
}

# --------------- small helpers -------------------
def _safe_col(points_df, preferred, fallbacks):
    if preferred in points_df.columns:
        return preferred
    for c in fallbacks:
        if c in points_df.columns:
            return c
    return None

def build_and_merge(folder, batch_prefix, inp_prefix):
    # Load every .npy (shape (batch, time, feat))
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0

        batch = np.load(f)  # shape (batch_size, time, features)
        print(f"Loaded {nm} with shape {batch.shape}")
        rec.append({"range": rng, "iteration": itr, "batch": batch})

    if not rec:
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)

    # Parse sample ids from "range"
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    # Explode each batch into one row per sample, with slicing logic
    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # (time, features)
            if r["iteration"] > 0:
                n_keep = 51 if r["iteration"] <= num_insertions else 50
                single = single[-n_keep:, :]
            rows.append({"id": int(sample_id), "iteration": int(r["iteration"]), "data": single})

    if not rows:
        return df, pd.DataFrame(), pd.DataFrame()

    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    # Concatenate along time per id
    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)
        merged.append({"id": int(id_val), "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

def global_beta_points_df(
    b_seq_inp,
    msg_seq_raw,
    *,
    hist_steps=550,
    gen_block=50,
    num_insertions=20,
    special_first=(79, 15),
):
    """
    Build a tidy points_df with columns ['sample_id','insertion','x','y'] where
    x = normalized log relative size, y = normalized log impact, for each valid insertion j.
    """
    # Accept either dicts or merged DataFrames
    if isinstance(b_seq_inp, pd.DataFrame):
        b_dict_local = {int(r.id): np.array(r.merged_data) for _, r in b_seq_inp.iterrows()}
    else:
        b_dict_local = b_seq_inp or {}
    if isinstance(msg_seq_raw, pd.DataFrame):
        m_dict_local = {int(r.id): np.array(r.merged_data) for _, r in msg_seq_raw.iterrows()}
    else:
        m_dict_local = msg_seq_raw or {}

    if not b_dict_local or not m_dict_local:
        return pd.DataFrame()

    EVENT_TYPE_COL = 1
    PRICE_COL      = 3
    SIZE_COL       = 5
    eps = 1e-12
    tol = 1e-12

    sample_ids = sorted(set(b_dict_local.keys()) & set(m_dict_local.keys()))
    if not sample_ids:
        return pd.DataFrame()

    col_names = [f"ins_{i}" for i in range(1, num_insertions + 1)]
    x_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)
    y_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)

    def per_sample(sample_id):
        messages = m_dict_local[sample_id]
        T = len(messages)
        insertion_positions = hist_steps + np.arange(1, num_insertions + 1) * gen_block
        valid_insertions = [pos for pos in insertion_positions if pos < T]
        if not valid_insertions:
            return dict(valid_insertions=[], x_norm=None, y_norm=None, mask_zero=None)

        ref_idx = valid_insertions[0]
        reference_price_ticks = float(messages[ref_idx, PRICE_COL])

        insert_sizes  = messages[valid_insertions, SIZE_COL].astype(float)
        insert_prices = messages[valid_insertions, PRICE_COL].astype(float)
        Q_cum    = np.cumsum(insert_sizes)
        notional = np.cumsum(insert_sizes * insert_prices)
        vwap_ticks_series = notional / np.maximum(Q_cum, eps)
        impact_ticks = np.abs(vwap_ticks_series - reference_price_ticks)

        evt_types    = messages[:, EVENT_TYPE_COL].astype(int)
        exec_sizes   = np.where(evt_types == 4, messages[:, SIZE_COL].astype(float), 0.0)
        cum_exec_vol = np.cumsum(exec_sizes)
        V_exp = np.array([float(cum_exec_vol[idx-1] if (idx-1) >= 0 else 0.0)
                          for idx in valid_insertions])

        rel_size   = Q_cum / np.maximum(V_exp, eps)
        log_qv     = np.log(np.maximum(rel_size, eps))
        log_imp    = np.log(np.maximum(impact_ticks, eps))
        mask_zero  = impact_ticks <= tol

        used_x_raw = log_qv[~mask_zero]
        used_y_raw = log_imp[~mask_zero]

        if used_x_raw.size >= 2:
            mu_x = float(used_x_raw.mean());  mu_y = float(used_y_raw.mean())
            std_x = float(np.sqrt(max(float(used_x_raw.var(ddof=0)), eps)))
            std_y = float(np.sqrt(max(float(used_y_raw.var(ddof=0)), eps)))
            x_norm_all = (log_qv - mu_x) / std_x
            y_norm_all = (log_imp - mu_y) / std_y
        else:
            x_norm_all = np.full_like(log_qv, np.nan, dtype=float)
            y_norm_all = np.full_like(log_imp, np.nan, dtype=float)

        return dict(valid_insertions=valid_insertions,
                    x_norm=x_norm_all, y_norm=y_norm_all, mask_zero=mask_zero)

    # fill x_df / y_df
    for sid in sample_ids:
        res = per_sample(sid)
        valid_ins = res["valid_insertions"]
        x_norm    = res["x_norm"]
        y_norm    = res["y_norm"]
        mask_zero = res["mask_zero"]
        for j, _idx in enumerate(valid_ins):
            col = f"ins_{j+1}"
            if mask_zero[j]:
                x_df.loc[sid, col] = "ZERO"
                y_df.loc[sid, col] = "ZERO"
            else:
                xv = x_norm[j]; yv = y_norm[j]
                x_df.loc[sid, col] = float(xv) if np.isfinite(xv) else np.nan
                y_df.loc[sid, col] = float(yv) if np.isfinite(yv) else np.nan

    # tidy points
    all_ids = list(x_df.index)
    ordered_ids = [sid for sid in (special_first or []) if sid in all_ids]
    ordered_ids += [sid for sid in sorted(all_ids) if sid not in ordered_ids]

    rows = []
    for sid in ordered_ids:
        for j, col in enumerate(x_df.columns, start=1):
            xv = x_df.loc[sid, col]
            yv = y_df.loc[sid, col]
            if isinstance(xv, (int, float, np.floating)) and isinstance(yv, (int, float, np.floating)):
                if np.isfinite(xv) and np.isfinite(yv):
                    rows.append({"sample_id": sid, "insertion": j, "x": float(xv), "y": float(yv)})
    return pd.DataFrame(rows)

def _fit_beta_prefix(points_df, prefix_mode="le"):
    """
    Compute prefix-only β(a): fit y ~ x on all points with insertion <= a (or < a).
    Returns (a_values, betas_prefix).
    """
    if points_df is None or len(points_df) == 0:
        return np.array([]), np.array([])

    COL_INS = "insertion" if "insertion" in points_df.columns else _safe_col(points_df, "insertion", ["ins_idx","idx"])
    COL_X   = "x"         if "x"         in points_df.columns else _safe_col(points_df, "x", ["log_qv"])
    COL_Y   = "y"         if "y"         in points_df.columns else _safe_col(points_df, "y", ["log_impact"])
    if COL_INS is None or COL_X is None or COL_Y is None:
        return np.array([]), np.array([])

    ins_series = pd.to_numeric(points_df[COL_INS], errors="coerce").dropna().astype(int)
    if ins_series.empty:
        return np.array([]), np.array([])

    ins_max = int(ins_series.max())
    a_values = np.arange(1, ins_max + 1, dtype=int)

    lr = LinearRegression()
    betas = np.full_like(a_values, np.nan, dtype=float)
    for i, a in enumerate(a_values):
        sub = points_df[ins_series <= a] if prefix_mode == "le" else points_df[ins_series < a]
        if len(sub) >= 2:
            X = sub[COL_X].to_numpy().reshape(-1, 1)
            Y = sub[COL_Y].to_numpy()
            mask = np.isfinite(X).ravel() & np.isfinite(Y)
            if mask.sum() >= 2:
                lr.fit(X[mask].reshape(-1, 1), Y[mask])
                betas[i] = float(lr.coef_[0])
    return a_values, betas

# ------------------ main plotting ------------------
fig = go.Figure()

for label, exp_name in experiments.items():
    # 1) Load config
    CONFIG_PATH = f"{DATA_ROOT}/{exp_name}/used_config.yaml"
    with open(CONFIG_PATH, "r") as f:
        config = yaml.safe_load(f)

    num_insertions      = int(config["num_insertions"])
    midprice_step_size  = int(config["midprice_step_size"])
    hist_msgs           = int(config["n_messages"])
    n_gen_msgs          = int(config["n_gen_msgs"])

    # 2) Base series (time/grid params)
    merged = summary_table(exp_name)
    x_base, all_series_base = build_zero_padded_series(hist_msgs, n_gen_msgs, midprice_step_size, merged)
    hist_steps_base = hist_msgs // midprice_step_size
    gen_steps_base  = n_gen_msgs // midprice_step_size
    gen_block_base  = gen_steps_base + 1

    # 3) Build merged dicts for b & m once
    experiment_name = exp_name
    b_folder      = f"{DATA_ROOT}/{experiment_name}/b_seq_gen_doubled"
    b_batch_pref  = "b_seq_gen_doubled_batch"
    b_inp_pref    = "b_seq_inp"

    m_folder      = f"{DATA_ROOT}/{experiment_name}/msgs_decoded_doubled"
    m_batch_pref  = "msgs_decoded_doubled_batch"
    m_inp_pref    = "m_seq_raw_inp"

    _, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)
    _, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

    if b_merged.empty or m_merged.empty:
        print(f"[{label}] No merged b/m data; skipping.")
        continue

    b_dict_full = {int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows()}
    m_dict_full = {int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows()}

    # add a zero row at the top of each sample
    for d in (b_dict_full, m_dict_full):
        for key, arr in d.items():
            zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
            d[key] = np.vstack([zero, arr])

    # ---------------- iterate over sample sizes ----------------
    all_ids_sorted = sorted(set(b_dict_full.keys()) & set(m_dict_full.keys()))
    print(f"[{label}] total available samples: {len(all_ids_sorted)}")

    for N in SAMPLE_SIZES:
        if len(all_ids_sorted) == 0:
            print(f"[{label}] N={N}: no ids available; skipping.")
            continue

        ids_N = all_ids_sorted[:min(N, len(all_ids_sorted))]
        print(f"[{label}] N={N}: using {len(ids_N)} samples")

        b_dict = {k: b_dict_full[k] for k in ids_N if k in b_dict_full}
        m_dict = {k: m_dict_full[k] for k in ids_N if k in m_dict_full}

        # Build points_df for this N
        points_df = global_beta_points_df(
            b_dict, m_dict,
            hist_steps=hist_steps_base,
            gen_block=gen_block_base,
            num_insertions=num_insertions,
            special_first=(79, 15),
        )
        if points_df is None or points_df.empty:
            print(f"[{label}] N={N}: points_df empty; skipping.")
            continue

        a_vals, betas_prefix = _fit_beta_prefix(points_df, prefix_mode=PREFIX_MODE)
        if a_vals.size == 0:
            print(f"[{label}] N={N}: no valid a-values; skipping.")
            continue

        fig.add_trace(go.Scatter(
            x=a_vals,
            y=betas_prefix,
            mode="lines+markers",
            name=f"(N={len(ids_N)}) {label} [:a] ",
            line=dict(width=3, color=colors.get(label, None), dash=dash_for_size.get(N, "solid"))
        ))

# ---------------- guides + layout ----------------
if len(fig.data) > 0:
    x_min = min(int(np.nanmin(tr.x)) for tr in fig.data if len(tr.x))
    x_max = max(int(np.nanmax(tr.x)) for tr in fig.data if len(tr.x))
else:
    x_min, x_max = 1, 20

fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.5, 0.5],
    mode="lines", line=dict(dash="dash"), name="y = 0.5"
))
fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.0, 0.0],
    mode="lines", name="y = 0"
))

fig.update_xaxes(title_text="a (insertion index threshold)", dtick=1)
fig.update_yaxes(title_text="β (slope)")
fig.update_layout(
    title="Prefix-only Global β(a) across experiments — growing sample sizes (N)",
    template="plotly_white",
    width=1150,
    height=620,
    legend=dict(x=0.01, y=0.99)
)
fig.show()

Loaded b_seq_gen_doubled_batch_[3868, 6391, 6213, 1572]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[7182, 2137, 4796, 1907]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[2640, 4030, 1086, 2674]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[410, 4610, 3384, 3588]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[6910, 2659, 4946, 2810]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[5317, 2178, 3859, 1730]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[3306, 251, 947, 2413]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[7078, 3728, 1563, 470]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[2649, 1217, 5285, 1579]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[6887, 5586, 415, 7000]_iter_0.npy with shape (4, 3020, 501)
Loaded b_seq_gen_doubled_batch_[3393, 6350, 2902, 6291]_iter_0.npy with sha

In [7]:
# ===== Prefix-only β(a) vs sample size (N=25,50,75,...) — GenAI with volatility filters 0.0/0.2/0.4 =====
import os, yaml, glob, re
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression

# ---------------- config / styles ----------------
SAMPLE_SIZES = [25, 50, 75, 100, 150, 200]   # изменяй по необходимости
PREFIX_MODE  = "le"   # <= a

# используем ТОЛЬКО GenAI из твоего словаря experiments
GENAI_LABEL = "GenAI"

# Цвет по cutoff (оттенки красного), штриховка — по N
color_for_cutoff = {
    0.0: "red",
    0.2: "orangered",
    0.4: "darkred",
}
CUT_OFFS = [0.0, 0.2, 0.4]

dash_for_size = {
    25:  "dot",
    50:  "dash",
    75:  "longdash",
    100: "dashdot",
    150: "longdashdot",
    200: "solid",
}

# --------------- small helpers -------------------
def _safe_col(points_df, preferred, fallbacks):
    if preferred in points_df.columns:
        return preferred
    for c in fallbacks:
        if c in points_df.columns:
            return c
    return None

def build_and_merge(folder, batch_prefix, inp_prefix):
    # Load every .npy (shape (batch, time, feat))
    files   = glob.glob(os.path.join(folder, "*.npy"))
    rx_iter = re.compile(rf"{re.escape(batch_prefix)}_\[(.+)\]_iter_(\d+)\.npy$")
    rx_inp  = re.compile(rf"{re.escape(inp_prefix)}_\[(.+)\]\.npy$")
    rec = []
    for f in files:
        nm = os.path.basename(f)
        m  = rx_iter.match(nm)
        if m:
            rng, itr = m.group(1).replace(" ", ""), int(m.group(2))
        else:
            m2 = rx_inp.match(nm)
            if not m2:
                continue
            rng, itr = m2.group(1).replace(" ", ""), 0

        batch = np.load(f)  # shape (batch_size, time, features)
        print(f"Loaded {nm} with shape {batch.shape}")
        rec.append({"range": rng, "iteration": itr, "batch": batch})

    if not rec:
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    df = pd.DataFrame(rec).sort_values(["range","iteration"]).reset_index(drop=True)
    df["ids"] = df["range"].str.split(",").apply(lambda L: [int(x) for x in L])

    rows = []
    for _, r in df.iterrows():
        for idx, sample_id in enumerate(r["ids"]):
            single = r["batch"][idx]   # (time, features)
            if r["iteration"] > 0:
                n_keep = 51 if r["iteration"] <= num_insertions else 50
                single = single[-n_keep:, :]
            rows.append({"id": int(sample_id), "iteration": int(r["iteration"]), "data": single})

    if not rows:
        return df, pd.DataFrame(), pd.DataFrame()

    df_sorted = pd.DataFrame(rows).sort_values(["id","iteration"]).reset_index(drop=True)

    merged = []
    for id_val, grp in df_sorted.groupby("id", sort=True):
        arrs = [row.data for _, row in grp.iterrows()]
        big  = np.concatenate(arrs, axis=0)
        merged.append({"id": int(id_val), "merged_data": big})
    merged_df = pd.DataFrame(merged).sort_values("id").reset_index(drop=True)

    return df, df_sorted, merged_df

def global_beta_points_df(
    b_seq_inp,
    msg_seq_raw,
    *,
    hist_steps=550,
    gen_block=50,
    num_insertions=20,
    special_first=(79, 15),
):
    """
    Build tidy points_df: ['sample_id','insertion','x','y'].
    x = normalized log relative size, y = normalized log impact.
    """
    if isinstance(b_seq_inp, pd.DataFrame):
        b_dict_local = {int(r.id): np.array(r.merged_data) for _, r in b_seq_inp.iterrows()}
    else:
        b_dict_local = b_seq_inp or {}
    if isinstance(msg_seq_raw, pd.DataFrame):
        m_dict_local = {int(r.id): np.array(r.merged_data) for _, r in msg_seq_raw.iterrows()}
    else:
        m_dict_local = msg_seq_raw or {}

    if not b_dict_local or not m_dict_local:
        return pd.DataFrame()

    EVENT_TYPE_COL = 1
    PRICE_COL      = 3
    SIZE_COL       = 5
    eps = 1e-12
    tol = 1e-12

    sample_ids = sorted(set(b_dict_local.keys()) & set(m_dict_local.keys()))
    if not sample_ids:
        return pd.DataFrame()

    col_names = [f"ins_{i}" for i in range(1, num_insertions + 1)]
    x_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)
    y_df = pd.DataFrame(index=sample_ids, columns=col_names, dtype=object)

    def per_sample(sample_id):
        messages = m_dict_local[sample_id]
        T = len(messages)
        insertion_positions = hist_steps + np.arange(1, num_insertions + 1) * gen_block
        valid_insertions = [pos for pos in insertion_positions if pos < T]
        if not valid_insertions:
            return dict(valid_insertions=[], x_norm=None, y_norm=None, mask_zero=None)

        ref_idx = valid_insertions[0]
        reference_price_ticks = float(messages[ref_idx, PRICE_COL])

        insert_sizes  = messages[valid_insertions, SIZE_COL].astype(float)
        insert_prices = messages[valid_insertions, PRICE_COL].astype(float)
        Q_cum    = np.cumsum(insert_sizes)
        notional = np.cumsum(insert_sizes * insert_prices)
        vwap_ticks_series = notional / np.maximum(Q_cum, eps)
        impact_ticks = np.abs(vwap_ticks_series - reference_price_ticks)

        evt_types    = messages[:, EVENT_TYPE_COL].astype(int)
        exec_sizes   = np.where(evt_types == 4, messages[:, SIZE_COL].astype(float), 0.0)
        cum_exec_vol = np.cumsum(exec_sizes)
        V_exp = np.array([float(cum_exec_vol[idx-1] if (idx-1) >= 0 else 0.0)
                          for idx in valid_insertions])

        rel_size   = Q_cum / np.maximum(V_exp, eps)
        log_qv     = np.log(np.maximum(rel_size, eps))
        log_imp    = np.log(np.maximum(impact_ticks, eps))
        mask_zero  = impact_ticks <= tol

        used_x_raw = log_qv[~mask_zero]
        used_y_raw = log_imp[~mask_zero]

        if used_x_raw.size >= 2:
            mu_x = float(used_x_raw.mean());  mu_y = float(used_y_raw.mean())
            std_x = float(np.sqrt(max(float(used_x_raw.var(ddof=0)), eps)))
            std_y = float(np.sqrt(max(float(used_y_raw.var(ddof=0)), eps)))
            x_norm_all = (log_qv - mu_x) / std_x
            y_norm_all = (log_imp - mu_y) / std_y
        else:
            x_norm_all = np.full_like(log_qv, np.nan, dtype=float)
            y_norm_all = np.full_like(log_imp, np.nan, dtype=float)

        return dict(valid_insertions=valid_insertions,
                    x_norm=x_norm_all, y_norm=y_norm_all, mask_zero=mask_zero)

    for sid in sample_ids:
        res = per_sample(sid)
        valid_ins = res["valid_insertions"]
        x_norm    = res["x_norm"]
        y_norm    = res["y_norm"]
        mask_zero = res["mask_zero"]
        for j, _idx in enumerate(valid_ins):
            col = f"ins_{j+1}"
            if mask_zero[j]:
                x_df.loc[sid, col] = "ZERO"
                y_df.loc[sid, col] = "ZERO"
            else:
                xv = x_norm[j]; yv = y_norm[j]
                x_df.loc[sid, col] = float(xv) if np.isfinite(xv) else np.nan
                y_df.loc[sid, col] = float(yv) if np.isfinite(yv) else np.nan

    # tidy
    all_ids = list(x_df.index)
    ordered_ids = [sid for sid in (special_first or []) if sid in all_ids]
    ordered_ids += [sid for sid in sorted(all_ids) if sid not in ordered_ids]

    rows = []
    for sid in ordered_ids:
        for j, col in enumerate(x_df.columns, start=1):
            xv = x_df.loc[sid, col]
            yv = y_df.loc[sid, col]
            if isinstance(xv, (int, float, np.floating)) and isinstance(yv, (int, float, np.floating)):
                if np.isfinite(xv) and np.isfinite(yv):
                    rows.append({"sample_id": sid, "insertion": j, "x": float(xv), "y": float(yv)})
    return pd.DataFrame(rows)

def _fit_beta_prefix(points_df, prefix_mode="le"):
    """Compute prefix-only β(a)."""
    if points_df is None or len(points_df) == 0:
        return np.array([]), np.array([])

    COL_INS = "insertion" if "insertion" in points_df.columns else _safe_col(points_df, "insertion", ["ins_idx","idx"])
    COL_X   = "x"         if "x"         in points_df.columns else _safe_col(points_df, "x", ["log_qv"])
    COL_Y   = "y"         if "y"         in points_df.columns else _safe_col(points_df, "y", ["log_impact"])
    if COL_INS is None or COL_X is None or COL_Y is None:
        return np.array([]), np.array([])

    ins_series = pd.to_numeric(points_df[COL_INS], errors="coerce").dropna().astype(int)
    if ins_series.empty:
        return np.array([]), np.array([])

    ins_max = int(ins_series.max())
    a_values = np.arange(1, ins_max + 1, dtype=int)

    lr = LinearRegression()
    betas = np.full_like(a_values, np.nan, dtype=float)
    for i, a in enumerate(a_values):
        sub = points_df[ins_series <= a] if prefix_mode == "le" else points_df[ins_series < a]
        if len(sub) >= 2:
            X = sub[COL_X].to_numpy().reshape(-1, 1)
            Y = sub[COL_Y].to_numpy()
            mask = np.isfinite(X).ravel() & np.isfinite(Y)
            if mask.sum() >= 2:
                lr.fit(X[mask].reshape(-1, 1), Y[mask])
                betas[i] = float(lr.coef_[0])
    return a_values, betas

# ------------------ main ------------------
fig = go.Figure()

# 1) найти путь к GenAI
if GENAI_LABEL not in experiments:
    raise KeyError(f"'{GENAI_LABEL}' not found in experiments dict")
exp_name = experiments[GENAI_LABEL]

# 2) загрузить конфиг и базовые ряды
CONFIG_PATH = f"{DATA_ROOT}/{exp_name}/used_config.yaml"
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

num_insertions      = int(config["num_insertions"])
midprice_step_size  = int(config["midprice_step_size"])
hist_msgs           = int(config["n_messages"])
n_gen_msgs          = int(config["n_gen_msgs"])

merged = summary_table(exp_name)
x_base, all_series_base = build_zero_padded_series(hist_msgs, n_gen_msgs, midprice_step_size, merged)
hist_steps_base = hist_msgs // midprice_step_size
gen_steps_base  = n_gen_msgs // midprice_step_size
gen_block_base  = gen_steps_base + 1

# 3) собрать b/m один раз (без фильтра)
experiment_name = exp_name
b_folder      = f"{DATA_ROOT}/{experiment_name}/b_seq_gen_doubled"
b_batch_pref  = "b_seq_gen_doubled_batch"
b_inp_pref    = "b_seq_inp"

m_folder      = f"{DATA_ROOT}/{experiment_name}/msgs_decoded_doubled"
m_batch_pref  = "msgs_decoded_doubled_batch"
m_inp_pref    = "m_seq_raw_inp"

_, b_sorted, b_merged = build_and_merge(b_folder, b_batch_pref, b_inp_pref)
_, m_sorted, m_merged = build_and_merge(m_folder, m_batch_pref, m_inp_pref)

if b_merged.empty or m_merged.empty:
    raise RuntimeError("No merged b/m data for GenAI")

b_dict_full = {int(r.id): np.array(r.merged_data) for _, r in b_merged.iterrows()}
m_dict_full = {int(r.id): np.array(r.merged_data) for _, r in m_merged.iterrows()}

# добавить нулевую строку
for d in (b_dict_full, m_dict_full):
    for key, arr in d.items():
        zero = np.zeros((1, arr.shape[1]), dtype=arr.dtype)
        d[key] = np.vstack([zero, arr])

# 4) пройдём по cutoffs и по N
for cutoff in CUT_OFFS:
    # фильтрация только merged-таблиц для отбора ID
    x_filt, all_series_filt, merged_filt, hist_steps_filt, gen_block_filt = prepare_volatility_filtered_series(
        merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=cutoff
    )
    filtered_ids_all = sorted(set(merged_filt["id"]))
    print(f"[GenAI] cutoff={cutoff}: {len(filtered_ids_all)} samples available")

    for N in SAMPLE_SIZES:
        if not filtered_ids_all:
            print(f"[GenAI] cutoff={cutoff}, N={N}: no ids; skipping")
            continue
        ids_N = filtered_ids_all[:min(N, len(filtered_ids_all))]
        print(f"[GenAI] cutoff={cutoff}, N={N}: using {len(ids_N)} samples")

        b_dict = {k: b_dict_full[k] for k in ids_N if k in b_dict_full}
        m_dict = {k: m_dict_full[k] for k in ids_N if k in m_dict_full}
        if not b_dict or not m_dict:
            print(f"[GenAI] cutoff={cutoff}, N={N}: empty dicts; skipping")
            continue

        # points_df и β(a)
        points_df = global_beta_points_df(
            b_dict, m_dict,
            hist_steps=hist_steps_filt,   # важно: те же шаги, что после фильтра
            gen_block=gen_block_filt,
            num_insertions=num_insertions,
            special_first=(79, 15),
        )
        if points_df is None or points_df.empty:
            print(f"[GenAI] cutoff={cutoff}, N={N}: points_df empty; skipping.")
            continue

        a_vals, betas_prefix = _fit_beta_prefix(points_df, prefix_mode=PREFIX_MODE)
        if a_vals.size == 0:
            print(f"[GenAI] cutoff={cutoff}, N={N}: no valid a; skipping.")
            continue

        fig.add_trace(go.Scatter(
            x=a_vals, y=betas_prefix,
            mode="lines+markers",
            name=f"v={cutoff} (N={len(ids_N)}) [:a] GenAI",
            line=dict(width=3, color=color_for_cutoff.get(cutoff, "red"),
                      dash=dash_for_size.get(N, "solid"))
        ))

# ---------------- guides + layout ----------------
if len(fig.data) > 0:
    x_min = min(int(np.nanmin(tr.x)) for tr in fig.data if len(tr.x))
    x_max = max(int(np.nanmax(tr.x)) for tr in fig.data if len(tr.x))
else:
    x_min, x_max = 1, 20

fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.5, 0.5],
    mode="lines", line=dict(dash="dash"), name="y = 0.5"
))
fig.add_trace(go.Scatter(
    x=[x_min, x_max], y=[0.0, 0.0],
    mode="lines", name="y = 0"
))

fig.update_xaxes(title_text="a (insertion index threshold)", dtick=1)
fig.update_yaxes(title_text="β (slope)")
fig.update_layout(
    title="Prefix-only Global β(a) — GenAI with volatility filters (v=0.0,0.2,0.4) and growing N",
    template="plotly_white",
    width=1150,
    height=620,
    legend=dict(x=0.01, y=0.99)
)
fig.show()

Loaded b_seq_gen_doubled_batch_[975, 6842, 5118, 5252]_iter_40.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[869, 5837, 2713, 5383]_iter_46.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[5964, 4470, 3909, 6469]_iter_4.npy with shape (4, 501, 501)
Loaded b_seq_gen_doubled_batch_[346, 1483, 2512, 5215]_iter_11.npy with shape (4, 501, 501)
Loaded b_seq_gen_doubled_batch_[3683, 6629, 1469, 5105]_iter_35.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[1652, 6865, 3808, 6681]_iter_41.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[2029, 5967, 5558, 3773]_iter_29.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[3683, 6629, 1469, 5105]_iter_28.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[3374, 2414, 3786, 879]_iter_19.npy with shape (4, 501, 501)
Loaded b_seq_gen_doubled_batch_[875, 731, 3699, 4695]_iter_30.npy with shape (4, 500, 501)
Loaded b_seq_gen_doubled_batch_[2640, 4030, 1086, 2674]_iter_4.npy with shape

In [8]:
# --- Prepare base GenAI data (no filtering) ---
x, all_series, merged_filt, hist_steps, gen_block = prepare_volatility_filtered_series(
    merged, hist_msgs, n_gen_msgs, midprice_step_size, volatility_cutoff=0.0
)

Before filtering: 168 samples

After filtering: 168 samples



In [9]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from sklearn.linear_model import LinearRegression

# ================= CONFIG =================
DEBUG = True
MAX_IDS_TO_PRINT = 30
START_PREFIX_AT = 1
PREFIX_MODE = "le"         # "lt" -> insertion < a, "le" -> insertion <= a
SAMPLE_SIZES = [25, 50, 75, 100, 150, 200]

pd.set_option("display.width", 160)
pd.set_option("display.max_columns", 200)
pd.set_option("display.max_rows", 50)

# ================ HELPERS ================
def _safe_col(points_df, preferred, fallbacks):
    if preferred in points_df.columns:
        return preferred
    for c in fallbacks:
        if c in points_df.columns:
            return c
    return None

def _print_section(h):
    print("\n" + "="*24 + f" {h} " + "="*24)

def _format_pairs_inline(series_like):
    pairs = [f"{int(k)}:{int(series_like.loc[k])}" for k in sorted(series_like.index)]
    return ", ".join(pairs) + "."

def debug_points_df(points_df, tag=""):
    _print_section(f"[DEBUG] points_df summary {tag}")
    if points_df is None or len(points_df) == 0:
        print("points_df is None or EMPTY")
        return
    print(f"shape: {points_df.shape}")
    print(f"columns: {list(points_df.columns)}")
    print(points_df.head(5))
    col_i = _safe_col(points_df, "ins_idx", ["ins_idx", "insertion_idx", "idx", "insertion"])
    if col_i:
        ii = pd.to_numeric(points_df[col_i], errors="coerce").dropna().astype(int)
        print(f"insertion: min={ii.min()}, max={ii.max()}, unique_count={ii.nunique()}, total={len(ii)}")
        cnt_by_ins = ii.value_counts().sort_index()
        print("insertion counts: " + _format_pairs_inline(cnt_by_ins))
    if "sample_id" in points_df.columns:
        sid = points_df["sample_id"].dropna()
        print(f"sample_id: unique={sid.nunique()}, total_rows_with_id={sid.shape[0]}")

# ================= MAIN =================
fig = go.Figure()

all_ids_sorted = sorted(set(b_dict.keys()) & set(m_dict.keys()))
print(f"[GenAI] total available samples: {len(all_ids_sorted)}")

for N in SAMPLE_SIZES:
    ids_sel = all_ids_sorted[:min(N, len(all_ids_sorted))]
    print(f"[GenAI] N={N}: using {len(ids_sel)} samples")

    b_dict_filt = {k: b_dict[k] for k in ids_sel if k in b_dict}
    m_dict_filt = {k: m_dict[k] for k in ids_sel if k in m_dict}

    _fig_cut, points_df = global_beta_plot_from_raw(
        b_dict_filt, m_dict_filt, all_series, x,
        hist_steps=hist_steps,
        gen_block=gen_block,
        num_insertions=num_insertions,
        beta_theory=0.5,
        samples_used=None,
        special_first=(79, 15),
        nbins_hist=30
    )
    if DEBUG:
        debug_points_df(points_df, tag=f"(N={N})")
    if points_df is None or len(points_df) == 0:
        continue

    COL_INS = "insertion" if "insertion" in points_df.columns else _safe_col(points_df, "insertion", ["ins_idx", "idx"])
    COL_X   = "x" if "x" in points_df.columns else _safe_col(points_df, "x", ["log_qv"])
    COL_Y   = "y" if "y" in points_df.columns else _safe_col(points_df, "y", ["log_impact"])

    ins_series = pd.to_numeric(points_df[COL_INS], errors="coerce").dropna().astype(int)
    ins_min, ins_max = int(ins_series.min()), int(ins_series.max())

    lr = LinearRegression()

    # Tail [a:]
    a_tail = list(range(ins_min, ins_max + 1))
    betas_tail = []
    for a in a_tail:
        sub = points_df[points_df[COL_INS] >= a]
        betas_tail.append(lr.fit(sub[COL_X].to_numpy().reshape(-1, 1), sub[COL_Y]).coef_[0]
                          if len(sub) >= 2 else np.nan)

    # Prefix [:a]
    a_pref = list(range(max(START_PREFIX_AT, 1), ins_max + 1))
    betas_pref = []
    for a in a_pref:
        sub = points_df[points_df[COL_INS] <= a] if PREFIX_MODE == "le" else points_df[points_df[COL_INS] < a]
        betas_pref.append(lr.fit(sub[COL_X].to_numpy().reshape(-1, 1), sub[COL_Y]).coef_[0]
                          if len(sub) >= 2 else np.nan)

    # Exact [exact @ a]
    a_exact = list(range(ins_min, ins_max + 1))
    betas_exact = []
    for a in a_exact:
        sub = points_df[points_df[COL_INS] == a]
        betas_exact.append(lr.fit(sub[COL_X].to_numpy().reshape(-1, 1), sub[COL_Y]).coef_[0]
                           if len(sub) >= 2 else np.nan)

    # Plot
    fig.add_trace(go.Scatter(x=a_tail, y=betas_tail, mode="lines+markers", name=f"N={N} [a:] (tail)"))
    fig.add_trace(go.Scatter(x=a_pref, y=betas_pref, mode="lines+markers", name=f"N={N} [:a] (prefix)"))
    fig.add_trace(go.Scatter(x=a_exact, y=betas_exact, mode="lines+markers", name=f"N={N} [exact @ a]"))

# Guides
if len(fig.data) > 0:
    all_x = np.concatenate([np.asarray(tr.x, dtype=float) for tr in fig.data if len(tr.x)])
    x_min, x_max = int(np.nanmin(all_x)), int(np.nanmax(all_x))
else:
    x_min, x_max = 1, 20

fig.add_trace(go.Scatter(x=[x_min, x_max], y=[0.5, 0.5], mode="lines", line=dict(dash="dash"), name="y = 0.5"))
fig.add_trace(go.Scatter(x=[x_min, x_max], y=[0, 0], mode="lines", name="y = 0"))

fig.update_xaxes(title_text="a (insertion index threshold)", dtick=1)
fig.update_yaxes(title_text="β (slope)")
fig.update_layout(
    title=f"Global β vs a — tail, prefix, exact for GenAI across N",
    template="plotly_white", width=980, height=560
)
fig.show()

[GenAI] total available samples: 101
[GenAI] N=25: using 25 samples

shape: (436, 4)
columns: ['sample_id', 'insertion', 'x', 'y']
   sample_id  insertion         x         y
0         97          2  2.067909 -3.602896
1         97          3  0.358161 -1.734673
2         97          4  2.090601 -0.471050
3         97          5  1.673771  0.162512
4         97          6  0.912638  0.265997
insertion: min=2, max=20, unique_count=19, total=436
insertion counts: 2:13, 3:20, 4:21, 5:21, 6:21, 7:23, 8:23, 9:23, 10:23, 11:24, 12:24, 13:25, 14:25, 15:25, 16:25, 17:25, 18:25, 19:25, 20:25.
sample_id: unique=25, total_rows_with_id=436
[GenAI] N=50: using 50 samples

shape: (867, 4)
columns: ['sample_id', 'insertion', 'x', 'y']
   sample_id  insertion         x         y
0         97          2  2.067909 -3.602896
1         97          3  0.358161 -1.734673
2         97          4  2.090601 -0.471050
3         97          5  1.673771  0.162512
4         97          6  0.912638  0.265997
insert

# check if samples are right with file

In [10]:
from filtration_utils import jupyter_show_all_ids

# your existing dict:
# experiments = {"Plain": "...", "Heuristic": "...", "GenAI": "..."}

ids_info = jupyter_show_all_ids(experiments)

# Access full ordered lists without sorting:
plain_ids = ids_info["Plain"]["unique_ordered"]
heur_ids  = ids_info["Heuristic"]["unique_ordered"]
genai_ids = ids_info["GenAI"]["unique_ordered"]

# If you want exact per-file batches & iterations:
plain_batches = ids_info["Plain"]["per_file"]


== Plain ==
Total IDs (unique, first-seen order): 352
First 50: [3398, 945, 3019, 1832, 6040, 3521, 1459, 5276, 5317, 2178, 3859, 1730, 355, 4099, 2287, 4810, 5548, 6676, 6779, 2424, 1638, 2934, 1548, 2634, 869, 5837, 2713, 5383, 79, 402, 3959, 7212, 5527, 4242, 1062, 4386, 5499, 339, 4819, 797, 6031, 5446, 739, 4418, 4238, 275, 2107, 4834, 3087, 6240]

== Heuristic ==
Total IDs (unique, first-seen order): 360
First 50: [3398, 945, 3019, 1832, 6040, 3521, 1459, 5276, 5317, 2178, 3859, 1730, 355, 4099, 2287, 4810, 5548, 6676, 6779, 2424, 1638, 2934, 1548, 2634, 869, 5837, 2713, 5383, 79, 402, 3959, 7212, 5527, 4242, 1062, 4386, 5499, 339, 4819, 797, 6031, 5446, 739, 4418, 4238, 275, 2107, 4834, 3087, 6240]

== GenAI ==
Total IDs (unique, first-seen order): 168
First 50: [3398, 945, 3019, 1832, 410, 4610, 3384, 3588, 3374, 2414, 3786, 879, 1928, 2921, 461, 97, 1833, 6991, 5093, 899, 2088, 5221, 6453, 2023, 4238, 275, 2107, 4834, 2097, 6583, 4121, 6072, 5499, 339, 4819, 797, 1652, 6865, 

In [11]:
import json
import pandas as pd

# Load the JSON file
with open("/app/sample_indices_b250_bs4_ins15_cool35.json", "r") as f:
    data = json.load(f)

# Convert to DataFrame
df = pd.DataFrame(data)

# Display the table (each column = one ID in the batch)
df

Unnamed: 0,0,1,2,3
0,2640,4030,1086,2674
1,6931,4464,4402,2408
2,1638,2934,1548,2634
3,6866,6936,2681,6698
4,869,5837,2713,5383
...,...,...,...,...
245,1147,5895,6189,1771
246,3973,2113,1153,3211
247,2192,5745,485,1927
248,5402,1166,6837,1250


In [12]:
import pandas as pd

def ids_to_list_column(ids_list, per_row=4):
    # Group into batches of `per_row` and store as lists
    return [ids_list[i:i+per_row] for i in range(0, len(ids_list), per_row)]

# Build columns for each experiment
col_plain = ids_to_list_column(plain_ids,  per_row=4)
col_heur  = ids_to_list_column(heur_ids,   per_row=4)
col_genai = ids_to_list_column(genai_ids,  per_row=4)

# Align lengths by padding with None
max_len = max(len(col_plain), len(col_heur), len(col_genai))
col_plain += [None] * (max_len - len(col_plain))
col_heur  += [None] * (max_len - len(col_heur))
col_genai += [None] * (max_len - len(col_genai))

# Create DataFrame
df_all = pd.DataFrame({
    "Plain": col_plain,
    "Heuristic": col_heur,
    "GenAI": col_genai
})

display(df_all)

Unnamed: 0,Plain,Heuristic,GenAI
0,"[3398, 945, 3019, 1832]","[3398, 945, 3019, 1832]","[3398, 945, 3019, 1832]"
1,"[6040, 3521, 1459, 5276]","[6040, 3521, 1459, 5276]","[410, 4610, 3384, 3588]"
2,"[5317, 2178, 3859, 1730]","[5317, 2178, 3859, 1730]","[3374, 2414, 3786, 879]"
3,"[355, 4099, 2287, 4810]","[355, 4099, 2287, 4810]","[1928, 2921, 461, 97]"
4,"[5548, 6676, 6779, 2424]","[5548, 6676, 6779, 2424]","[1833, 6991, 5093, 899]"
...,...,...,...
85,"[865, 1225, 5874, 5305]","[3919, 201, 6004, 6542]",
86,"[3306, 251, 947, 2413]","[7182, 2137, 4796, 1907]",
87,"[5898, 699, 2074, 2026]","[865, 1225, 5874, 5305]",
88,,"[3306, 251, 947, 2413]",
