Skip to content

hse-cs/probaforms

Repository files navigation

Welcome to probaforms

PyPI version Tests Docs Downloads License: MIT

Probaforms is a python library of conditional Generative Adversarial Networks, Normalizing Flows, Variational Autoencoders and other generative models for tabular data. All models have a sklearn-like interface to enable rapid use in a variety of science and engineering applications.

Implemented conditional models

  • Variational Autoencoder (CVAE)
  • Wasserstein GAN (WGAN)
  • Real NVP

Installation

pip install probaforms

or

git clone https://github.com/HSE-LAMBDA/probaforms.git
cd probaforms
pip install -e .

or

poetry install

Basic usage

(See more examples in the documentation.)

The following code snippet generates a noisy synthetic data, fits a conditional generative model, sample new objects, and displays the results.

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from probaforms.models import RealNVP

# generate sample X with conditions C
X, y = make_moons(n_samples=1000, noise=0.1)
C = y.reshape(-1, 1)

# fit nomalizing flow model
model = RealNVP(lr=0.01, n_epochs=100)
model.fit(X, C)

# sample new objects
X_gen = model.sample(C)

# display the results
plt.scatter(X_gen[y==0, 0], X_gen[y==0, 1])
plt.scatter(X_gen[y==1, 0], X_gen[y==1, 1])
plt.show()

Support

Thanks to all our contributors