Skip to content

hankook/CLEL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Guiding Energy-based Models via Contrastive Latent Variables

PyTorch implementation for "Guiding Energy-based Models via Contrastive Latent Variables" (accepted as a Spotlight presentation in ICLR 2023)

스크린샷 2023-03-02 20 59 24

TL;DR: A simple yet effective framework for improving energy-based models (EBMs) via contrastive representation learning.

Install

conda create -n ebm python=3.9
conda activate ebm
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
conda install torchmetrics -c conda-forge
conda install ignite -c pytorch-nightly
pip install omegaconf
pip install torch-fidelity
pip install kornia==0.6.3
pip install tensorboard
pip install sklearn

Training

export CUDA_VISIBLE_DEVICES=0
python train.py configs/cifar10.yaml

You can modify options using YAML config files or key=value command-line arguments. See utils.parse_config() and OmegaConf for details.

Generation

python test_fid.py logs/cifar10/resnet_resnet18/ours/config.yaml use_ema=true

This command saves 50k generated samples into samples.pth in the log directory. You can use this file for official pytorch FID evaluation. Note that the FID value obtained from our code is similar to that from the official evaluation.

Out-of-distribution Detection

python test_ood.py logs/cifar10/resnet_resnet18/ours/config.yaml use_ema=true \
    ood_data.name=svhn ood_data.root=/data \
    model.beta=0.1 model.ebm_augmentation=none

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages