Skip to content

ml-jku/MAE-CT

Repository files navigation

MAE-CT: Masked Autoencoder Contrastive Tuning

[Project Page] [arXiv] [Models] [BibTeX] [Follow-up work (MIM-Refiner)]

Pytorch implementation of Masked AutoEncoder Contrastive Tuning (MAE-CT) from our paper
Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget.

maect_schematic lowshot_vitl

This repository provides:

  • Pretrained checkpoints for MAE, MAE-CT and MAE-CTaug
  • Instructions to generate low-shot datasets for evaluation
  • Instructions on how to use our models as backbone

Pretrained Checkpoints

MAE reimplementation

Weights Pretrain Probe Probe k-NN
ViT-B/16 hp 66.7 log 51.1
ViT-L/16 hp 75.9 log 60.6
ViT-H/16 hp 78.0 log 61.1
ViT-H/14 original 77.2 log 58.9

MAE-CT

Encoder Pretrain Probe Probe k-NN
ViT-B/16 hp 73.5 log 64.1
ViT-L/16 hp 80.2 log 78.0
ViT-H/16 hp 81.5 log 79.4
ViT-H/14 hp 81.3 log 79.1

MAE-CTaug

Encoder Pretrain Probe Probe k-NN
ViT-B/16 hp 76.9 log 73.4
ViT-L/16 hp 81.5 log 79.1
ViT-H/16 hp 82.2 log 79.8
ViT-H/14 hp 82.0 log 78.9

Reproducability

  • Models can be trained using the hyperparameters provided here. Examples how to start training runs can be found here.
  • We provide instructions for reproducing our probing result in PROBING.md.

Use checkpoints as backbone for other tasks

The script eval_probe.py demonstrates how one can load our models from a checkpoint and use it for a downstream task. The script extracts the features of the encoder and feeds it to a linear probe as task, but the code can be adjusted for other downstream tasks as well.

Setup

Setup a conda environment: conda env create --file environment_linux.yml --name maect

We use FlashAttention (paper) to greatly accelerate computations. We recommend to install it, but this repo can also be used without FlashAttention (without modification).

Configuration of dataset paths and environment specific things

  • cp template_static_config.yaml static_config.yaml
  • edit values in static_config.yaml to your setup

For low-shot evaluations, we use the official splits from SimCLRv2 and MSN.

To generate these ImageNet subsets we use the ImageNetSubsetGenerator repository.

[Optional] Configure Weights & Biases

This repo uses Weights & Biases for experiment tracking, but offers an alternative in case you do not want to use it. By default W&B logging is disabled via the default_wandb_mode: disabled configuration in the static_config.yaml. You can enable it by setting default_wandb_mode: online in static_config.yaml or via the CLI --wandb_mode online.

If you enabled W&B logging, the W&B entity and project will (by default) be fetched from the wandb_config.yaml. You can create this via cp template_wandb_config.yaml wandb_config.yaml and adjust the values to your setup.

Run

To run your own experiments or reproduce our results you have to specify the desired hyperparameters via a yaml file. Start the training/evaluation run by specifying the following CLI arguments for main_train.py

  • --hp <YAML> (e.g. --hp yamls/mae/base16.yaml)
  • --devices <DEVICES> (e.g. --devices 0 to run on GPU0 or --devices 0,1,2,3 to run on 4 GPUs)

Example: Train MAE with ViT-B/16 on 4 GPUs: python main_train.py --hp yamls/mae/base16.yaml --devices 0,1,2,3

Output

Each yaml file will create a folder in your output directory (defined via output_path in static_config.yaml). The output directory is structured into subdirectories with the stage_name and the stage_id. Example: ~/output_path/pretrain/9j3kl092

The output directory of each run is organized as follows:

  • checkpoints: Model weights will be stored here (choose interval by adjusting the values of the checkpoint_logger in the yaml file of a run)
  • primitive: All metrics that are written to Weights & Biases are also stored locally here. If you don't want to use W&B you can parse metrics from the files within this directory.
  • log.txt: logfile
  • hp_resolved.yaml: a copy of the yaml file that was specified in the --hp CLI arg

The yamls used for our paper can be found here. Each step of MAE-CT requires its own yaml file where the later steps require a reference to a checkpoint of a previous step. This can be defined by changing the stage_id of the initializer objects within the yaml.

Examples

Train models

  • Pretrain a MAE on 8 GPUs (stage 1):
    python main_train.py --hp yamls/stage1_mae/large16.yaml --devices 0,1,2,3,4,5,6,7
  • Train a NNCLR head on frozen encoder features (stage 2) with 8 GPUs:
    • change the stage_id of the initializer in the encoder to the stage_id from stage 1
    • python main_train.py --hp yamls/stage2_maect_prepare_head/large16.yaml --devices 0,1,2,3,4,5,6,7
  • Apply contrastive tuning (stage 3) with 8 GPUs:
    • change the stage_id of the initializer in the encoder and the nnclr head to the stage_id from stage 2
    • python main_train.py --hp yamls/stage3_maect_contrastive_tuning/large16.yaml --devices 0,1,2,3,4,5,6,7

Evaluate pretrained models

  • Adapt the initializer of yamls_probe.yaml to the model you want to evaluate
  • python main_train.py --hp yamls/probe.yaml --devices 0,1,2,3

Citation

If you find this repository useful, please consider giving it a star ⭐ and cite us

@article{lehner2023maect,
      title={Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget}, 
      author={Johannes Lehner and Benedikt Alkin and Andreas Fürst and Elisabeth Rumetshofer and Lukas Miklautz and Sepp Hochreiter},
      journal={arXiv preprint arXiv:2304.10520},
      year={2023}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages