In [1]:
import sys
import os

# Go up one level to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [2]:
import torch, math
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from util.config_util import dotdict

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal

In [4]:
import typing as T

def shift_axis(distance_to_origin, scaler):
    if scaler is None:
        return distance_to_origin
    scaled_offset = (
        torch.mul(scaler, distance_to_origin[:, -1])
        .unsqueeze(-1)
        .expand(-1, distance_to_origin.shape[-1])
    )
    return torch.sub(distance_to_origin, scaled_offset)

def noise_scale_sampling(device: str = "cpu"):
    rand = np.random.rand()
    if rand <= 0.4: 
        noise = Uniform(0, 0.2).sample()   # very low noise
    elif rand <= 0.8:
        noise = Uniform(0.3, 0.7).sample() # moderate noise
    else:
        noise = Uniform(0.8, 1.2).sample() # high noise

    return noise.to(device)


In [5]:
hyperprior_params = dotdict({
    'resolution_min': 25, 'resolution_max': 48, 'resolution_multiplier': 25,

    'trend_lin_min': -0.6,  'trend_lin_max': 0.6,  'trend_lin_fixed_variance': 0.15,
    'trend_exp_min': 1,  'trend_exp_max': 2, 'trend_exp_fixed_variance': 0.010,
    'trend_exp_multiplier': 100,
    'offset_lin_min': -0.5, 'offset_lin_max': 0.5,
    'offset_exp_min': -0.5, 'offset_exp_max': 0.5,

    'a_min': -1.5,  'a_max': 1.5,  'a_fixed_variance': 0.35,
    'm_min': -3,  'm_max': 3,  'm_fixed_variance': 0.35,
    'w_min': -6,  'w_max': 6,  'w_fixed_variance': 0.35,
    'harmonics_min': 1,     'harmonics_max': 12,

    'noise_k_min': 0.3,    'noise_k_max': 3,

    'amplitude_min': 1.2,  'amplitude_max': 2
})

In [6]:
from typing import Union

def _phase01(x: torch.Tensor, period: float) -> torch.Tensor:
    # continuous phase in [0,1)
    return torch.remainder(x, period) / period

def get_freq_component(
    phase01: torch.Tensor,          # normalized phase in [0,1)
    n_harmonics: Union[int, torch.Tensor],
    device: str = "cpu",
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    H = int(n_harmonics if isinstance(n_harmonics, int) else n_harmonics.item())
    B, T = phase01.shape
    h = torch.arange(1, H + 1, device=device, dtype=dtype).view(1, H, 1)  # [1,H,1]
    phi = phase01.to(device=device, dtype=dtype).unsqueeze(1)              # [B,1,T]

    # random coefficients ~ N(0, 1/h), normalized
    std = 1.0 / torch.arange(1, H + 1, device=device, dtype=dtype)
    sin_c = torch.normal(mean=torch.zeros(H, device=device, dtype=dtype), std=std).view(1, H, 1)
    cos_c = torch.normal(mean=torch.zeros(H, device=device, dtype=dtype), std=std).view(1, H, 1)
    norm = torch.sqrt((sin_c ** 2 + cos_c ** 2).sum())  # scalar; stable scale across batches
    sin_c /= norm
    cos_c /= norm

    sin = (sin_c * torch.sin(2 * math.pi * h * phi)).sum(dim=1)  # [B,T]
    cos = (cos_c * torch.cos(2 * math.pi * h * phi)).sum(dim=1)  # [B,T]
    return (sin + cos) / math.sqrt(2.0 * H)                      # zero-mean, unit-ish var


In [7]:
def sample_from_hyperpriors(hyperprior_params, n_sequence, device):
    hpp = hyperprior_params
    
    def _single_hier_normal(lo, hi, fixed_std):
        mu = Uniform(lo, hi).sample().to(device)                # scalar
        return Normal(mu, fixed_std).sample().to(device).view(1) # [1]

    def _single_uniform(lo, hi):
        return Uniform(lo, hi).sample().to(device).view(1)       # [1]
    result = dotdict()

    for param, min_val, max_val, fixed_variance in [
        ("annual_param", hpp.a_min, hpp.a_max, hpp.a_fixed_variance),
        ("monthly_param", hpp.m_min, hpp.m_max, hpp.m_fixed_variance),
        ("weekly_param", hpp.w_min, hpp.w_max, hpp.w_fixed_variance),
        ("trend_lin", hpp.trend_lin_min, hpp.trend_lin_max, hpp.trend_lin_fixed_variance),  # noqa
    ]:
        result[param] = _single_hier_normal(min_val, max_val, fixed_variance)


    # make it equally likely to have a positive or negative exp trend

    mm = hpp.trend_exp_multiplier
    span_upper = n_sequence / hpp.resolution_min
    mm_eff = mm / span_upper
    f_exp      = lambda x: 2 ** ((x - 1) * mm_eff)
    f_exp_inv  = lambda x: (torch.log2(x) / mm_eff) + 1


    g_min = f_exp(torch.tensor(hpp.trend_exp_min, device=device))
    g_max = f_exp(torch.tensor(hpp.trend_exp_max, device=device))
    g     = _single_hier_normal(g_min, g_max, hpp.trend_exp_fixed_variance)  # factor, [1]
    result.trend_exp = f_exp_inv(g)  # [1]

    # ensure consistent sign for trends

    median_lin_sign = result.trend_lin.median().sign()
    result.trend_lin = result.trend_lin.abs() * median_lin_sign

    assert (result.trend_lin >= 0).all() or (
        result.trend_lin <= 0
    ).all(), f"non-consistent sign {result.trend_lin=} in trend_lin"

    median_exp_sign = (result.trend_exp - 1).median().sign()
    result.trend_exp = (result.trend_exp - 1).abs() * median_exp_sign + 1

    assert (result.trend_exp >= 1).all() or (
        result.trend_exp <= 1
    ).all(), f"non-consistent {result.trend_exp=} in trend_exp"

    # sub-context-specific params

    result.noise_k = _single_uniform(hpp.noise_k_min, hpp.noise_k_max)

    result.noise_scale = noise_scale_sampling(device=device)

    # domain-specific params
    result.amplitude    = _single_uniform(hpp.amplitude_min, hpp.amplitude_max)
    result.offset_lin = _single_uniform(hpp.offset_lin_min, hpp.offset_lin_max)
    result.offset_exp = _single_uniform(hpp.offset_exp_min, hpp.offset_exp_max)
    result.harmonics = torch.randint(hpp.harmonics_min, hpp.harmonics_max, (3,), device=device)

    # keep the n-days at a set median
    mm = hpp.resolution_multiplier
    f_res = lambda x: torch.log2(x * mm + 1)
    f_res_inv = lambda x: (2**x - 1) / mm

    rmin = f_res(torch.tensor(hpp.resolution_min, device=device))
    rmax = f_res(torch.tensor(hpp.resolution_max, device=device))
    result.resolution = f_res_inv(Uniform(rmin, rmax).sample().to(device)).view(1)    # [1]

    result.n_units = torch.ceil(n_sequence / result.resolution)

    return result


In [8]:
min_exp_scaler = 0.00001
def get_shift_and_span(x):
    B, T = x.shape
    origin = x[:, 0].unsqueeze(-1).expand(B, T)
    distance_to_origin = torch.sub(x, origin)
    span = (x[:, -1] - x[:, 0]).unsqueeze(-1).expand_as(x).clamp_min(1e-8)
    return distance_to_origin, span

def gen_linear_trend(x, trend_lin, offset_lin, distance_to_origin, span):
    B, T = x.shape
    linear_trend = torch.zeros_like(x)
    
    trend_linear_scaler = trend_lin.unsqueeze(-1).expand(B, T)
    linear_trend = torch.mul(
        shift_axis(distance_to_origin, offset_lin) / span, trend_linear_scaler
    )
    return linear_trend

def gen_exp_trend(x, trend_exp, offset_exp, distance_to_origin, span):
    B, T = x.shape
    log_base = torch.log(trend_exp.clip(min=min_exp_scaler)).unsqueeze(-1).expand(B, T)
    exponent = shift_axis(distance_to_origin, offset_exp) / span
    log_term = (log_base * exponent).clamp(min=-20.0, max=20.0)
    exp_trend = torch.exp(log_term)
    return exp_trend

def gen_trend(x, trend_lin, offset_lin, trend_exp, offset_exp):
    distance_to_origin, span = get_shift_and_span(x)
    
    trend_comp_total = torch.ones_like(x)
    
    linear_trend = gen_linear_trend(x, trend_lin, offset_lin, distance_to_origin, span)
    trend_comp_total = torch.add(trend_comp_total, linear_trend)
    
    exp_trend = gen_exp_trend(x, trend_exp, offset_exp, distance_to_origin, span)
    trend_comp_total = torch.mul(trend_comp_total, exp_trend)
    
    return trend_comp_total, linear_trend, exp_trend

In [9]:
def gen_noise(x, noise_k, noise_scale, trend_comp_total, tau=3.0):
    B, T = x.shape
    k = noise_k.clamp_min(1e-6)
    lambda_med = 1.0 / (torch.log(torch.tensor(2.0)) ** (1.0 / k))
    wb = torch.distributions.Weibull(
        concentration=k.unsqueeze(-1).expand(B, T),
        scale=lambda_med.unsqueeze(-1).expand(B, T)
        )
    weibull_noise_term = wb.sample()
    delta = torch.tanh((weibull_noise_term - 1.0) / tau)        # ∈ (-1, 1)
    
    lvl_mean = trend_comp_total.abs().mean(dim=1, keepdim=True).add(1e-6)   # [B,1]
    lvl = (trend_comp_total.abs() / lvl_mean).clamp(0.1, 10.0).sqrt()       # [B,T]
    noise = 1 + noise_scale.unsqueeze(-1) * delta * lvl
    return noise

In [10]:
def gen_seasonal_component(comp_type, x, amp, n_harmonics):
    _PERIOD = {"weekly": 7.0, "monthly": 30.417, "annual": 365.25}
    period = _PERIOD[comp_type]
    base = get_freq_component(_phase01(x, period), n_harmonics, device=x.device.type, dtype=x.dtype)  # [B,T]
    return 1 + amp.unsqueeze(-1) * base
def gen_seasonal(x, n_harmonics, annual_param, monthly_param, weekly_param, cap=0.5, eps=1e-6):
    B, T = x.shape
    device, dtype = x.device, x.dtype

    total = torch.ones(B, T, device=device, dtype=dtype)

    total.mul_(gen_seasonal_component('annual',  x, annual_param,  n_harmonics[0]))
    total.mul_(gen_seasonal_component('monthly', x, monthly_param, n_harmonics[1]))
    total.mul_(gen_seasonal_component('weekly',  x, weekly_param,  n_harmonics[2]))
    z = total - total.mean(dim=1, keepdim=True)
    z = z / (total.std(dim=1, keepdim=True) + eps)
    s = 1.0 + cap * torch.tanh(z)                      # strictly positive, bounded
    s = s / (s.mean(dim=1, keepdim=True) + eps)        # mean 1

    return s

In [11]:
def _series_from_params(p, n_sequence, device):
    T = n_sequence
    lin = torch.linspace(0, 1, T, device=device)                 # [T]
    x = p.n_units.to(device)[:, None] * lin[None, :]             # [1, T]

    trend_comp, _, _ = gen_trend(
        x,
        trend_lin=p.trend_lin, offset_lin=p.offset_lin,
        trend_exp=p.trend_exp, offset_exp=p.offset_exp
    )                                                             # [1, T]

    total_seasonality = gen_seasonal(
        x, n_harmonics=p.harmonics,
        annual_param=p.annual_param, monthly_param=p.monthly_param, weekly_param=p.weekly_param
    )                                                             # [1, T]

    amp   = p.amplitude.to(device).unsqueeze(-1)                  # [1,1]
    noise = gen_noise(x, noise_k=p.noise_k, noise_scale=p.noise_scale, trend_comp_total=trend_comp)
    v = amp * trend_comp * total_seasonality * noise              # [1, T]
    return v.squeeze(0)                                           # [T]

def gen_series():
    n_sequence = 12000
    component_params = sample_from_hyperpriors(hyperprior_params, n_sequence, device)
    return _series_from_params(component_params, n_sequence, device)
gen_series().shape

torch.Size([12000])

In [12]:
import gradio as gr
import math

In [13]:
def sample_n_units(resolution_min, resolution_max, resolution_multiplier,
                   n_sequence=12000, n_samples=20000, seed=0):
    torch.manual_seed(int(seed))
    mm = float(resolution_multiplier)

    # transforms
    f_res     = lambda x: torch.log2(x * mm + 1.0)
    f_res_inv = lambda x: (2.0**x - 1.0) / mm

    # bounds in transformed space
    rmin = f_res(torch.tensor(float(resolution_min), dtype=torch.float32))
    rmax = f_res(torch.tensor(float(resolution_max), dtype=torch.float32))

    # sample resolution in transformed space, map back
    u = torch.distributions.Uniform(rmin, rmax).sample((int(n_samples),))
    resolution = f_res_inv(u)  # [n_samples]

    # discrete n_units
    n_units = torch.ceil(torch.tensor(float(n_sequence)) / resolution)  # [n_samples]: float tensor with integer values
    vmin = int(n_units.min().item())
    vmax = int(n_units.max().item())
    
    uniq, counts = torch.unique(n_units, sorted=True, return_counts=True)
    probs = counts.float() / counts.sum()
    mode = int(uniq[counts.argmax()].item())
    unique_ct = int(uniq.numel())
    entropy = float(-(probs * torch.log2(probs + 1e-12)).sum().item())
    
    x = n_units.float()
    mean = x.mean().item()
    std  = x.std(unbiased=False).item()
    q05, q50, q95 = torch.quantile(x, torch.tensor([0.05, 0.5, 0.95])).tolist()
    stats = {
        "uniq": uniq, "counts": counts, "mode": mode, "unique_ct": unique_ct, "probs": probs, "entropy": entropy,
        "vmin": vmin, "vmax": vmax, "mean": mean, "std": std, "q05": q05, "q50": q50, "q95": q95
    }
    return n_units, stats

In [14]:
def _theoretical_bounds(resolution_min, resolution_max, resolution_multiplier, n_sequence):
    mm = float(resolution_multiplier)
    f_res     = lambda x: torch.log2(x * mm + 1.0)
    f_res_inv = lambda x: (2.0**x - 1.0) / mm
    # min n_units occurs at largest resolution, and vice versa
    lo = math.ceil(n_sequence / f_res_inv(f_res(torch.tensor(float(resolution_max)))).item())
    hi = math.ceil(n_sequence / f_res_inv(f_res(torch.tensor(float(resolution_min)))).item())
    return int(lo), int(hi)

def plot_n_units_with_stats(resolution_min, resolution_max, resolution_multiplier,
                            n_sequence=12000, n_samples=20000, seed=0):
    n_units, stats = sample_n_units(resolution_min, resolution_max, resolution_multiplier,
                             n_sequence=n_sequence, n_samples=n_samples, seed=seed)

    lo_th, hi_th = _theoretical_bounds(resolution_min, resolution_max, resolution_multiplier, n_sequence)

    fig = plt.figure(figsize=(8, 4))
    plt.bar(stats["uniq"].numpy(), stats["probs"].numpy(), width=0.9, align="center")
    plt.xlabel("n_units")
    plt.ylabel("Probability")
    plt.title(f"Distribution of n_units (N={int(n_samples)})")
    plt.axvline(lo_th, linestyle="--")
    plt.axvline(hi_th, linestyle="--")
    plt.tight_layout()
    q05, q50, q95 = stats["q05"], stats["q50"], stats["q95"]

    md = (
        f"**n_units stats**\n\n"
        f"- theoretical min-max: {lo_th} - {hi_th}\n"
        f"- min-max: {stats['vmin']} - {stats['vmax']}\n"
        f"- mean: {stats['mean']:.3f}, std: {stats['std']:.3f}\n"
        f"- mode: {stats['mode']}  (unique values: {stats['unique_ct']})\n"
        f"- q05/median/q95: {q05:.1f} / {q50:.1f} / {q95:.1f}\n"
        f"- Step size Δx ~ q05/median/q95: {(q05 / n_sequence):.3f} / {(q50 / n_sequence):.3f} / {(q95 / n_sequence):.3f}\n"
        f"- entropy (bits): {stats['entropy']:.3f}\n"
        f"(max entropy = {math.log2(stats['unique_ct']):.3f})"
    )
    opts = [f"Min: {stats['vmin']}", f"Median: {int(round(stats['q50']))}", f"Max: {stats['vmax']}"]
    # return a ComponentUpdate for the Radio
    radio_update = gr.update(choices=opts, value=opts[1], label="Select n_units")
    return fig, md, radio_update, radio_update, radio_update, radio_update

In [15]:
import itertools

def plot_trends(
    n_units_choice, n_sequence: int = 12000, n_samples: int = 20000, seed: int = 0, 
    # trend param bounds
    trend_lin_min: float = -0.02, offset_lin_min: float = -0.5, trend_exp_min: float = 1.0, offset_exp_min: float = -0.5,
    trend_lin_max: float = 0.02,  offset_lin_max: float = 0.5,  trend_exp_max: float = 2.0,  offset_exp_max: float = 0.5,
):
    
    n_units = torch.as_tensor([float(n_units_choice.split()[-1])]) # [1]

    T = int(n_sequence)
    lin = torch.linspace(0.0, 1.0, T)  # [T]
    x = n_units[:, None] * lin[None, :]

    distance_to_origin, span = get_shift_and_span(x)
    offset_lin_dict = {"min": offset_lin_min, "avg": 0.5 * (offset_lin_min + offset_lin_max), "max": offset_lin_max}
    offset_exp_dict = {"min": offset_exp_min, "avg": 0.5 * (offset_exp_min + offset_exp_max), "max": offset_exp_max}
    trend_lin_vals  = np.linspace(trend_lin_min, trend_lin_max, 5)  # 5 lines per panel
    trend_exp_vals  = np.linspace(trend_exp_min, trend_exp_max, 5)

    fig_lin, axes_lin = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
    for i, (okey, ovalue) in enumerate(offset_lin_dict.items()):
        ax = axes_lin[i]
        stats_lines = []
        for tl in trend_lin_vals:
            tl_t = torch.tensor([tl], dtype=torch.float32)
            ol_t = torch.tensor([ovalue], dtype=torch.float32)
            lt = gen_linear_trend(x, tl_t, ol_t, distance_to_origin, span).squeeze(0)  # [T]
            y = lt.detach().cpu().numpy()
            ax.plot(range(T), y, label=f"trend_lin={tl:.3f}")
            stats_lines.append(f"lin={tl:+.3f} → mean={y.mean():+.3g}, std={y.std():.3g}, min={y.min():+.3g}, max={y.max():+.3g}")
        ax.set_title(f"offset_lin={okey} ({ovalue:+.3f})")
        ax.set_ylabel("linear_trend")
        ax.legend(fontsize=8, ncol=3)
        ax.text(0.0, -0.35, "\n".join(stats_lines), transform=ax.transAxes, fontsize=8, va="top", ha="left", family="monospace")
    axes_lin[-1].set_xlabel("t")
    fig_lin.tight_layout(rect=[0, 0, 1, 0.97])

    fig_exp, axes_exp = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
    for i, (okey, ovalue) in enumerate(offset_exp_dict.items()):
        ax = axes_exp[i]
        stats_lines = []
        for te in trend_exp_vals:
            te_t = torch.tensor([te], dtype=torch.float32)
            oe_t = torch.tensor([ovalue], dtype=torch.float32)
            et = gen_exp_trend(x, te_t, oe_t, distance_to_origin, span).squeeze(0)  # [T]
            y = et.detach().cpu().numpy()
            ax.plot(range(T), y, label=f"trend_exp={te:.3f}")
            stats_lines.append(f"exp={te:+.3f} → mean={y.mean():+.3g}, std={y.std():.3g}, min={y.min():+.3g}, max={y.max():+.3g}")
        ax.set_title(f"offset_exp={okey} ({ovalue:+.3f})")
        ax.set_ylabel("exp_trend")
        ax.legend(fontsize=8, ncol=3)
        ax.text(0.0, -0.35, "\n".join(stats_lines), transform=ax.transAxes, fontsize=8, va="top", ha="left", family="monospace")
    axes_exp[-1].set_xlabel("t")
    fig_exp.tight_layout(rect=[0, 0, 1, 0.97])

    ol_avg = torch.tensor([offset_lin_dict['avg']], dtype=torch.float32)
    oe_avg = torch.tensor([offset_exp_dict['avg']], dtype=torch.float32)
    tl_min = torch.tensor([trend_lin_min], dtype=torch.float32)
    tl_max = torch.tensor([trend_lin_max], dtype=torch.float32)

    fig_tot, axes_tot = plt.subplots(3, 2, figsize=(12, 10), sharex=True)
    
    rep_trends = {}
    combos = [
        (0, 0, 0), # "low-low-low"
        (0, 4, 2), # "low-high-high"
        (1, 0, 2), # "high-low-high" 
        (1, 4, 0), # "high-high-low"
    ]
    for r, (okey, ovalue) in enumerate(offset_exp_dict.items()):
        for c, (tl_lbl, tl_val) in enumerate([("min", tl_min), ("max", tl_max)]):
            ax = axes_tot[r, c]
            stats_lines = []
            for t, te in enumerate(trend_exp_vals):
                te_t = torch.tensor([te], dtype=torch.float32)
                oe_t = torch.tensor([ovalue], dtype=torch.float32)
                total, lin_comp, exp_comp = gen_trend(x, tl_val, ol_avg, te_t, oe_t)  # [1,T]
                y = total.squeeze(0).detach().cpu().numpy()
                if (c, t, r) in combos: 
                    rep_trends[(c, t, r)] = y
                ax.plot(range(T), y, label=f"trend_exp={te:.3f}")
                stats_lines.append(f"tot(tl={tl_lbl}, exp={te:+.3f}) → mean={y.mean():+.3g} std={y.std():.3g} "
                                   f"min={y.min():+.3g} max={y.max():+.3g}")
            ax.set_title(f"offset_exp={okey} ({ovalue:+.3f}) | trend_lin={tl_lbl} ({tl_val.item():+.3f})")
            ax.set_ylabel("total_trend")
            if r == 2:
                ax.set_xlabel("t")
            ax.legend(fontsize=8, ncol=3)
            ax.text(0.0, -0.35, "\n".join(stats_lines), transform=ax.transAxes, fontsize=8,
                    va="top", ha="left", family="monospace")
    fig_tot.tight_layout(rect=[0, 0, 1, 0.97])
    
    tl_avg, te_avg = torch.tensor([trend_lin_vals[2]], dtype=torch.float32), torch.tensor([trend_exp_vals[2]], dtype=torch.float32)
    tot, _, _ = gen_trend(x, tl_avg, ol_avg, te_avg, oe_avg)    # center
    rep_trends[(2, 2, 1)] = tot.squeeze(0).detach().cpu().numpy()

    return fig_lin, fig_exp, fig_tot, rep_trends

In [16]:
def plot_seasonality_components(
    n_units_choice, n_sequence: int = 12_000,
    a: float = 3.0, m: float = 6.0, w: float = 9.0,     # amplitude bounds
    harmonics_min: int = 2, harmonics_max: int = 12,    # harmonics bounds
):
    a_min, a_max = -a, a
    m_min, m_max = -m, m
    w_min, w_max = -w, w
    # timeline
    n_units = torch.as_tensor([float(n_units_choice.split()[-1])])  # [1]
    T = int(n_sequence)
    lin = torch.linspace(0.0, 1.0, T)                               # [T]
    x = n_units[:, None] * lin[None, :]                              # [1,T]
    device, dtype = x.device, x.dtype

    # midpoints
    a_avg, m_avg, w_avg = 0.5 * (a_min + a_max), 0.5 * (m_min + m_max), 0.5 * (w_min + w_max)
    h_avg = int(0.5 * (harmonics_min + harmonics_max))

    a_vals, m_vals, w_vals = np.linspace(a_min, a_max, 5), np.linspace(m_min, m_max, 5), np.linspace(w_min, w_max, 5)
    h_vals = [harmonics_min, h_avg, harmonics_max]

    fig_ann, axes_ann = plt.subplots(3, 1, figsize=(10, 16), sharex=True)
    for i, hv in enumerate(h_vals):
        ax = axes_ann[i]
        stats = []
        for v in a_vals:
            pv = torch.tensor([v], dtype=dtype)
            comp = gen_seasonal_component('annual', x, pv, hv).squeeze(0).cpu().numpy()
            ax.plot(range(T), comp, label=f"a={v:+.2f}, h={hv}")
            stats.append(f"{v:+.2f} → mean={comp.mean():+.3g} std={comp.std():.3g} min={comp.min():+.3g} max={comp.max():+.3g}")
        ax.set_title(f"Annual component — harmonics={hv}")
        ax.legend(fontsize=8, ncol=3)
        ax.text(0.0, -0.35, "\n".join(stats), transform=ax.transAxes, fontsize=8,
                va="top", ha="left", family="monospace")
    axes_ann[-1].set_xlabel("t")
    fig_ann.tight_layout(rect=[0, 0, 1, 0.97])

    fig_mon, axes_mon = plt.subplots(3, 1, figsize=(10, 16), sharex=True)
    for i, hv in enumerate(h_vals):
        ax = axes_mon[i]
        stats = []
        for v in m_vals:
            pv = torch.tensor([v], dtype=dtype)
            comp = gen_seasonal_component('monthly', x, pv, hv).squeeze(0).cpu().numpy()
            ax.plot(range(T), comp, label=f"m={v:+.2f}, h={hv}")
            stats.append(f"{v:+.2f} → mean={comp.mean():+.3g} std={comp.std():.3g} min={comp.min():+.3g} max={comp.max():+.3g}")
        ax.set_title(f"Monthly component — harmonics={hv}")
        ax.legend(fontsize=8, ncol=3)
        ax.text(0.0, -0.35, "\n".join(stats), transform=ax.transAxes, fontsize=8,
                va="top", ha="left", family="monospace")
    axes_mon[-1].set_xlabel("t")
    fig_mon.tight_layout(rect=[0, 0, 1, 0.97])

    fig_wek, axes_wek = plt.subplots(3, 1, figsize=(10, 16), sharex=True)
    for i, hv in enumerate(h_vals):
        ax = axes_wek[i]
        stats = []
        for v in w_vals:
            pv = torch.tensor([v], dtype=dtype)
            comp = gen_seasonal_component('weekly', x, pv, hv).squeeze(0).cpu().numpy()
            ax.plot(range(T), comp, label=f"w={v:+.2f}, h={hv}")
            stats.append(f"{v:+.2f} → mean={comp.mean():+.3g} std={comp.std():.3g} min={comp.min():+.3g} max={comp.max():+.3g}")
        ax.set_title(f"Weekly component — harmonics={hv}")
        ax.legend(fontsize=8, ncol=3)
        ax.text(0.0, -0.35, "\n".join(stats), transform=ax.transAxes, fontsize=8,
                va="top", ha="left", family="monospace")
    axes_wek[-1].set_xlabel("t")
    fig_wek.tight_layout(rect=[0, 0, 1, 0.97])

    # Fix two amplitudes at their midpoints, sweep the third; repeat for h in {min,avg,max}
    fig_tot, axes_tot = plt.subplots(3, 3, figsize=(20, 10), sharex=True, sharey=False)
    sweeps = [
        ('annual',  a_vals, ('monthly','weekly')),
        ('monthly', m_vals, ('annual','weekly')),
        ('weekly',  w_vals, ('annual','monthly')),
    ]
    rep_seasonal = {}
    name_list = ['min', 'avg', 'max']
    for r, hv in enumerate(h_vals):
        for c, (name, vals, _) in enumerate(sweeps):
            ax = axes_tot[r, c]
            stats = []
            for t, v in enumerate(vals):
                a = torch.tensor([a_max], dtype=dtype)
                m = torch.tensor([m_max], dtype=dtype)
                w = torch.tensor([w_max], dtype=dtype)
                if name == 'annual':  a = torch.tensor([v], dtype=dtype)
                if name == 'monthly': m = torch.tensor([v], dtype=dtype)
                if name == 'weekly':  w = torch.tensor([v], dtype=dtype)
                ha = hv if name == 'annual'  else h_avg
                hm = hv if name == 'monthly' else h_avg
                hw = hv if name == 'weekly'  else h_avg
                total = gen_seasonal(x, (ha, hm, hw), a, m, w).squeeze(0).cpu().numpy()
                if t == 4:
                    rep_seasonal[(name_list[r], 'max')] = total
                ax.plot(range(T), total, label=f"{name}={v:+.2f}")
                stats.append(f"{v:+.2f} → mean={total.mean():+.3g} std={total.std():.3g} min={total.min():+.3g} max={total.max():+.3g}")
            ax.set_title(f"Total — sweep {name} amp @ its harmonics={hv}")
            if r == 2:
                ax.set_xlabel("t")
            ax.legend(fontsize=7, ncol=3)
            ax.text(0.0, -0.35, "\n".join(stats), transform=ax.transAxes, fontsize=8,
                    va="top", ha="left", family="monospace")
    fig_tot.tight_layout(rect=[0, 0, 1, 0.97])
    for r, hv in enumerate(h_vals):
        total = gen_seasonal(x, (hv, hv, hv), torch.tensor([a_min], dtype=dtype), torch.tensor([m_min], dtype=dtype), torch.tensor([w_min], dtype=dtype)).squeeze(0).cpu().numpy()
        rep_seasonal[(name_list[r], 'min')] = total
    return fig_ann, fig_mon, fig_wek, fig_tot, rep_seasonal

In [17]:
def plot_noise(
    n_units_choice,
    n_sequence: int = 12_000,
    noise_k_min: float = 0.3,
    noise_k_max: float = 3.0,
    trend_bases = None,
    scale_levels = (0.1, 0.5, 1.0), # representative scales for three buckets (very low, moderate, high)
):
    n_units = torch.as_tensor([float(n_units_choice.split()[-1])])  # [1]
    T = int(n_sequence)
    lin = torch.linspace(0.0, 1.0, T)
    x = n_units[:, None] * lin[None, :]                              # [1,T]
    device, dtype = x.device, x.dtype
    figs = []
    rep_noise = {}
    for name, trend in trend_bases.items():
        trend_comp_total = torch.tensor(trend, dtype=torch.float32).unsqueeze(0)
        if trend_comp_total is None:
            trend_comp_total = torch.ones_like(x)                        # [1,T]

        k_vals = np.linspace(noise_k_min, noise_k_max, 5)

        # rows = 3 noise_scale levels, cols = 5 values
        fig, axes = plt.subplots(len(scale_levels), 5, figsize=(2.7*7, 2.7*len(scale_levels)), sharex=True, sharey=False)
        if len(scale_levels) == 1: axes = np.expand_dims(axes, 0)
        for r, scale in enumerate(scale_levels):
            for c, kv in enumerate(k_vals):
                ax = axes[r, c]
                k_t = torch.full((1,), float(kv),  device=device, dtype=dtype)     # [B]
                s_t = torch.full((1,), float(scale), device=device, dtype=dtype)   # [B]
                y = gen_noise(x, k_t, s_t, trend_comp_total).squeeze(0).cpu().numpy()
                ax.plot(range(T), y)
                if r == 0:
                    ax.set_title(f"k={kv:.2f}")
                if c == 0:
                    ax.set_ylabel(f"scale={scale:.2f}")
                if r == len(scale_levels)-1:
                    ax.set_xlabel("t")
                    if c in (0, 4):
                        rep_noise[(name, c)] = y
                ax.text(0.02, 0.02,
                        f"mean={y.mean():+.3g} std={y.std():.3g} min={y.min():+.3g} max={y.max():+.3g}",
                        transform=ax.transAxes, fontsize=7, family="monospace")
        fig.suptitle("Multiplicative noise component: rows=scale, cols=noise_k")
        fig.tight_layout(rect=[0, 0, 1, 0.97])
        figs.append(fig)
    return figs + [rep_noise]

In [18]:
def get_hyperparams(resolution_min, resolution_max, resolution_multiplier,
                    trend_lin_min, trend_lin_max, offset_lin_min, offset_lin_max,
                    trend_exp_min, trend_exp_max, offset_exp_min, offset_exp_max,
                    a, m, w, harmonics_min, harmonics_max,
                    noise_k_min, noise_k_max, amplitude_min, amplitude_max):
    return f"""'resolution_min': {resolution_min}, 'resolution_max': {resolution_max}, 'resolution_multiplier': {resolution_multiplier},

'trend_lin_min': {trend_lin_min},  'trend_lin_max': {trend_lin_max},  'trend_lin_fixed_variance': 0.15,
'trend_exp_min': {trend_exp_min},  'trend_exp_max': {trend_exp_max}, 'trend_exp_fixed_variance': 0.010,
'trend_exp_multiplier': 100,
'offset_lin_min': {offset_lin_min}, 'offset_lin_max': {offset_lin_max},
'offset_exp_min': {offset_exp_min}, 'offset_exp_max': {offset_exp_max},

'a_min': {-a},  'a_max': {a},  'a_fixed_variance': 0.35,
'm_min': {-m},  'm_max': {m},  'm_fixed_variance': 0.35,
'w_min': {-w},  'w_max': {w},  'w_fixed_variance': 0.35,
'harmonics_min': {harmonics_min},     'harmonics_max': {harmonics_max},

'noise_k_min': {noise_k_min},    'noise_k_max': {noise_k_max},

'amplitude_min': {amplitude_min},  'amplitude_max': {amplitude_max}"""

In [23]:
def _to_bt(arr):
    t = torch.as_tensor(arr, dtype=torch.float32)
    if t.ndim == 1: t = t.unsqueeze(0)   # [1,T]
    return t
def plot_total_series(n_units_choice, n_sequence, 
                      amplitude_min, amplitude_max,
                      trend_choice, seasonal_choice, harmonics_choice,
                      trend_bases_state, seasonal_state, noise_state):
    trend_map = {"low-low-low": (0, 0, 0),
                 "low-high-high": (0, 4, 2),
                 "high-low-high" : (1, 0, 2),
                 "high-high-low": (1, 4, 0),
                 "center": (2, 2, 1)}
    seasonal_map = {'Min', 'Max'}
    n_units = torch.as_tensor([float(n_units_choice.split()[-1])])  # [1]
    T = int(n_sequence)
    lin = torch.linspace(0.0, 1.0, T)                               # [T]
    x = n_units[:, None] * lin[None, :]                              # [1,T]
    
    trend_comp_arr = trend_bases_state[trend_map[trend_choice]]
    trend  = _to_bt(trend_comp_arr)           # [1,T]
    
    season_total_arr = seasonal_state[(harmonics_choice, seasonal_choice)]
    season = _to_bt(season_total_arr)         # [1,T]
    T = trend.shape[1]
    amp = torch.tensor([amplitude_min]).unsqueeze(-1)   # [1,1]

    # noiseless baseline
    base = amp * trend * season                        # [1,T]
    
    n_min = noise_state[(trend_map[trend_choice], 0)]
    n_max = noise_state[(trend_map[trend_choice], 4)]

    y0 = base.squeeze(0).cpu().numpy()                 # noiseless
    yA = (base * _to_bt(n_min)).squeeze(0).cpu().numpy()       # noisy @ k_min
    yB = (base * _to_bt(n_max)).squeeze(0).cpu().numpy()       # noisy @ k_max

    # components to display (noise factors only, not applied)
    tr_np = trend.squeeze(0).cpu().numpy()
    se_np = season.squeeze(0).cpu().numpy()

    # layout: 4 rows, 3 cols; row 0 has 3 panels; rows 1–3 span all cols
    fig = plt.figure(figsize=(14, 10), constrained_layout=True)
    outer = fig.add_gridspec(nrows=4, ncols=1, height_ratios=[1, 1, 1, 1])
    top = outer[0].subgridspec(1, 3, wspace=0.08)

    ax_tr   = fig.add_subplot(top[0, 0])
    ax_se   = fig.add_subplot(top[0, 1])
    ax_noise= fig.add_subplot(top[0, 2])

    ax_y0   = fig.add_subplot(outer[1, :])
    ax_yA   = fig.add_subplot(outer[2, :])
    ax_yB   = fig.add_subplot(outer[3, :])


    # components row
    ax_tr.plot(range(T), tr_np);      ax_tr.set_title("Trend component")
    ax_se.plot(range(T), se_np);      ax_se.set_title("Seasonality component")
    ax_noise.plot(range(T), n_min, label=f"min noise_k")
    ax_noise.plot(range(T), n_max, label=f"max noise_k")
    ax_noise.set_title(f"Noise factors")
    ax_noise.legend(fontsize=8)

    # final series rows
    ax_y0.plot(range(T), y0); ax_y0.set_title("Final — noiseless")
    ax_yA.plot(range(T), yA); ax_yA.set_title(f"Final — noisy @ min noise_k")
    ax_yB.plot(range(T), yB); ax_yB.set_title(f"Final — noisy @ max noise_k")
    ax_yB.set_xlabel("t")

    # fig.tight_layout(rect=[0, 0, 1, 0.97])
    return fig

In [25]:
with gr.Blocks() as demo:
    trend_bases_state = gr.State()
    seasonal_state = gr.State()
    noise_state = gr.State()
    with gr.Tab("n_units"):
        with gr.Row():
            resolution_multiplier = gr.Slider(1, 50, 25, step=1, label="Resolution Multiplier")
            resolution_min = gr.Slider(1, 365, 25, step=1, label="Min Resolution")
            resolution_max = gr.Slider(1, 365, 75, step=1, label="Max Resolution")
        with gr.Row():
            n_sequence = gr.Slider(128, 50000, 12000, step=128, label="n_sequence")
            n_samples = gr.Slider(1000, 200000, 20000, step=1000, label="num samples")
            seed = gr.Slider(0, 10_000, 0, step=1, label="seed")
        btn_n_units = gr.Button("Plot + Stats")
        with gr.Row():
            n_units_plot = gr.Plot(label="n_units distribution", scale=2)
            n_units_stats = gr.Markdown()
    with gr.Tab("Trend"):
        trend_radio = gr.Radio([], label="Generate n_units first")
        with gr.Column(visible=False) as trend_panel:
            with gr.Row():
                trend_lin_min = gr.Slider(-5, 5, value=-2.0, step=0.05, label="trend_lin_min")
                trend_lin_max = gr.Slider(-3, 7, value=2.0,  step=0.05, label="trend_lin_max")
            with gr.Row():
                offset_lin_min = gr.Slider(-2.0, 2.0, value=-1.5, step=0.01, label="offset_lin_min")
                offset_lin_max = gr.Slider(-2.0, 2.0, value=1.5,  step=0.01, label="offset_lin_max")
            with gr.Row():
                trend_exp_min = gr.Slider(0.50, 6.00, value=2.0, step=0.01, label="trend_exp_min")
                trend_exp_max = gr.Slider(0.50, 8.00, value=4.0, step=0.01, label="trend_exp_max")
            with gr.Row():
                offset_exp_min = gr.Slider(-2.0, 2.0, value=-1.5, step=0.01, label="offset_exp_min")
                offset_exp_max = gr.Slider(-2.0, 2.0, value=1.5,  step=0.01, label="offset_exp_max")

            btn_trend = gr.Button("Generate Trend Components")
            
            with gr.Tabs():
                with gr.TabItem("Linear Trend"):
                    linear_trend_plot = gr.Plot(label="linear_trend")
                with gr.TabItem("Exponential Trend"):
                    exp_trend_plot = gr.Plot(label="exp_trend")
                with gr.TabItem("Total Trend"):
                    total_trend_plot = gr.Plot(label="trend_total")
        
        def _on_radio_change(choice):
            return gr.update(visible=(choice is not None))
        trend_radio.change(_on_radio_change, inputs=trend_radio, outputs=trend_panel)
        
    with gr.Tab("Seasonal"):
        seasonal_radio = gr.Radio([], label="Generate n_units first")
        with gr.Column(visible=False) as season_panel:
            with gr.Row():
                a = gr.Slider(0, 20.0, value=2, step=0.1, label="a")
                m = gr.Slider(0, 20.0, value=4, step=0.1, label="m")
                w = gr.Slider(0, 20.0, value=8, step=0.1, label="w")
            with gr.Row():
                harmonics_min = gr.Slider(0, 10, value=1, step=1, label="harmonics_min")
                harmonics_max = gr.Slider(1, 20, value=12,  step=1, label="harmonics_max")
            btn_season = gr.Button("Generate Seasonal Components")
            
            with gr.Tabs():
                with gr.TabItem("Annual Seasonal"):
                    annual_comp_plot = gr.Plot(label="annual_comp")
                with gr.TabItem("Monthly Seasonal"):
                    monthly_comp_plot = gr.Plot(label="monthly_comp")
                with gr.TabItem("Weekly Seasonal"):
                    weekly_comp_plot = gr.Plot(label="weekly_comp")
                with gr.TabItem("Total Seasonal"):
                    total_comp_plot = gr.Plot(label="total_comp")
        seasonal_radio.change(_on_radio_change, inputs=seasonal_radio, outputs=season_panel)
    
    with gr.Tab("Noise"):
        noise_radio = gr.Radio([], label="Generate n_units first")
        with gr.Column(visible=False) as noise_panel:
            with gr.Row():
                noise_k_min = gr.Slider(0.1, 5.0, value=1.3, step=0.1, label="noise_k_min")
                noise_k_max = gr.Slider(0.2, 8.0, value=3.0, step=0.1, label="noise_k_max")
            btn_noise = gr.Button("Generate Noise")
            noise_plot_list = [None] * 5
            with gr.Tabs():
                for i, label in enumerate(["low-low-low", "low-high-high", "high-low-high", "high-high-low", "center"]):
                    with gr.TabItem("Base Trend: " + label):
                        noise_plot_list[i] = gr.Plot(label=f'noise_from_{label}')
        noise_radio.change(_on_radio_change, inputs=noise_radio, outputs=noise_panel)
        
    with gr.Tab("Total"):
        btn_gen_hpp = gr.Button("Get Hyperparams")
        hpp_text = gr.Textbox(lines=5)
        total_radio = gr.Radio([], label="Generate n_units first")
        with gr.Column(visible=False) as total_panel:
            with gr.Row():
                amplitude_min = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="amplitude_min")
                amplitude_max = gr.Slider(0.2, 5.0, value=3.0, step=0.1, label="amplitude_max")
            trend_total_radio = gr.Radio(['low-low-low', 'low-high-high', 'high-low-high', 'high-high-low', 'center'], label="Pick trend")
            with gr.Row():
                seasonal_total_radio = gr.Radio(['min', 'max'], label="Pick seasonality")
                harmonic_total_radio = gr.Radio(['min', 'avg', 'max'], label="Pick n_harmonics")
            btn_total = gr.Button("Generate Final Series")
            total_plot = gr.Plot(label="total_series")
            
        total_radio.change(_on_radio_change, inputs=total_radio, outputs=total_panel)
        btn_gen_hpp.click(fn=get_hyperparams,
                          inputs=[resolution_min, resolution_max, resolution_multiplier, 
                                  trend_lin_min, trend_lin_max, offset_lin_min, offset_lin_max, 
                                  trend_exp_min, trend_exp_max, offset_exp_min, offset_exp_max, 
                                  a, m, w, harmonics_min, harmonics_max, 
                                  noise_k_min, noise_k_max, amplitude_min, amplitude_max],
                          outputs=hpp_text)
    
    btn_n_units.click(
        fn=plot_n_units_with_stats,
        inputs=[resolution_min, resolution_max, resolution_multiplier, n_sequence, n_samples, seed],
        outputs=[n_units_plot, n_units_stats, trend_radio, seasonal_radio, noise_radio, total_radio]
    )
    btn_trend.click(
        fn=plot_trends,
        inputs=[trend_radio, n_sequence, n_samples, seed, 
                trend_lin_min, offset_lin_min, trend_exp_min, offset_exp_min, 
                trend_lin_max, offset_lin_max, trend_exp_max, offset_exp_max],
        outputs=[linear_trend_plot, exp_trend_plot, total_trend_plot, trend_bases_state]
    )
    btn_season.click(
        fn=plot_seasonality_components,
        inputs=[seasonal_radio, n_sequence, a, m, w, harmonics_min, harmonics_max],
        outputs=[annual_comp_plot, monthly_comp_plot, weekly_comp_plot, total_comp_plot, seasonal_state]
    )
    btn_noise.click(
        fn=plot_noise,
        inputs=[noise_radio, n_sequence, noise_k_min, noise_k_max, trend_bases_state],
        outputs=noise_plot_list + [noise_state]
    )
    btn_total.click(
        fn=plot_total_series,
        inputs=[total_radio, n_sequence, amplitude_min, amplitude_max, trend_total_radio, seasonal_total_radio, harmonic_total_radio, trend_bases_state, seasonal_state, noise_state],
        outputs=[total_plot]
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/gradio/queueing.py", line 759, in process_events
    response = await route_utils.call_process_api(
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/gradio/route_utils.py", line 354, in call_process_api
    output = await app.get_blocks().process_api(
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/gradio/blocks.py", line 2116, in process_api
    result = await self.call_function(
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/gradio/blocks.py", line 1623, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/home/lvu/playground/ts-project/venv/lib/python3.10/site-packages/anyio/_backends/

In [24]:
demo.close()

Closing server running on port: 7861


In [27]:
hyperprior_params = dotdict({'resolution_min': 25, 'resolution_max': 75, 'resolution_multiplier': 25,

'trend_lin_min': -2,  'trend_lin_max': 2,  'trend_lin_fixed_variance': 0.15,
'trend_exp_min': 2,  'trend_exp_max': 4, 'trend_exp_fixed_variance': 0.010,
'trend_exp_multiplier': 100,
'offset_lin_min': -1.5, 'offset_lin_max': 1.5,
'offset_exp_min': -1.5, 'offset_exp_max': 1.5,

'a_min': -2,  'a_max': 2,  'a_fixed_variance': 0.35,
'm_min': -4,  'm_max': 4,  'm_fixed_variance': 0.35,
'w_min': -8,  'w_max': 8,  'w_fixed_variance': 0.35,
'harmonics_min': 1,     'harmonics_max': 12,

'noise_k_min': 1.3,    'noise_k_max': 3,

'amplitude_min': 1,  'amplitude_max': 3})

In [28]:
import json, glob
from pathlib import Path

n_sequence = 12_000
out_dir   = "series_bank/v2"
n_series  = 1_000_000
shard_sz  = 20_000
dtype     = np.float32
seed      = 42
device    = "cpu"

Path(out_dir).mkdir(parents=True, exist_ok=True)
torch.manual_seed(seed); np.random.seed(seed)

# snapshot meta for reproducibility
meta = {
    "n_series": n_series,
    "shard_size": shard_sz,
    "dtype": "float32",
    "seed": seed,
    "n_sequence": n_sequence,
    "hyperprior_params": dict(hyperprior_params),
}
Path(out_dir, "meta.json").write_text(json.dumps(meta, indent=2))
n_sequence

12000

In [90]:
import math
from tqdm import tqdm

# --- strict constants ---
L, H = 512, 96
CLIP = 10.0
STD_MIN = 0.02
PATCH_S = 64
MAX_RETRIES = 12
n_series  = 1_000_000
n_sequence = 12000

num_shards = math.ceil(n_series / shard_sz)
idx = 0
rej_quick = rej_patch = 0

for s in tqdm(range(num_shards), desc='num_shards'):
    k = min(shard_sz, n_series - idx)
    buf = np.empty((k, n_sequence), dtype=dtype)

    for i in tqdm(range(k), desc='k', leave=False):
        accepted = False
        for tries in range(MAX_RETRIES):
            # synth
            y = gen_series()
            y = y.contiguous().float()

            # quick series-level gate on context region
            ctx_len = int(0.8 * n_sequence)
            y_ctx = y[:ctx_len]
            if y_ctx.std().item() < STD_MIN:
                rej_quick += 1
                continue

            # subsampled patch-level saturation test
            if ctx_len - (L + H) > 0:
                ends = torch.linspace(L-1, ctx_len-1-H, PATCH_S).long()
                Xs = torch.stack([y[e-L+1:e+1] for e in ends], dim=0)
                Zs = torch.stack([y[e+1:e+1+H] for e in ends], dim=0)
                mu = Xs.mean()
                med = Xs.median()
                mad = (Xs - med).abs().median()
                sigma = (1.4826 * mad).clamp_min(1e-9)   # no 0.10 floor here
                Zn = torch.clamp((Zs - mu) / sigma, -CLIP, CLIP)
                clip_frac = (Zn.abs() == CLIP).float().mean().item()
                if clip_frac > 0.25:
                    rej_patch += 1
                    continue

            accepted = True
            break

        if not accepted:
            # keep drawing until finite; skip stats gates, but require finiteness
            while True:
                y = gen_series()
                y = y.contiguous().float()
                if torch.isfinite(y).all():
                    break

        buf[i] = y.cpu().numpy().astype(dtype)

    np.save(Path(out_dir, f"series_shard_{s:04d}.npy"), buf)
    idx += k

print(f"Rejected (quick std): {rej_quick}")
print(f"Rejected (patch clip): {rej_patch}")

num_shards: 100%|██████████| 50/50 [1:01:19<00:00, 73.58s/it]

Rejected (quick std): 0
Rejected (patch clip): 0





In [30]:
meta_json = json.dumps(meta, indent=2)

In [31]:
from torch.utils.tensorboard import SummaryWriter
log_dir = f'../tb_test/series_bank_v2'
writer = SummaryWriter(log_dir=log_dir)


In [32]:
writer.add_text("Metadata", meta_json, 0)
writer.add_text("Notes", "New version", 0)