Skip to content

Official implementation for "Revisiting Discriminative vs. Generative Classifiers: Theory and Implications".

License

Notifications You must be signed in to change notification settings

ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers

Repository files navigation

Revisiting Discriminative vs. Generative Classifiers: Theory and Implications

This is the official implementation for Revisiting Discriminative vs. Generative Classifiers: Theory and Implications.

Dependencies

conda env create -f gen_vs_dis.yaml

Simulation Experiments

python data/generate_data.py
bash scripts/main_simulation.sh

Deep Learning Experiments

Source of Pre-trained models

  • ViT: checkpoint given by Google.

    wget https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz
  • ResNet: pre-trained ResNet50 supported by Pytorch.

  • CLIP: pre-trained checkpoint (backbone is ResNet50) supported by OpenAI (link).

  • MoCov2: checkpoint given by FAIR (link).

  • SimCLRv2: checkpoint given by Google (link). The tenorflow checkpoint can be converted to Pytorch version by using the codes in https://github.com/Separius/SimCLRv2-Pytorch.

  • MAE: checkpoint provided by FAIR (link).

  • SimMIM: checkpoint given by MSRA (link).

Extract features on CIFAR10/CIFAR100

For example, when dataset is CIFAR10 and method is MoCov2, we can run

export CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONPATH=$PYTHONPATH:`pwd`
python main_extract_features.py \
  --dataset cifar10 \
  --backbone moco_v2 \
  --bs 100 \
  --gpu 1 \

Analysis features

export CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONPATH=$PYTHONPATH:`pwd`
backbone_list=("clip" "resnet" "vit" "moco_v2" "simclr_v2" "mae" "simmim")
for backbone in ${backbone_list[@]};do
python plot.py \
  --dataset cifar10 \
  --backbone $backbone \
  --mode sigmas
done

backbone_list=("clip" "resnet" "vit" "moco_v2" "simclr_v2" "mae" "simmim")
for backbone in ${backbone_list[@]};do
python plot.py \
  --dataset cifar10 \
  --backbone $backbone \
  --mode kl
done

backbone_list=("clip" "resnet" "vit" "moco_v2" "simclr_v2" "mae" "simmim")
for backbone in ${backbone_list[@]};do
python plot.py \
  --dataset cifar10 \
  --backbone $backbone \
  --mode var_likelihood_diff
done

Compare logistic regression and naive Bayes on the extracted features

For example, when dataset is CIFAR10 and method is MoCov2, we can run

export CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONPATH=$PYTHONPATH:`pwd`
python main_train_offline.py \
  --dataset cifar10 \
  --backbone moco_v2 \
  --model lr_bgfs \
  --C 1 \
  --repeat 5 \
  --minmax

Hyperparameters Configuration

Detailed hyperparameters config can be found in scripts/main_plot.sh.

Acknowledgments

The code is developed based on the following repositories. We appreciate their nice implementations.

Method Repository
ViT https://github.com/google-research/vision_transformer
ResNet https://github.com/pytorch/pytorch
CLIP https://github.com/openai/CLIP
MoCo_v2 https://github.com/facebookresearch/moco
SimCLR_v2 https://github.com/google-research/simclr
SimCLR_v2 https://github.com/Separius/SimCLRv2-Pytorch
MAE https://github.com/facebookresearch/mae
SimMIM https://github.com/microsoft/SimMIM
logistic regression, naive Bayes https://github.com/scikit-learn/scikit-learn

About

Official implementation for "Revisiting Discriminative vs. Generative Classifiers: Theory and Implications".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages