# Conditional Flow Diagnostics (MAF/MADE/RQS)

**Ziel:** Belastbar testen, ob und wie stark die Konditionierung \(\mathbf{X}\) in deinem Flow \(p(\mathbf{Y}\mid\mathbf{X})\) greift – skaliert auf viele (z. B. 20) Kontext-Features.

**Checks:**
1. **Permutation-Importance (ΔNLL)** je Kontext-Feature  
2. **Gradienten‑Sensitivität** von \(\log p(\mathbf{Y}\mid\mathbf{X})\) w.r.t. \(\mathbf{X}\)  
3. **Counterfactual‑Sampling** (kleine Δ in einem Feature → sichtbarer Shift in Y‑Verteilung)  
4. **All‑Shuffle Sanity Check** (alle Kontexte permutieren → NLL muss hochgehen)

> **Seaborn-Hinweis:** In dieser generierten Version werden Plots mit **matplotlib** erstellt.  
> Wenn du lokal **seaborn** nutzen willst, kannst du einfach die kommentierten `import seaborn as sns`/`sns.set()`‑Zeilen aktivieren und die Plot-Funktionen minimal anpassen.

In [1]:
# --- Imports ---
import math, json, os
import numpy as np
import torch
from contextlib import nullcontext
import matplotlib.pyplot as plt
import logging
import argparse
import yaml
from datetime import datetime
import pandas as pd
import sys

# Optional (lokal aktivierbar):
# import seaborn as sns
# sns.set()

In [2]:
# --- Konfiguration (anpassbar) ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE_EVAL = 4096
PERM_REPEATS = 2           # wie oft eine Spalte permutiert wird
COUNTERFACTUAL_DELTA = 0.1 # Schrittweite für Δx_j
TOPK = 5                   # wie viele Top-Features wir detailliert anschauen
USE_AMP = False            # für Diagnostik besser aus
PRINT_WIDTH = 120

In [3]:
# --- Utility-Funktionen ---
@torch.no_grad()
def mean_nll(model, Y, X, bs=4096, use_amp=False, device="cpu"):
    model.eval()
    tot, n = 0.0, 0
    X = X.to(device)
    Y = Y.to(device)
    autocast_ctx = (torch.amp.autocast(device_type="cuda") if (use_amp and device=="cuda") else nullcontext())
    for i in range(0, X.size(0), bs):
        xb = X[i:i+bs]
        yb = Y[i:i+bs]
        with autocast_ctx:
            nll = -model.log_probs(yb, xb)  # shape [B,1]
        tot += nll.sum().item()
        n += yb.size(0)
    return tot / max(n,1)

def perm_importance(model, Y, X, col_names, repeats=3, bs=4096, use_amp=False, device="cpu", rng=None):
    base = mean_nll(model, Y, X, bs, use_amp, device)
    deltas = {}
    Xcpu = X.detach().cpu()
    N = Xcpu.shape[0]
    rng = np.random.default_rng(None if rng is None else rng)
    for j, name in enumerate(col_names):
        incs = []
        for _ in range(repeats):
            Xperm = Xcpu.clone()
            idx = torch.from_numpy(rng.permutation(N))
            Xperm[:, j] = Xperm[idx, j]   # nur Spalte j mischen
            inc = mean_nll(model, Y, Xperm.to(device), bs, use_amp, device) - base
            incs.append(float(inc))
        deltas[name] = (float(np.mean(incs)), float(np.std(incs)))
    return base, deltas

def all_shuffle_nll(model, Y, X, bs=4096, use_amp=False, device="cpu", rng=None):
    rng = np.random.default_rng(None if rng is None else rng)
    Xcpu = X.detach().cpu().clone()
    N = Xcpu.shape[0]
    # jede Spalte unabhängig permutieren
    for j in range(Xcpu.shape[1]):
        idx = torch.from_numpy(rng.permutation(N))
        Xcpu[:, j] = Xcpu[idx, j]
    return mean_nll(model, Y, Xcpu.to(device), bs, use_amp, device)

def context_grad_sensitivity(model, Y, X, bs=2048, device="cpu"):
    model.eval()
    grads_sum = torch.zeros(X.size(1), device=device)
    count = 0
    for i in range(0, X.size(0), bs):
        xb = X[i:i+bs].to(device).detach().requires_grad_(True)
        yb = Y[i:i+bs].to(device)
        lp = model.log_probs(yb, xb).mean()   # scalar
        g, = torch.autograd.grad(lp, xb, retain_graph=False, create_graph=False)
        grads_sum += g.abs().mean(dim=0)
        count += 1
    return (grads_sum / max(count,1)).detach().cpu()  # Größe [C]

@torch.no_grad()
def counterfactual_shift(model, X, j, delta, sample_bs=8192, device="cpu"):
    model.eval()
    N = X.size(0)
    # a) Original-Samples
    Ys = []
    for i in range(0, N, sample_bs):
        xb = X[i:i+sample_bs].to(device)
        Ys.append(model.sample(num_samples=xb.size(0), cond_inputs=xb))  # [B,Dy]
    Y0 = torch.cat(Ys, dim=0).cpu()

    # b) Perturbierte Kontexte
    Xp = X.clone()
    Xp[:, j] += delta
    Ys = []
    for i in range(0, N, sample_bs):
        xb = Xp[i:i+sample_bs].to(device)
        Ys.append(model.sample(num_samples=xb.size(0), cond_inputs=xb))
    Y1 = torch.cat(Ys, dim=0).cpu()

    mean_shift = (Y1.mean(dim=0) - Y0.mean(dim=0))       # [Dy]
    var_shift  = (Y1.var(dim=0, unbiased=False) - Y0.var(dim=0, unbiased=False))  # [Dy]
    return mean_shift, var_shift

def topk_by_delta(deltas, k=5):
    arr = [(name, v[0], v[1]) for name, v in deltas.items()]
    arr.sort(key=lambda t: t[1], reverse=True)  # nach ΔNLL-mean absteigend
    return arr[:k], arr

def print_table(rows, headers):
    colw = [max(len(str(h)), *(len(str(r[i])) for r in rows)) for i,h in enumerate(headers)]
    fmt = " | ".join("{:%d}"%w for w in colw)
    line = "-+-".join("-"*w for w in colw)
    print(fmt.format(*headers))
    print(line)
    for r in rows:
        print(fmt.format(*r))

# --- Plot-Funktionen (matplotlib-only) ---
def barplot_perm_importance(deltas, title="Permutation Importance (ΔNLL)", figsize=(10, 5)):
    names = list(deltas.keys())
    vals  = [deltas[n][0] for n in names]
    errs  = [deltas[n][1] for n in names]
    idx = np.arange(len(names))
    plt.figure(figsize=figsize)
    plt.bar(idx, vals, yerr=errs)
    plt.xticks(idx, names, rotation=45, ha="right")
    plt.ylabel("ΔNLL (↑ schlimmer)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def barplot_grad_sensitivity(names, grads, title="Gradient Sensitivity ⟨|∂ log p/∂x_j|⟩", figsize=(10,5)):
    idx = np.arange(len(names))
    plt.figure(figsize=figsize)
    plt.bar(idx, grads)
    plt.xticks(idx, names, rotation=45, ha="right")
    plt.ylabel("⟨|∂ log p/∂x_j|⟩")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def mean_var_shift_plot(mean_shift, var_shift, out_names=None, title="Counterfactual Shift (Δx_j)", figsize=(10,6)):
    Dy = mean_shift.numel()
    xs = np.arange(Dy)
    plt.figure(figsize=figsize)
    plt.plot(xs, mean_shift.numpy(), marker="o", linestyle="-", label="Δ mean(Y)")
    plt.plot(xs, var_shift.numpy(),  marker="s", linestyle="--", label="Δ var(Y)")
    if out_names is not None and len(out_names)==Dy:
        plt.xticks(xs, out_names, rotation=45, ha="right")
    plt.legend()
    plt.xlabel("Output-Dimension")
    plt.title(title)
    plt.tight_layout()
    plt.show()

In [4]:
def load_config_and_parser(system_path):
    now = datetime.now()
    if get_os() == "Mac":
        print("load MAC config-file")
        config_file_name_classifier = "MAC_run_classifier.cfg"
        config_file_name_flow = "MAC_run_flow.cfg"
    elif get_os() == "Linux":
        print("load LMU config-file")
        config_file_name_classifier = "LMU_run_classifier.cfg"
        config_file_name_flow = "LMU_run_flow.cfg"
    else:
        print("Undefined operating system")
        sys.exit()

    parser_classifier = argparse.ArgumentParser(description='Start gaNdalF')
    parser_classifier.add_argument(
        '--config_filename',
        "-cf",
        type=str,
        nargs=1,
        required=False,
        default=config_file_name_classifier,
        help='Name of config file. If not given default.cfg will be used'
    )
    args_classifier = parser_classifier.parse_args()
    if isinstance(args_classifier.config_filename, list):
        args_classifier.config_filename = args_classifier.config_filename[0]
    with open(f"{system_path}/conf/{args_classifier.config_filename}", 'r') as fp:
        print(f"open {f'{system_path}/conf/{args_classifier.config_filename}'}")
        config_classifier = yaml.safe_load(fp)
    config_classifier['RUN_DATE'] = now.strftime('%Y-%m-%d_%H-%M')

    parser_flow = argparse.ArgumentParser(description='Start gaNdalF')
    parser_flow.add_argument(
        '--config_filename',
        "-cf",
        type=str,
        nargs=1,
        required=False,
        default=config_file_name_flow,
        help='Name of config file. If not given default.cfg will be used'
    )
    args_flow = parser_flow.parse_args()
    if isinstance(args_flow.config_filename, list):
        args_flow.config_filename = args_flow.config_filename[0]
    with open(f"{system_path}/conf/{args_flow.config_filename}", 'r') as fp:
        print(f"open {f'{system_path}/conf/{args_flow.config_filename}'}")
        config_flow = yaml.safe_load(fp)
    config_flow['RUN_DATE'] = now.strftime('%Y-%m-%d_%H-%M')

    return config_classifier, config_flow

In [10]:
from Handler import fnn, get_os, unsheared_shear_cuts, unsheared_mag_cut, LoggerHandler, plot_features, plot_binning_statistics_combined, plot_balrog_histogram_with_error, plot_compare_corner, calc_color
from gandalf import gaNdalF

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
classifier_cfg, flow_cfg = load_config_and_parser(system_path=os.path.abspath(sys.path[0]))

log_lvl = logging.INFO
if flow_cfg["LOGGING_LEVEL"] == "DEBUG":
    log_lvl = logging.DEBUG
elif flow_cfg["LOGGING_LEVEL"] == "ERROR":
    log_lvl = logging.ERROR

run_flow_logger = LoggerHandler(
    logger_dict={"logger_name": "train flow logger",
                 "info_logger": flow_cfg['INFO_LOGGER'],
                 "error_logger": flow_cfg['ERROR_LOGGER'],
                 "debug_logger": flow_cfg['DEBUG_LOGGER'],
                 "stream_logger": flow_cfg['STREAM_LOGGER'],
                 "stream_logging_level": log_lvl},
    log_folder_path=f"{flow_cfg['PATH_OUTPUT']}/"
)

flow_model = gaNdalF(run_flow_logger, classifier_cfg=classifier_cfg, flow_cfg=flow_cfg)
# df_gandalf, df_balrog = flow_model.run_flow()

df_balrog = flow_model.galaxies.test_dataset
col_in  = flow_cfg["INPUT_COLS"]
col_out = flow_cfg["OUTPUT_COLS"]
X_valid = torch.tensor(df_balrog[col_in].values, dtype=next(flow_model.parameters()).dtype)
Y_valid = torch.tensor(df_balrog[col_out].values, dtype=next(flow_model.parameters()).dtype)
names_in = list(col_in)
names_out = list(col_out)

load MAC config-file


ValueError: I/O operation on closed file

In [None]:
# --- 1) Permutation-Importance ---
base_nll, deltas = perm_importance(flow_model, Y_valid, X_valid, names_in, repeats=PERM_REPEATS, bs=BATCH_SIZE_EVAL, use_amp=USE_AMP, device=DEVICE)
topk, all_rows = topk_by_delta(deltas, k=min(TOPK, len(names_in)))

print(f"Base NLL: {base_nll:.6f}")
rows = [(n, f"{m:.6f}", f"{s:.6f}") for (n,m,s) in [(n,)+deltas[n] for n in deltas]]
rows.sort(key=lambda r: float(r[1]), reverse=True)
print_table(rows, headers=["Feature", "ΔNLL (mean)", "ΔNLL (std)"])

# Plot (alle Features)
barplot_perm_importance(deltas, title="Permutation Importance (ΔNLL)")

In [None]:
# --- 2) Gradienten-Sensitivität ---
grads = context_grad_sensitivity(flow_model, Y_valid, X_valid, bs=2048, device=DEVICE).numpy()
rows = list(zip(names_in, [f"{g:.6e}" for g in grads]))
rows.sort(key=lambda t: float(t[1]), reverse=True)
print_table(rows, headers=["Feature", "⟨|∂ log p/∂x_j|⟩"])
barplot_grad_sensitivity(names_in, grads, title="Gradient Sensitivity ⟨|∂ log p/∂x_j|⟩")

In [None]:
# --- 3) Counterfactual‑Sampling auf Top-K Features ---
# Wähle Top-K nach ΔNLL
top_features = [name for name,_,_ in topk]
print("Top-Features (nach ΔNLL):", top_features)

subset = X_valid[: min(10000, X_valid.size(0))].clone()
for name in top_features:
    j = names_in.index(name)
    mean_shift, var_shift = counterfactual_shift(flow_model, subset, j, delta=COUNTERFACTUAL_DELTA, device=DEVICE)
    print(f"\nFeature: {name} (Δx={COUNTERFACTUAL_DELTA})")
    print("Δ mean(Y):", mean_shift.numpy())
    print("Δ var(Y): ", var_shift.numpy())
    mean_var_shift_plot(mean_shift, var_shift, out_names=names_out, title=f"Counterfactual Shift: {name} (Δx={COUNTERFACTUAL_DELTA})")

In [None]:
# --- 4) All‑Shuffle Sanity Check ---
nll_all_shuffle = all_shuffle_nll(flow_model, Y_valid, X_valid, bs=BATCH_SIZE_EVAL, use_amp=USE_AMP, device=DEVICE)
print(f"NLL (all-shuffle contexts): {nll_all_shuffle:.6f}  |  Base NLL: {base_nll:.6f}  |  Δ = {nll_all_shuffle - base_nll:.6f}")