### Natural Posterior Network

In [None]:
import pytorch_lightning as pl
from src.datasets import DATASET_REGISTRY
from src.models import suppress_pytorch_lightning_logs
from src.models.natpn.lightning import NaturalPosteriorNetwork

seed = 42
suppress_pytorch_lightning_logs()
pl.seed_everything(seed)
dm = DATASET_REGISTRY["blob"](seed=seed)
dm.prepare_data()
dm.setup("test")

In [None]:
trainer_params = dict(
    jnable_checkpointing=False,
    enable_progress_bar=True,
    enable_model_summary=True,
    # fast_dev_run=1,
    max_epochs=1,
    gpus=0,
)

params_dict = dict(
    latent_dim=4,
    encoder="tabular",
    flow="radial",
    flow_num_layers=4,
    residual=True,
    spectral=(False, False, False),
    lipschitz_constant=1,
    entropy_weight=1e-5,
    reconst_weight=0.1,
    evidence_scaler=1.0,
    # pretrained_enc_path="",  # Load a pretrained encoder
    learning_rate=1e-1,
    learning_rate_head=1e-4,
    optim="adamw",
    warmup_epochs=1,
    finetune_epochs=1,
    trainer_params=trainer_params,
)

In [None]:
estimator = NaturalPosteriorNetwork(**params_dict)
estimator.fit(dm)
result_id = estimator.score(dm)
result_ood = estimator.score_ood_detection(dm)

print(result_id)
print(result_ood)