Skip to content
/ trex Public

PyTorch implementation of the paper "No reason for no supervision: Improving the generalization of supervised models"

License

Notifications You must be signed in to change notification settings

naver/trex

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

No reason for no supervision: Improved generalization in supervised models

Project Website Paper (arXiv) Paper (ICLR 2023 - notable top 25%)

In this repository, we provide:

  • Several pretrained t-ReX and t-ReX* models in PyTorch (see here).
  • Code for training our t-ReX and t-ReX* models on the ImageNet-1K dataset in PyTorch (see here).
  • Code for running transfer learning evaluations of pretrained models via linear classification over pre-extracted features on 16 downstream datasets (see here).

If you find this repository useful, please consider citing us:

@inproceedings{sariyildiz2023improving,
    title={No Reason for No Supervision: Improved Generalization in Supervised Models},
    author={Sariyildiz, Mert Bulent and Kalantidis, Yannis and Alahari, Karteek and Larlus, Diane},
    booktitle={International Conference on Learning Representations},
    year={2023},
}

Model Zoo

In the table below, we provide links for several pretrained t-ReX and t-ReX* models. These are the models which produce the results reported in the paper, as well as the models reproduced with the cleaner codebase released in this repo. Transfer performance of these models are averaged over 15 datasets, which include two additions, i.e., the i-Naturalist datasets, to the 13 transfer datasets we mainly used in the paper. To perform transfer evaluations, see the corresponding section of this readme.

Model ResNet50
Checkpoint
Full
Checkpoint
ImageNet-1K
(Top-1 %)
Average Transfer
(Log odds)
Models reported in the paper
t-ReX Link 78.0 1.1704
t-ReX* Link 80.2 0.8829
Models reproduced with this code base
t-ReX Link Link 77.9 1.1664
t-ReX* Link Link 80.2 0.8800

Full checkpoints contain a separate state dictionary for the model, optimizer and gradient scaler (for mixed precision). We share them for reference. Whereas, you can use the ResNet50 checkpoints simply by

import torch as th
from torchvision.models import resnet50
ckpt = th.load("trex.pth", "cpu")
net = resnet50()
msg = net.load_state_dict(ckpt, strict=False)
assert msg.missing_keys == ["fc.weight", "fc.bias"] and msg.unexpected_keys == []

Training t-ReX models

Installation

We developed this code by using a recent version of PyTorch, torchvision and Tensorboard. We recommend creating a new conda environment to manage these packages.

conda create -n trex
conda activate trex
conda install pytorch=1.13.1 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
pip install tensorboard

Dataset

We train our models on the ILSVRC-2012 dataset (also called ImageNet-1K). It is available on the ImageNet website. Once you download the dataset, make sure that data_dir=/path/to/imagenet contains train and val directories, each including 1000 sub-directories for the images of the ImageNet-1K classes.

Training commands

Below, we provide commands for training plain t-ReX and t-ReX-OCM models on ImageNet-1K. Note that the results we report in the paper are obtained by 100 epoch trainings over 4 GPUs each processing a batch of 64 samples. If you want to use a less number of GPUs or increase the batch size, etc., see the arguments of main.py.

Commands for t-ReX-OCM models

t-ReX-OCM models are defined by Equation-2 of the paper.

Command for training a t-ReX-OCM-1 model (named t-ReX* in the paper)
data_dir=/path/to/imagenet
output_dir=/path/where/to/save/checkpoints
export CUDA_VISIBLE_DEVICES=0,1,2,3  # change accordingly the <nproc_per_node> argument below

python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 main.py  \
    --output_dir=${output_dir} \
    --data_dir=${data_dir} \
    --seed=${RANDOM} \
    --pr_hidden_layers=1 \
    --mc_global_scale 0.40 1.00 \
    --mc_local_scale 0.05 0.40
Command for training a t-ReX-OCM-3 model (named t-ReX in the paper)
data_dir=/path/to/imagenet
output_dir=/path/where/to/save/checkpoints
export CUDA_VISIBLE_DEVICES=0,1,2,3  # change accordingly the <nproc_per_node> argument below

python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 main.py  \
    --output_dir=${output_dir} \
    --data_dir=${data_dir} \
    --seed=${RANDOM} \
    --pr_hidden_layers=3 \
    --mc_global_scale 0.25 1.00 \
    --mc_local_scale 0.05 0.25

Commands for plain t-ReX models

Plain t-ReX models are defined by Equation-1 of the paper. Compared to the commands for training OCM models above, we just add the --memory_size=0 argument, which disables the OCM part.

Command for training a plain t-ReX-1 model
data_dir=/path/to/imagenet
output_dir=/path/where/to/save/checkpoints
export CUDA_VISIBLE_DEVICES=0,1,2,3  # change accordingly the <nproc_per_node> argument below

python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 main.py  \
    --output_dir=${output_dir} \
    --data_dir=${data_dir} \
    --seed=${RANDOM} \
    --pr_hidden_layers=1 \
    --mc_global_scale 0.40 1.00 \
    --mc_local_scale 0.05 0.40 \
    --memory_size=0
Command for training a plain t-ReX-3 model
data_dir=/path/to/imagenet
output_dir=/path/where/to/save/checkpoints
export CUDA_VISIBLE_DEVICES=0,1,2,3  # change accordingly the <nproc_per_node> argument below

python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 main.py  \
    --output_dir=${output_dir} \
    --data_dir=${data_dir} \
    --seed=${RANDOM} \
    --pr_hidden_layers=3 \
    --mc_global_scale 0.25 1.00 \
    --mc_local_scale 0.05 0.25 \
    --memory_size=0

Transfer learning evaluation suite

We provide the evaluation code under the transfer folder. Please navigate there.

Acknowledgement

Our implementation builds on several public code repositories such as DINO, MoCo and the PyTorch examples. We thank all the authors and developers for making their code accessible.

About

PyTorch implementation of the paper "No reason for no supervision: Improving the generalization of supervised models"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published