Towards self-supervised learning of global and object-centric representations

For the ICLR workshop on Objects, Structure, and Causality.


First, clone the repo:

git clone ''
cd iclr-osc-22

Create an environment from scratch:

conda create -y -n "${ENV_NAME}" -c pytorch -c conda-forge \
    python black isort pytest dill pre-commit \
    hydra-core colorlog submitit fvcore tqdm wandb sphinx \
    numpy pandas matplotlib seaborn tabulate scikit-learn scikit-image \
    jupyterlab jupyterlab_code_formatter jupyter_console ipywidgets \
    pytorch tensorflow-gpu cudatoolkit-dev cudnn \
    torchvision einops opt_einsum

conda activate "${ENV_NAME}"

python -m pip install \
    better_exceptions \
    sphinx-rtd-theme sphinx-autodoc-typehints \
    hydra_colorlog hydra-submitit-launcher namesgenerator \
    tensorflow-datasets transformers datasets \
    'git+' \
conda env config vars set BETTER_EXCEPTIONS=1
pre-commit install
pre-commit autoupdate

python -m pip install --editable .

Or create an environment using the provided dependency file:

conda env create -n "${ENV_NAME}" -f 'environment.yaml'
conda activate "${ENV_NAME}"
pre-commit install
python -m pip install --editable .


The project uses the "CLEVR with masks dataset", which is part of the Multi Object Datasets collection.

Download all datasets from a Google Cloud bucket (see original website for other options):

sudo apt install -y apt-transport-https ca-certificates gnupg
echo 'deb [signed-by=/usr/share/keyrings/] cloud-sdk main' |
    sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
curl '' |
    sudo apt-key --keyring /usr/share/keyrings/ add -
sudo apt update
sudo apt install -y google-cloud-sdk

gsutil -m cp -r gs://multi-object-datasets "${HOME}/"

Data loading and visualization notebooks:

Prepare the CLEVR dataset for training and evaluation by splitting the original TFRecords file in 3 parts (train+val only contain RGB images, test contains the full sample dict with object masks and attributes):

python -m --data-root "${HOME}/multi-object-datasets"


A training run with default parameters can be launched by executing following commands from the root of the repository (local training, single GPU):


The project uses Hydra as a configuration manager. All defaults can be listed as:

./ --cfg job

Individual parameters from the configuration can be changed on the command line:

./ \
  training.batch_size=16 \
  model.backbone.embed_dim=64 \
  model.backbone.patch_size='[4,4]' \
  model.backbone.embed_dim=64 \
  model.backbone.num_heads=8 \

All available configuration groups (e.g. different loss functions, attention types, learning rate schedules, etc.) can be found in the configs folder. For example, to train with an object-wise contrastive loss that takes all object tokens from all images as negatives and overfit on a small subset of images:

./ losses/l_objects=ctr_all +overfit=overfit

Running a parameter sweep in a SLURM environment is also supported, for example:

./ --multirun hydra/launcher=submitit_slurm +slurm=slurm \
  +losses=more_objects,more_global \
  model=bb_obj_global \
  model/obj_queries=sample \
  model.backbone.embed_dim=64,128,256 \'slurm_sweep' \
  lr_scheduler=linear1_cosine4_x5 \
  lr_scheduler.decay.end_lr=0.0003 \
  optimizer.start_lr=0.0007 \
  optimizer.weight_decay=0.0001 \
  model.backbone.num_heads=4,8 \
  model.backbone.num_layers=2,4,6 \


Here follow the main hyperparameters that can be configured for the experiments. A corresponding configuration file can be found in the configs folder.


  • backbone-global_fn-global_proj: global representation only. Backbone patch tokens can be aggregated either with global average pooling (avg) or an extra CLS token (cls)
  • backbone(-slot_fn-slot_proj)-global_fn-global_proj: after the backbone, two separate branches process global and object features. Backbone patch tokens can be aggregated either with global average pooling (avg) or an extra CLS token (cls)
  • backbone-slot_fn(-global_fn-global_proj)-slot_proj: after the backbone, the slot function extracts S object representations, these S feature tokens are further projected to yield object representations, furthermore these S tokens are average-pooled and processed to extract global features and projections. The backbone pooling is set to avg since a CLS token would not be ignored.

Object query implementations:

  • learned: learned query tokens in fixed number
  • sample: object queries are sampled either from a single Gaussian distributions with learned parameters, or a mixture of Gaussiand with uniform component weights
  • kmeans_euclidean: object queries are initialized as the K-Means clustering of backbone features. Number of clusters can be dynamically chosen, the distance function is a simply Euclidean distance.

Object function implementations:

  • slot-attention slot attention decoder (iterative)
  • cross-attention cross attention decoder
  • co attention decoder (not implemented yet)

Loss functions:

  • Global image representation:
    • Contrastive loss ctr (given one image, classify positively an augmented version of that image among B-2 other unrelated images in the batch)
    • Cosine similarity loss sim (given one image and its augmented version, maximise the cosine similarity between their projected representations)
  • Object representation:
    • Contrastive loss ctr_all (one token compared to all tokens in all images)
    • Contrastive loss ctr_img (one token compared to all tokens from its original image and the augmented version)
    • Cosine similarity loss sim_img (one token compared to all tokens from its original image and the augmented version)

Embedding dimension:

  • default 64 for everything with 2x factor for all MLP hidden layers
  • 128 and 256 also work well but require a smaller batch size especially when using 8 heads. Safe (dim, batch) pairs: (64, 64), (128, 16), (256, 8)
  • Interesting to try different size for the final projection head when using matching cosine similarity loss


