Skip to content
Switch branches/tags
Go to file


Failed to load latest commit information.
Latest commit message
Commit time

This is the official code for the paper Hybrid Discriminative-Generative Training via Contrastive Learning.

This is code is built upon the official code of the great work JEM.

Contrastive learning and supervised learning have both seen significant progress and success. However, thus far they have largely been treated as two separate objectives, brought together only by having a shared neural network. In this paper we show that through the perspective of hybrid discriminative-generative training of energy-based models we can make a direct connection between contrastive learning and supervised learning. Beyond presenting this unified view, we show our specific choice of approximation of the energy-based loss outperforms the existing practice in terms of classification accuracy of WideResNet on CIFAR-10 and CIFAR-100. It also leads to improved performance on robustness, out-of-distribution detection, and calibration.



Includes codes and scripts for training our method UDGE and baselines using WideResNet.

For any questions/issues please contact


The experiment environemnt is provided in this conda env.


To train a model on CIFAR10 as in the paper

# Baseline model: JEM: log q(y|x) + log q(x)
python --lr .0001 --dataset cifar10 --optimizer adam --pyxce 1.0 --pxsgld 1.0 --sigma .03 --width 10 --depth 28 --warmup_iters 1000 --log_dir ./save --id YOUR_EXP_ID

# UDGE: log q(y|x) + log q(x|y)
python --lr .0001 --dataset cifar10 --optimizer adam --pyxce 1.0 --pxycontrast 1.0 --sigma .03 --width 10 --depth 28 --warmup_iters 1000 --log_dir ./save --id YOUR_EXP_ID

# UDGE + marginal likelihood: log q(y|x) + log q(x|y) + log q(x) (for generative tasks)
python --lr .0001 --dataset cifar10 --optimizer adam --pyxce 1.0 --pxycontrast 1.0 --pxsgld 1.0 --sigma .03 --width 10 --depth 28 --warmup_iters 1000 --log_dir ./save --id YOUR_EXP_ID


The experiments of the paper basically followed the setup of JEM, where a couple changes were made to accommodate energy-based model training. Compared with the original training setup, the changes are no batch normalization, Adam as optimizer instead of SGD with momentum, decay epochs difference, warmu learning rate, and decay rate difference. These changes can downgrade the classification accuracy, if one is more interested in accuracy, set the following configs

python --dataset cifar10 --pyxce 1.0 --pxycontrast 1.0 --width 10 --depth 28 --warmup_iters -1 --lr .1 --norm batch --optimizer sgd --batch_size 128 --dropout_rate .3 --weight_decay .0005 --log_dir ./save --id YOUR_EXP_ID

Details about experiment setup can be found in experiment section of our paper, where for each experiment we detailed whether batch normalization is used or not.

You can use the viskit from this open souce code to track the experiment progress.

Stay tuned for pretrained model and distributed code that scales better.


To generate a histogram of calibration error

python --load_path /PATH/TO/YOUR/ --eval eval_ece --dataset cifar10_test --log_dir ./save/evaluation/ece_hist

To generate a histogram of OOD scores

python --load_path /PATH/TO/YOUR/ --eval logp_hist --datasets cifar10 svhn --log_dir ./save/evaluation/ood_hist

To evaluate the classifier accuracy (on CIFAR10):

python --load_path /PATH/TO/YOUR/ --eval test_clf --dataset cifar_test --log_dir ./save/evaluation/cls

To do OOD detection (on CIFAR100)

python --load_path /PATH/TO/YOUR/ --eval OOD --ood_dataset cifar100 --log_dir ./save/evaluation/ood

To generate new unconditional samples

python --load_path /PATH/TO/YOUR/ --eval uncond_samples --log_dir ./save/evaluation/uncondsample --n_sample_steps {THE_MORE_THE_BETTER (1000 minimum)} --buffer_size 10000 --n_steps 40 --print_every 100 --reinit_freq 0.05

To generate new conditional samples

python --load_path /PATH/TO/YOUR/ --eval cond_samples --log_dir ./save/evaluation/condsample --n_sample_steps {THE_MORE_THE_BETTER (1000 minimum)} --buffer_size 10000 --n_steps 40 --print_every 10 --reinit_freq 0.05 --fresh_samples


This code is based on this open source implementation. The UDGE model is adapted from this open source code.


No description, website, or topics provided.




No releases published


No packages published