Skip to content

Exploring generative modeling techniques with JAX on synthetic 2D datasets using EBM, DSM, VAE, and GAN.

Notifications You must be signed in to change notification settings

coderconroy/generative-modeling-2d

Repository files navigation

Generative Models on Synthetic Datasets

Overview

This repository contains implementations of several generative models trained on various two-dimensional synthetic datasets. The models implemented include:

  • Energy-Based Models (EBM)
  • Denoising Score Matching (DSM)
  • Variational Autoencoder (VAE)
  • Generative Adversarial Network (GAN)

Each model has been applied to generate samples from the following synthetic datasets:

  • Checkerboard
  • Gaussian Mixture
  • Spiral
  • Pinwheel

The project aims to explore the capabilities and performance of these models in generating new data points that mimic the characteristics of the original datasets.

Models Description

Energy-Based Models (EBM)

EBMs learn an energy function that assigns lower energy to points near the data distribution and higher energy elsewhere. These models are trained using a sampling procedure that approximates the gradient of the log-likelihood.

Denoising Score Matching (DSM)

DSM trains a model to estimate the gradients of the log density with respect to the data, using a denoising criterion that compares the model's score estimates to those obtained from perturbed samples. The DSM implementation leverages JAX for automatic differentiation and GPU acceleration, Flax for defining neural network models, Haiku for another approach to network definition, and Optax for gradient-based optimization. These tools collectively enable efficient mathematical operations and model training processes.

Variational Autoencoder (VAE)

VAEs are generative models that use a latent variable model to approximate the data distribution. They are trained to maximize the evidence lower bound (ELBO) on the data likelihood.

Generative Adversarial Network (GAN)

GANs consist of two competing networks: a generator that creates samples intended to come from the data distribution, and a discriminator that tries to distinguish between the samples generated by the generator and the real data.

About

Exploring generative modeling techniques with JAX on synthetic 2D datasets using EBM, DSM, VAE, and GAN.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published