In [13]:
import numpy as np
import torch
import os
from tqdm import tqdm
import json
from scipy.stats import multivariate_normal


# -----------------------
# Config
# -----------------------
# -----------------------
# Config
# -----------------------
n_models = 10
min_fluctuation = 300

# Create dynamic output directory
base_dir = "/work/gbadarac/MonoJet_NPLM/MonoJet_NPLM_analysis/Generate_Gaussian_Toy/saved_generated_gaussian_toys"
output_dir = os.path.join(base_dir, f"N_{n_models}_fluct{min_fluctuation}")
os.makedirs(output_dir, exist_ok=True)

x_eval_path = "/work/gbadarac/MonoJet_NPLM/MonoJet_NPLM_analysis/Train_Ensembles/Generate_Data/saved_generated_target_data/100k_target_training_set.npy"
x_eval = np.load(x_eval_path)  # shape: (100000, 2)
n_points = x_eval.shape[0]

# Target distribution params (known 2D Gaussian)
mean = np.array([-0.5, 0.6])
cov = np.diag([0.25**2, 0.4**2])  # diagonal covariance
inv_cov = np.linalg.inv(cov)
det_cov = np.linalg.det(cov)

# Reproducibility
np.random.seed(1234)

# -----------------------
# Evaluate target PDF at x_eval
# -----------------------
delta = x_eval - mean
exponent = -0.5 * np.sum(delta @ inv_cov * delta, axis=1)
norm_const = 1. / (2 * np.pi * np.sqrt(det_cov))
target_probs = norm_const * np.exp(exponent)  # shape: (100000,)



In [14]:
# -----------------------
# Generate toy models with structured bumps
# -----------------------

bump_width=0.025  # Width of the Gaussian bump fixed across all models
bump_cov = np.eye(2) * bump_width**2  # isotropic 2D bump

model_probs = []
bump_params = []

i = 0  # actual model index
attempt = 0  # counts total attempts just for debugging

while i < n_models:
    attempt += 1

    # Sample bump center from target distribution
    probs_2d = target_probs / target_probs.sum()
    idx_center = np.random.choice(np.arange(len(x_eval)), p=probs_2d)
    center = x_eval[idx_center]
    n_local = target_probs[idx_center]

    bump = multivariate_normal.pdf(x_eval, mean=center, cov=bump_cov)

    # Estimate fluctuation
    n_bump = np.sum(target_probs * bump)
    fluctuation = np.sqrt(n_bump)
    sign = np.random.choice([-1, 1])
    delta_N = sign * fluctuation

    if abs(delta_N) < min_fluctuation:
        continue  # try again without incrementing i

    bump *= fluctuation / bump.sum()
    perturbed = target_probs + sign * bump

    # Enforce non-negativity and normalize
    perturbed = np.clip(perturbed, 0, None)
    perturbed /= perturbed.sum()
    perturbed *= n_points

    model_probs.append(torch.tensor(perturbed, dtype=torch.float32))
    bump_params.append({
        "model_idx": i,
        "center": center.tolist(),
        "sign": int(sign),
        "width": float(bump_width),
        "n_bump": float(n_bump),
        "fluctuation": float(fluctuation)
    })

    i += 1  # only increment when a valid model is generated


with open(os.path.join(output_dir, "bump_params.json"), "w") as f:
    json.dump(bump_params, f, indent=2)


In [15]:
# -----------------------
# Plot marginals
# -----------------------
import matplotlib.pyplot as plt
model_probs_np = torch.stack(model_probs).numpy()  # shape: (10, 100000)
feature_names = ["Feature 1", "Feature 2"]
bins = 100

# Normalize target_probs once (outside loop)
target_probs_normed = target_probs / target_probs.sum()

# Loop over models
for model_idx, probs in enumerate(model_probs_np):
    fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=False)

    # Normalize perturbed model
    perturbed_normed = probs / probs.sum()

    for i in range(2):
        feature = x_eval[:, i]
        bin_edges = np.linspace(feature.min(), feature.max(), bins + 1)
        bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
        bin_widths = np.diff(bin_edges)

        # Use normalized weights for histogram
        pert_hist = np.histogram(feature, bins=bin_edges, weights=perturbed_normed)[0]
        target_hist = np.histogram(feature, bins=bin_edges, weights=target_probs_normed)[0]

        # Convert to densities
        pert_density = pert_hist / bin_widths
        target_density = target_hist / bin_widths

        ax = axes[i]
        ax.plot(bin_centers, target_density, label="Target", color="red", linestyle="--")
        ax.plot(bin_centers, pert_density, label="Perturbed Model", color="blue")

        # Shade bump region
        bump_info = bump_params[model_idx]
        mu = bump_info["center"][i]
        sigma = bump_info["width"]
        fluct = bump_info["fluctuation"]
        sign = bump_info["sign"] 

        bump_region = (bin_centers >= mu - 1.5 * sigma) & (bin_centers <= mu + 1.5 * sigma)
        ax.fill_between(bin_centers, pert_density, target_density, where=bump_region,
                        color="blue", alpha=0.3, label="Bump region")

        # Annotate bump
        y_bump = pert_density[bump_region].max()
        y_top = ax.get_ylim()[1]
        y_text = min(y_bump * 1.05, y_top * 0.95)  # clamp to stay inside

        ax.annotate(
            f"μ={mu:.2f}, σ={sigma:.3f}, ΔN={sign * fluct:.1f}",
            xy=(mu, y_bump),
            xytext=(mu, y_text),
            ha="center", fontsize=9,
            arrowprops=dict(arrowstyle="->", color="gray"),
            clip_on=True
        )

        ax.set_ylabel("Density")
        ax.set_title(f"Model {model_idx} — Marginal: {feature_names[i]}")
        ax.legend()

    axes[1].set_xlabel("Feature value")
    plt.tight_layout()
    fig.savefig(os.path.join(output_dir, f"model_{model_idx}_marginals.png"))
    plt.close(fig)



In [10]:
# -----------------------
# Save models and metadata
# -----------------------
model_probs_np = np.stack([t.numpy() for t in model_probs], axis=1)  # (N, M)
np.save(os.path.join(output_dir, "f_i.npy"), model_probs_np)

with open(os.path.join(output_dir, "bump_params.json"), "w") as f:
    json.dump(bump_params, f, indent=2)
