Skip to content

ly1998117/HybridCBM

Repository files navigation

HybridCBM

Code for the paper "Hybrid Concept Bottleneck Models"

Environments

We run our experiments using Python 3.11. You can install the required packages using:

conda create --name hybridcbm python=3.11
conda activate hybridcbm
pip install -r requirements.txt

Since we use cuML for linear probe, you need to install the cuML package. You can install it using:

pip install \
    --extra-index-url=https://pypi.nvidia.com \
    "cudf-cu12==24.10.*" "dask-cudf-cu12==24.10.*" "cuml-cu12==24.10.*" \
    "cugraph-cu12==24.10.*" "nx-cugraph-cu12==24.10.*" "cuspatial-cu12==24.10.*" \
    "cuproj-cu12==24.10.*" "cuxfilter-cu12==24.10.*" "cucim-cu12==24.10.*" \
    "pylibraft-cu12==24.10.*" "raft-dask-cu12==24.10.*" "cuvs-cu12==24.10.*" \
    "nx-cugraph-cu12==24.10.*"

Datasets

The file structure of each dataset folder is:

dataset/
|-- images/
|-- concepts/
|-- splits/

Note: You need to put all images each dataset in images/. In concepts/, there is a csv file concepts.csv is the candidate concepts. splits/ stores the train/val/test splits in csv format. You need to download them and store in the corresponding folder: datasets/{dataset name}/images/.

Here is the downloading links for each dataset:

Category Dataset Description Download Link
Translator datasets MSCOCO Train captions captions
Train images images
Val images images
ConceptNet ConceptNet 5.7 ConceptNet 5.7
CBM datasets Aircraft Aircraft dataset click to download
CIFAR-10 CIFAR-10 dataset click to download
CIFAR-100 CIFAR-100 dataset click to download
CUB CUB dataset click to download
DTD DTD dataset click to download
Flower Flower dataset click to download
Food Food dataset click to download
HAM10000 HAM10000 dataset click to download
ImageNet ImageNet dataset click to download
RESISC45 RESISC45 dataset click to download
UCF101 UCF101 dataset click to download

Directories

  • config/ is the config files for all experiments, including linear probe (config/LProbe) and HybridCBM (config/HybridCBM). You can modify the config files to change the system arguments.
  • exp/ is the work directories of the experiments. The config files and model checkpoints will be saved in this folder.
  • models/ is the models:
    • Linear Probe: models/linearProb.py
    • HybridCBM: models/cbm/linearCBM.py
    • concept selection functions: models/conceptBank/concept_select.py
  • Other files:
    • datasets/preprocess.py is the preprocess file for the datasets.
    • datasets/dataloader.py is the dataloaders for Hybrid and Linear Probe, respectively.
    • trainLinear.py is the interface to run all experiments.
    • train_translator.sh is the bash file to train the translator. And will save the model in weights/translator.

Linear Probe

For example, for CUB dataset 1-shot with ViT-L/14 image encoder, the command is:

python train_probe.py --config config/LProbe/CUB.py --cfg-options dataset=CUB --cfg-options n_shots=1 --cfg-options clip_model=ViT-L/14 --cfg-options exp_root=exp/LProbe/CUB/ViT-L_14

The code will automatically encode the images and run a hyperparameter search on the L2 regularization using the dev set. The best validation and test performance will be saved in the exp/LProbe/CUB/ViT-L_14/train_1shot.log.

Translator Training

To train the translator, run the following command:

bash train_translator.sh

you may need to modify the train_translator.sh to set the dataset and clip model you want to train. The model will be saved in weights/translator. The set the path of the model in the config/HybridCBM/base.py file. For example, if you train the translator with ViT-L/14 model, you need to set the translator_path in the config/HybridCBM/base.py file as:

translator_path = weights/translator/{clip_model.replace("/", "_")}-AUG_True/translator.pt'

We provide the pretrained ViT-L/14 translator models checkpoints in this link.

Hybrid Training

To train the LaBo, run the following command:

python trainLinear.py --config config/HybridCBM/CUB/CUB_allshot.py 
--cfg-options clip_model=ViT-L/14 
--cfg-options concept_select_fn=submodular 
--cfg-options num_concept_per_class=10 
--cfg-options dynamic_concept_ratio=0.5 
--cfg-options lambda_discri_alpha=2 
--cfg-options lambda_discri_beta=0.1 
--cfg-options lambda_ort=0.1 
--cfg-options lambda_align=0.01

After reaching the maximum epochs, the checkpoint with the highest validation accuracy and the corresponding config file will be saved to exp/HybridCBM/{DATASET}_Zero_L1.

Hybrid Testing

To test the HybridCBM, run the following command:

python trainLinear.py --exp_root <exp dir path> --test

For example, to test the HybridCBM on the CUB dataset, run the following command:

python trainLinear.py --exp_root exp/HybridCBM/CUB_Zero_L1 --test

Acknowledgments

This repository is heavily based on the following repositories: CLIP. ClipCap, DeCap. Labo.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published