# 1D ISL Training Notebook

This notebook trains a **1D implicit generator** using the **Invariant Statistical Loss (ISL)** on several possible 1D target distributions:

- `gaussian`: $\mathcal N(0, 1)$
- `mixture`: symmetric 2-Gaussian mixture $0.5\,\mathcal N(-\delta,1) + 0.5\,\mathcal N(+\delta,1)$
- `laplace`: Laplace$(0, b)$ (double exponential)
- `student`: centered Student-t$(\nu)$ (heavy tails)
- `lognormal`: LogNormal$(0, \sigma)$ (positive, skewed)
- `pareto`: Pareto$(x_m, \alpha)$ (positive heavy tail)
- `mog3`: 3-component Gaussian mixture with default parameters

The code assumes you are working in the **`isl-implicit-generative-models`** repository with a `src/isl/` package providing:

- `isl.loss_1d.isl_1d_soft`
- `isl.models.MLPGenerator`
- `isl.utils.set_seed`, `isl.utils.get_device`, `isl.utils.ensure_dir`
- `isl.metrics.ksd_rbf`

You can adapt paths/imports if your layout differs.


In [None]:
# Imports and path setup
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # macOS OpenMP workaround
os.environ["OMP_NUM_THREADS"] = "1"

from pathlib import Path
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Laplace, StudentT, LogNormal, Pareto

# Adjust this if the notebook is not at repo root.
ROOT = Path().resolve()
SRC = ROOT / "src"
if not SRC.exists():
    # If the notebook lives in experiments/1d_univariate/, go two levels up
    ROOT = Path.cwd().resolve().parents[2]
    SRC = ROOT / "src"

if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

from isl.models import MLPGenerator
from isl.loss_1d import isl_1d_soft
from isl.utils import set_seed, get_device, ensure_dir
from isl.metrics import ksd_rbf

print("Using ROOT:", ROOT)


In [None]:
# Target distributions: sampling functions

def sample_real_gaussian(n: int, device: torch.device) -> torch.Tensor:
    """Sample from the 1D standard normal N(0, 1)."""
    return torch.randn(n, device=device)


def sample_real_mixture(n: int, device: torch.device, delta: float) -> torch.Tensor:
    """Sample from symmetric 2-component Gaussian mixture: 0.5 N(-delta,1) + 0.5 N(+delta,1)."""
    signs = torch.where(
        torch.rand(n, device=device) < 0.5,
        torch.full((n,), -1.0, device=device),
        torch.full((n,), +1.0, device=device),
    )
    eps = torch.randn(n, device=device)
    return signs * delta + eps


def sample_real_laplace(n: int, device: torch.device, scale: float) -> torch.Tensor:
    """Sample from Laplace(0, scale)."""
    loc = torch.tensor(0.0, device=device)
    sc = torch.tensor(scale, device=device)
    dist = Laplace(loc=loc, scale=sc)
    return dist.sample((n,))


def sample_real_student(n: int, device: torch.device, df: float) -> torch.Tensor:
    """Sample from a centered Student-t(df)."""
    df_t = torch.tensor(df, device=device)
    dist = StudentT(df_t)
    return dist.sample((n,))


def sample_real_lognormal(n: int, device: torch.device, sigma: float) -> torch.Tensor:
    """Sample from LogNormal(0, sigma)."""
    mean = torch.tensor(0.0, device=device)
    std = torch.tensor(sigma, device=device)
    dist = LogNormal(mean, std)
    return dist.sample((n,))


def sample_real_pareto(n: int, device: torch.device, xm: float, alpha: float) -> torch.Tensor:
    """Sample from Pareto(xm, alpha). Support: x >= xm > 0."""
    scale = torch.tensor(xm, device=device)
    a = torch.tensor(alpha, device=device)
    dist = Pareto(scale=scale, alpha=a)
    return dist.sample((n,))


def sample_real_mog3(
    n: int,
    device: torch.device,
    means: torch.Tensor | None = None,
    stds: torch.Tensor | None = None,
    weights: torch.Tensor | None = None,
) -> torch.Tensor:
    """Sample from a 1D 3-component Gaussian mixture.

    Default: means=[-3, 0, 3], stds=[1, 0.5, 1], weights=[0.3, 0.4, 0.3].
    """
    if means is None:
        means = torch.tensor([-3.0, 0.0, 3.0], device=device)
    if stds is None:
        stds = torch.tensor([1.0, 0.5, 1.0], device=device)
    if weights is None:
        weights = torch.tensor([0.3, 0.4, 0.3], device=device)

    weights = weights / weights.sum()
    comp_idx = torch.multinomial(weights, num_samples=n, replacement=True)
    comp_means = means[comp_idx]
    comp_stds = stds[comp_idx]

    eps = torch.randn(n, device=device)
    x = comp_means + comp_stds * eps
    return x


In [None]:
# Training loop as a reusable function

def train_isl_1d(
    target: str = "gaussian",
    steps: int = 5000,
    batch_size: int = 512,
    K: int = 32,
    cdf_bandwidth: float = 0.15,
    hist_sigma: float = 0.05,
    noise_dim: int = 4,
    hidden_dims: tuple[int, ...] = (64, 64),
    lr: float = 1e-3,
    mixture_delta: float = 2.0,
    laplace_scale: float = 1.0,
    student_df: float = 3.0,
    lognorm_sigma: float = 0.5,
    pareto_xm: float = 1.0,
    pareto_alpha: float = 2.5,
    log_every: int = 200,
    device: torch.device | None = None,
    outdir: Path | None = None,
):
    """Train a 1D generator with ISL loss for a chosen 1D target distribution."""
    if device is None:
        device = get_device(prefer_gpu=True)
    if outdir is None:
        outdir = ROOT / "experiments" / "1d_univariate" / "runs"
    ensure_dir(outdir)

    print("=" * 70)
    print("Train 1D generator with ISL")
    print(f"  target         : {target}")
    print(f"  steps          : {steps}")
    print(f"  batch_size     : {batch_size}")
    print(f"  K              : {K}")
    print(f"  cdf_bandwidth  : {cdf_bandwidth}")
    print(f"  hist_sigma     : {hist_sigma}")
    print(f"  noise_dim      : {noise_dim}")
    print(f"  hidden_dims    : {hidden_dims}")
    print(f"  lr             : {lr}")
    print(f"  mixture_delta  : {mixture_delta}")
    print(f"  laplace_scale  : {laplace_scale}")
    print(f"  student_df     : {student_df}")
    print(f"  lognorm_sigma  : {lognorm_sigma}")
    print(f"  pareto_xm      : {pareto_xm}")
    print(f"  pareto_alpha   : {pareto_alpha}")
    print(f"  device         : {device}")
    print(f"  outdir         : {outdir}")
    print("=" * 70)

    torch.set_num_threads(1)

    gen = MLPGenerator(
        noise_dim=noise_dim,
        data_dim=1,
        hidden_dims=hidden_dims,
    ).to(device)

    optimizer = optim.Adam(gen.parameters(), lr=lr)
    losses: list[float] = []

    for step in range(1, steps + 1):
        # Real samples
        if target == "gaussian":
            x_real = sample_real_gaussian(batch_size, device=device)
        elif target == "mixture":
            x_real = sample_real_mixture(batch_size, device=device, delta=mixture_delta)
        elif target == "laplace":
            x_real = sample_real_laplace(batch_size, device=device, scale=laplace_scale)
        elif target == "student":
            x_real = sample_real_student(batch_size, device=device, df=student_df)
        elif target == "lognormal":
            x_real = sample_real_lognormal(batch_size, device=device, sigma=lognorm_sigma)
        elif target == "pareto":
            x_real = sample_real_pareto(batch_size, device=device, xm=pareto_xm, alpha=pareto_alpha)
        elif target == "mog3":
            x_real = sample_real_mog3(batch_size, device=device)
        else:
            raise ValueError(f"Unknown target: {target!r}")

        # Generator samples
        z = torch.randn(batch_size, noise_dim, device=device)
        x_fake = gen(z).view(-1)

        # ISL loss
        loss = isl_1d_soft(
            x_real.view(-1),
            x_fake,
            K=K,
            cdf_bandwidth=cdf_bandwidth,
            hist_sigma=hist_sigma,
            reduction="mean",
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if step % log_every == 0 or step == 1 or step == steps:
            print(f"[{step:5d}/{steps}] ISL loss = {loss.item():.6f}")

    # Save model
    ckpt_path = outdir / f"generator_1d_{target}_K{K}.pt"
    torch.save(gen.state_dict(), ckpt_path)
    print(f"\nSaved generator checkpoint to {ckpt_path}")

    # Plot training curve
    steps_arr = np.arange(1, len(losses) + 1)
    plt.figure(figsize=(6, 4))
    plt.plot(steps_arr, losses)
    plt.xlabel("Training step")
    plt.ylabel("ISL loss")
    plt.title(f"1D ISL training curve ({target}, K={K})")
    plt.grid(True, ls="--", alpha=0.5)
    plt.tight_layout()
    curve_path = outdir / f"isl_1d_loss_curve_{target}_K{K}.png"
    plt.savefig(curve_path, dpi=200)
    plt.close()
    print(f"Saved training curve to {curve_path}")

    # Evaluation: real vs generated
    gen.eval()
    with torch.no_grad():
        N_eval = 20000

        if target == "gaussian":
            x_real_eval = sample_real_gaussian(N_eval, device=device)
        elif target == "mixture":
            x_real_eval = sample_real_mixture(N_eval, device=device, delta=mixture_delta)
        elif target == "laplace":
            x_real_eval = sample_real_laplace(N_eval, device=device, scale=laplace_scale)
        elif target == "student":
            x_real_eval = sample_real_student(N_eval, device=device, df=student_df)
        elif target == "lognormal":
            x_real_eval = sample_real_lognormal(N_eval, device=device, sigma=lognorm_sigma)
        elif target == "pareto":
            x_real_eval = sample_real_pareto(N_eval, device=device, xm=pareto_xm, alpha=pareto_alpha)
        elif target == "mog3":
            x_real_eval = sample_real_mog3(N_eval, device=device)
        else:
            raise ValueError(f"Unknown target: {target!r}")

        z_eval = torch.randn(N_eval, noise_dim, device=device)
        x_fake_eval = gen(z_eval).view(-1)

    x_real_np = x_real_eval.cpu().numpy()
    x_fake_np = x_fake_eval.cpu().numpy()

    print("\nSample statistics (generated):")
    print(f"  mean ≈ {x_fake_np.mean():.4f}, std ≈ {x_fake_np.std():.4f}")
    print("  (real)    mean ≈ {:.4f}, std ≈ {:.4f}".format(
        x_real_np.mean(), x_real_np.std()
    ))

    # Histogram overlay
    plt.figure(figsize=(6, 4))
    xmin = min(x_real_np.min(), x_fake_np.min())
    xmax = max(x_real_np.max(), x_fake_np.max())
    plt.hist(
        x_real_np,
        bins=100,
        range=(xmin, xmax),
        density=True,
        alpha=0.5,
        label="real",
    )
    plt.hist(
        x_fake_np,
        bins=100,
        range=(xmin, xmax),
        density=True,
        alpha=0.5,
        label="generated",
    )
    plt.legend()
    plt.xlabel("x")
    plt.ylabel("density (hist)")
    plt.title(f"Real vs generated histogram ({target}, K={K})")
    plt.grid(True, ls="--", alpha=0.5)
    plt.tight_layout()
    hist_path = outdir / f"isl_1d_hist_{target}_K{K}.png"
    plt.savefig(hist_path, dpi=200)
    plt.close()
    print(f"Saved histogram plot to {hist_path}")

    # Optional: KSD for Gaussian
    if target == "gaussian":
        print("\nComputing KSD against N(0,1)...")
        samples = torch.from_numpy(x_fake_np).float().unsqueeze(1).to(device)

        def score_fn(x: torch.Tensor) -> torch.Tensor:
            return -x

        with torch.no_grad():
            ksd_val = ksd_rbf(samples, score_fn).item()
        print(f"KSD (generated vs N(0,1)) ≈ {ksd_val:.4e}")

    return gen, np.array(losses), x_real_np, x_fake_np


In [None]:
# Configuration cell: edit these for quick experiments

seed = 42
set_seed(seed, deterministic=False)

config = {
    "target": "gaussian",   # one of: gaussian, mixture, laplace, student, lognormal, pareto, mog3
    "steps": 3000,
    "batch_size": 512,
    "K": 32,
    "noise_dim": 4,
    "hidden_dims": (64, 64),
    "lr": 1e-3,
    "mixture_delta": 2.0,
    "laplace_scale": 1.0,
    "student_df": 3.0,
    "lognorm_sigma": 0.5,
    "pareto_xm": 1.0,
    "pareto_alpha": 2.5,
    "log_every": 200,
}

device = get_device(prefer_gpu=True)
print("Device:", device)


In [None]:
# Run training with the above configuration

gen, losses, x_real_np, x_fake_np = train_isl_1d(
    target=config["target"],
    steps=config["steps"],
    batch_size=config["batch_size"],
    K=config["K"],
    cdf_bandwidth=0.15,
    hist_sigma=0.05,
    noise_dim=config["noise_dim"],
    hidden_dims=config["hidden_dims"],
    lr=config["lr"],
    mixture_delta=config["mixture_delta"],
    laplace_scale=config["laplace_scale"],
    student_df=config["student_df"],
    lognorm_sigma=config["lognorm_sigma"],
    pareto_xm=config["pareto_xm"],
    pareto_alpha=config["pareto_alpha"],
    log_every=config["log_every"],
    device=device,
)
