In [1]:
import sys
import os

# Go up one level to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [2]:
import torch, math
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from util.config_util import dotdict

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
seasonality_base = 3.0
w, m, a = seasonality_base*1, seasonality_base*2, seasonality_base*4

hyperprior_params = dotdict({
    # Seasonality parameters
    'a_min': -a, 'a_max': a, 'a_fixed_variance': 0.15,
    'm_min': -m, 'm_max': m, 'm_fixed_variance': 0.15,
    'w_min': -w, 'w_max': w, 'w_fixed_variance': 0.15,
    
    # Trend parameters
    'trend_lin_min': -0.015, 'trend_lin_max': 0.015, 'trend_lin_fixed_variance': 0.005,
    'trend_exp_min': 1 - 0.005, 'trend_exp_max': 1 + 0.005, 'trend_exp_fixed_variance': 0.001,
    'trend_exp_multiplier': 400,
    
    # Noise and resolution
    'noise_k_min': 0.5, 'noise_k_max': 3.5,
    'resolution_min': 0.1, 'resolution_max': 1.2, 'resolution_multiplier': 50,
    
    # Other parameters
    'harmonics_min': 2, 'harmonics_max': 8,
    'discreteness_min': 1, 'discreteness_max': 6,
    'bias_zi_min': 0.8, 'bias_zi_max': 3.0,
    'amplitude_min': 0.5, 'amplitude_max': 4.0,
    'non_negative_prob': 0.2,
    'offset_lin_min': -1.0, 'offset_lin_max': 1.2,
    'offset_exp_min': -1.5, 'offset_exp_max': 2.0,
    'f_zi_min': 0.0, 'f_zi_max': 0.6, 'f_zi_fixed_variance': 0.3
})

In [105]:
import json, glob
from pathlib import Path

n_sequence = 12_000
out_dir   = "series_bank/noise_v1_no_reject"
n_series  = 1_000_000
shard_sz  = 20_000
dtype     = np.float32
seed      = 42
device    = "cpu"

Path(out_dir).mkdir(parents=True, exist_ok=True)
torch.manual_seed(seed); np.random.seed(seed)

# snapshot meta for reproducibility
meta = {
    "n_series": n_series,
    "shard_size": shard_sz,
    "dtype": "float32",
    "seed": seed,
    "n_sequence": n_sequence,
    "hyperprior_params": dict(hyperprior_params),
}
Path(out_dir, "meta.json").write_text(json.dumps(meta, indent=2))
n_sequence

12000

In [106]:
import math
from tqdm import tqdm

# --- strict constants ---
L, H = 512, 96
CLIP = 10.0
STD_MIN = 0.02
PATCH_S = 64
MAX_RETRIES = 12
n_series  = 1_000_000
n_sequence = 12000

num_shards = math.ceil(n_series / shard_sz)
idx = 0
rej_quick = rej_patch = 0

for s in tqdm(range(num_shards), desc='num_shards'):
    k = min(shard_sz, n_series - idx)
    buf = np.empty((k, n_sequence), dtype=dtype)

    for i in tqdm(range(k), desc='k', leave=False):
        accepted = False
        for tries in range(MAX_RETRIES):
            # synth
            y = gen_series()
            y = y.contiguous().float()

            # quick series-level gate on context region
            # ctx_len = int(0.8 * n_sequence)
            # y_ctx = y[:ctx_len]
            # if y_ctx.std().item() < STD_MIN:
            #     rej_quick += 1
            #     continue

            # subsampled patch-level saturation test
            # if ctx_len - (L + H) > 0:
            #     ends = torch.linspace(L-1, ctx_len-1-H, PATCH_S).long()
            #     Xs = torch.stack([y[e-L+1:e+1] for e in ends], dim=0)
            #     Zs = torch.stack([y[e+1:e+1+H] for e in ends], dim=0)
            #     mu = Xs.mean()
            #     med = Xs.median()
            #     mad = (Xs - med).abs().median()
            #     sigma = (1.4826 * mad).clamp_min(1e-9)   # no 0.10 floor here
            #     Zn = torch.clamp((Zs - mu) / sigma, -CLIP, CLIP)
            #     clip_frac = (Zn.abs() == CLIP).float().mean().item()
            #     if clip_frac > 0.25:
            #         rej_patch += 1
            #         continue

            # accepted = True
            # break

        # if not accepted:
        #     # keep drawing until finite; skip stats gates, but require finiteness
        #     while True:
        #         y = gen_series()
        #         y = y.contiguous().float()
        #         if torch.isfinite(y).all():
        #             break

        buf[i] = y.cpu().numpy().astype(dtype)

    np.save(Path(out_dir, f"series_shard_{s:04d}.npy"), buf)
    idx += k

print(f"Rejected (quick std): {rej_quick}")
print(f"Rejected (patch clip): {rej_patch}")

num_shards: 100%|██████████| 50/50 [7:28:19<00:00, 537.99s/it]  

Rejected (quick std): 0
Rejected (patch clip): 0



