Skip to content

fifi-research/SESaMo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Paper License: MIT

SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows

Quick start

Install the package with pip:

pip install sesamo

Here is a quick example of how to use SESaMo to build a normalizing flow with stochastic modulation:

import torch
from sesamo import Sesamo
from sesamo.models import GaussianPrior, RealNVP, Z2Modulation, Z2Regularization
from sesamo.loss import StochmodLoss

# Initialize SESaMo
sesamo = Sesamo(
    prior=GaussianPrior(
        var=1,
        lat_shape=[1,2]
    ),
    flow=RealNVP(
        lat_shape=[1,2],
        num_coupling_layers=10,
        num_hidden_layers=2,
        num_hidden_features=40
    ),
    stochastic_modulation=Z2Modulation(),
    regularization=Z2Regularization(),
).to("cuda")

action = # define action for the target distribution p(x) = exp(-action(x)) / Z
loss_fn = StochmodLoss()
optimizer = torch.optim.Adam(sesamo.parameters(), lr=5e-4)

# Training loop
for _ in range(10_000):
    # reset gradients
    optimizer.zero_grad()

    # sample from sesamo
    samples, log_prob, log_prob_stochmod, penalty = sesamo.sample_for_training(8_000)
    
    # compute action and loss
    action_samples = action(samples)
    loss = loss_fn(action_samples, log_prob, log_prob_stochmod, penalty).mean()
    
    # backpropagate and update flow parameters
    loss.backward()
    optimizer.step()

Examples

For more examples see the SESaMo/examples folder, which contains Jupyter notebooks for the Hubbard model and the Gaussian mixture model.

Run experiments

To run the experiments from the paper, follow the instructions below.

Clone the repository and move into the directory:

git clone https://github.com/fifi-research/SESaMo.git
cd SESaMo

Create a python virtual environment and install the package:

python -m venv .venv
source .venv/bin/activate
pip install -e .

Run experiments with

cd experiments
python train.py -cp configs/<experiment> -cn <model>

Available <experiment>s are:

hubbard2x1
hubbard18x100
gaussian-mixture
broken-gaussian-mixture
complex-phi4
broken-complex-phi4
broken-scalar-phi4

Available <model>s are:

realnvp
vmonf
canonicalization
sesamo

The checkpoint, tensorboard, config and stats files are stored in the SESaMo/scripts/runs folder. After training is completed or interupted the distribution is plotted and saved as SESaMo/scripts/runs/.../samples.png

Citation

If you use SESaMo in your research, please consider citing our paper:

@article{kreit2025sesamo,
    title={SESaMo: Symmetry-Enforcing Stochastic Modulation for Normalizing Flows}, 
    author={Janik Kreit and Dominic Schuh and Kim A. Nicoli and Lena Funcke},
    year={2025},
    eprint={2505.19619},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    url={https://arxiv.org/abs/2505.19619}, 
}

About

SESaMo provides an extension to Normalizing Flows that enforces symmetries to the output distribution.

Resources

License

Stars

Watchers

Forks

Contributors

Languages