Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Barlow Twins implementation #230

Closed
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2eb8c82
- Added BarlowTwinsLoss and BarlowTwinsCriterion
OlivierDehaene Mar 10, 2021
a03f97f
Typos in docstrings
OlivierDehaene Mar 10, 2021
cf60b1e
Reduced MLP size for integration test
OlivierDehaene Mar 11, 2021
900178c
embedding_dim was not updated after MLP size change
OlivierDehaene Mar 11, 2021
341935e
Added debugging config on imagenette_160
OlivierDehaene Mar 14, 2021
dd19482
Tensorboard LOG_PARAMS and LOG_PARAMS_GRADIENTS to False
OlivierDehaene Mar 14, 2021
9402a52
- Added sync between processes for normalization
OlivierDehaene Mar 21, 2021
78e2899
Updated defaults.yaml
OlivierDehaene Mar 21, 2021
7110e33
Faster implementation
OlivierDehaene Mar 22, 2021
891fcec
Removed unused attribute
OlivierDehaene Mar 22, 2021
5cb45c2
Use gather_from_all instead of all_reduce
OlivierDehaene Mar 22, 2021
0857086
Removed unused import
OlivierDehaene Mar 22, 2021
e3dad35
Added docs
OlivierDehaene Mar 22, 2021
2125fb0
Backward pass was dysfunctional
OlivierDehaene Mar 23, 2021
e1fff28
Patched test
OlivierDehaene Mar 23, 2021
cc2866b
Patched test
OlivierDehaene Mar 23, 2021
664bad2
Cleaned SyncNormalizeFunction
OlivierDehaene Mar 23, 2021
1083fb6
Taking into account review comments
OlivierDehaene Mar 24, 2021
c74fb2c
reduce over the cross-correlation matrix is overall faster as it cuts…
OlivierDehaene Apr 7, 2021
32f182d
Added LARS
OlivierDehaene Apr 7, 2021
1a61226
Flake8
OlivierDehaene Apr 7, 2021
c1d2f9d
Patched config
OlivierDehaene Apr 7, 2021
1c113fe
Added Facebook copyrights
OlivierDehaene Apr 29, 2021
a5ba52b
Removed LARC mentions in doc as Barlow Twins uses its own version.
OlivierDehaene Apr 29, 2021
947c7fc
Linting + configs patch
OlivierDehaene Apr 30, 2021
28c294f
Patched barlow_twins.rst
OlivierDehaene Apr 30, 2021
fa7b119
misunderstanding around TRUNK and TRUNK_PARAMS
OlivierDehaene Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions MODEL_ZOO.md
Expand Up @@ -21,6 +21,7 @@ VISSL provides reference implementation of a large number of self-supervision ap
- [DeepClusterV2](#DeepClusterV2)
- [SwAV](#SwAV)
- [MoCoV2](#MoCoV2)
- [Barlow Twins](#BarlowTwins)

## Torchvision and VISSL

Expand Down Expand Up @@ -205,3 +206,9 @@ There is some standard deviation in linear results if we run the same eval sever
| Method | Model | PreTrain dataset | ImageNet top-1 acc. | URL |
| ------ | ----- | ---------------- | ------------------- | --- |
| [MoCo-v2](https://arxiv.org/abs/2003.04297) | RN50 - 200 epochs - 256 batch-size | ImageNet-1K | 66.4 | [model](https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch)

### BarlowTwins

| Method | Model | PreTrain dataset | ImageNet top-1 acc. | URL |
| ------ | ----- | ---------------- | ------------------- | --- |
| [Barlow Twins](https://arxiv.org/abs/2103.03230) | RN50 - 300 epochs - 2048 batch-size | ImageNet-1K | 70.75 | [model](https://dl.fbaipublicfiles.com/vissl/model_zoo/barlow_twins/barlow_twins_32gpus_4node_imagenet1k_300ep_resnet50.torch)
@@ -0,0 +1,119 @@
# @package _global_
config:
VERBOSE: True
LOG_FREQUENCY: 1
TEST_ONLY: False
TEST_MODEL: False
SEED_VALUE: 0
MULTI_PROCESSING_METHOD: forkserver
HOOKS:
PERF_STATS:
MONITOR_PERF_STATS: True
ROLLING_BTIME_FREQ: 313
TENSORBOARD_SETUP:
USE_TENSORBOARD: True
EXPERIMENT_LOG_DIR:
LOG_PARAMS: False
LOG_PARAMS_GRADIENTS: False
FLUSH_EVERY_N_MIN: 20
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_folder]
DATASET_NAMES: [imagenette_160_folder]
BATCHSIZE_PER_REPLICA: 256
LABEL_TYPE: sample_index # just an implementation detail. Label isn't used
TRANSFORMS:
- name: ImgReplicatePil
num_times: 2
- name: RandomResizedCrop
size: 128
- name: RandomHorizontalFlip
p: 0.5
- name: ImgPilColorDistortion
strength: 0.5
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilGaussianBlur
p: 1.0
radius_min: 0.1
radius_max: 2.0
prob: [1.0, 0.1]
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilRandomSolarize
p: 1.0
prob: [0.0, 0.2]
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
COLLATE_FUNCTION: simclr_collator
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenette_160/
DROP_LAST: True
TRAINER:
TRAIN_STEP_NAME: standard_train_step
METERS:
name: ""
MODEL:
TRUNK:
NAME: resnet
TRUNK_PARAMS:
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
RESNETS:
DEPTH: 50
HEAD:
PARAMS: [
["mlp", {"dims": [2048, 8192], "use_relu": True, "use_bn": True, "use_bias": False}],
["mlp", {"dims": [8192, 8192], "use_relu": True, "use_bn": True, "use_bias": False}],
["mlp", {"dims": [8192, 8192], "use_bias": False}],
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: apex
AMP_PARAMS:
USE_AMP: True
AMP_TYPE: pytorch
LOSS:
name: barlow_twins_loss
barlow_twins_loss:
lambda_: 0.0051
scale_loss: 0.024
embedding_dim: 8192
OPTIMIZER:
name: lars
weight_decay: 0.0000015
momentum: 0.9
exclude_bias_and_norm: True
num_epochs: 1000
regularize_bn: False
regularize_bias: False
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.2
base_lr_batch_size: 256
name: composite
schedulers:
- name: linear
start_value: 0.0
end_value: 0.2 # Automatically rescaled if needed
- name: cosine
start_value: 0.2 # Automatically rescaled if needed
end_value: 0.002 # Automatically rescaled if needed
update_interval: step
interval_scaling: [rescaled, fixed]
lengths: [0.01, 0.99] # 1000ep
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 1
NUM_PROC_PER_NODE: 8
RUN_ID: auto
INIT_METHOD: tcp
MACHINE:
DEVICE: gpu
CHECKPOINT:
AUTO_RESUME: True
CHECKPOINT_FREQUENCY: 100
116 changes: 116 additions & 0 deletions configs/config/pretrain/barlow_twins/barlow_twins_4node_resnet.yaml
@@ -0,0 +1,116 @@
# @package _global_
config:
VERBOSE: False
LOG_FREQUENCY: 10
TEST_ONLY: False
TEST_MODEL: False
SEED_VALUE: 0
MULTI_PROCESSING_METHOD: forkserver
HOOKS:
PERF_STATS:
MONITOR_PERF_STATS: True
ROLLING_BTIME_FREQ: 313
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 64
LABEL_TYPE: sample_index # just an implementation detail. Label isn't used
TRANSFORMS:
- name: ImgReplicatePil
num_times: 2
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
p: 0.5
- name: ImgPilColorDistortion
strength: 0.5
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilGaussianBlur
p: 1.0
radius_min: 0.1
radius_max: 2.0
prob: [ 1.0, 0.1 ]
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilRandomSolarize
p: 1.0
prob: [ 0.0, 0.2 ]
- name: ToTensor
- name: Normalize
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
COLLATE_FUNCTION: simclr_collator
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
DROP_LAST: True
TRAINER:
TRAIN_STEP_NAME: standard_train_step
METERS:
name: ""
MODEL:
TRUNK:
NAME: resnet
TRUNK_PARAMS:
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
RESNETS:
DEPTH: 50
HEAD:
PARAMS: [
["mlp", {"dims": [2048, 8192], "use_relu": True, "use_bn": True}],
["mlp", {"dims": [8192, 8192], "use_relu": True, "use_bn": True}],
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
["mlp", {"dims": [8192, 8192]}],
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: apex
GROUP_SIZE: 8
AMP_PARAMS:
USE_AMP: True
AMP_TYPE: pytorch
LOSS:
name: barlow_twins_loss
barlow_twins_loss:
lambda_: 0.0051
scale_loss: 0.024
embedding_dim: 8192
OPTIMIZER:
name: lars
weight_decay: 0.000001
momentum: 0.9
exclude_bias_and_norm: True
num_epochs: 1000
regularize_bn: False
regularize_bias: False
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.2
base_lr_batch_size: 256
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
name: composite
schedulers:
- name: linear
start_value: 0.0
end_value: 0.2 # Automatically rescaled if needed
- name: cosine
start_value: 0.2 # Automatically rescaled if needed
end_value: 0.002 # Automatically rescaled if needed
update_interval: step
interval_scaling: [rescaled, fixed]
lengths: [0.01, 0.99] # 1000ep
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 4
NUM_PROC_PER_NODE: 8
RUN_ID: xxxxxxxxxxxxxxxxxxxxxxxxxxxx
INIT_METHOD: tcp
NCCL_DEBUG: False
MACHINE:
DEVICE: gpu
CHECKPOINT:
AUTO_RESUME: True
CHECKPOINT_FREQUENCY: 5
CHECKPOINT_ITER_FREQUENCY: -1 # set this variable to checkpoint every few iterations
116 changes: 116 additions & 0 deletions configs/config/test/integration_test/quick_barlow_twins.yaml
@@ -0,0 +1,116 @@
# @package _global_
config:
VERBOSE: False
LOG_FREQUENCY: 1
TEST_ONLY: False
TEST_MODEL: False
SEED_VALUE: 0
MULTI_PROCESSING_METHOD: forkserver
HOOKS:
PERF_STATS:
MONITOR_PERF_STATS: True
PERF_STAT_FREQUENCY: 10
ROLLING_BTIME_FREQ: 5
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_filelist]
DATASET_NAMES: [imagenet1k_filelist]
BATCHSIZE_PER_REPLICA: 32
LABEL_TYPE: sample_index # just an implementation detail. Label isn't used
TRANSFORMS:
- name: ImgReplicatePil
num_times: 2
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
p: 0.5
- name: ImgPilColorDistortion
strength: 0.5
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilGaussianBlur
p: 1.0
radius_min: 0.1
radius_max: 2.0
prob: [1.0, 0.1]
- name: ImgPilMultiCropRandomApply
transforms:
- name: ImgPilRandomSolarize
p: 1.0
prob: [0.0, 0.2]
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
COLLATE_FUNCTION: simclr_collator
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
DATA_LIMIT: 500
DROP_LAST: True
COPY_DESTINATION_DIR: "/tmp/imagenet1k"
TRAINER:
TRAIN_STEP_NAME: standard_train_step
METERS:
name: ""
MODEL:
TRUNK:
NAME: resnet
TRUNK_PARAMS:
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
RESNETS:
DEPTH: 50
HEAD:
# Reduced MLP size to limit memory footprint
PARAMS: [
["mlp", {"dims": [2048, 128], "use_relu": True, "use_bn": True, "use_bias": False}],
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
["mlp", {"dims": [128, 128], "use_bias": False}],
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: pytorch
AMP_PARAMS:
USE_AMP: False
LOSS:
name: barlow_twins_loss
barlow_twins_loss:
lambda_: 0.0051
scale_loss: 0.024
embedding_dim: 128
OPTIMIZER:
name: lars
weight_decay: 0.0000015
momentum: 0.9
exclude_bias_and_norm: True
num_epochs: 2
regularize_bn: False
regularize_bias: False
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.2
base_lr_batch_size: 256
name: composite
schedulers:
- name: linear
start_value: 0.0
end_value: 0.2 # Automatically rescaled if needed
- name: cosine
start_value: 0.2 # Automatically rescaled if needed
end_value: 0.002 # Automatically rescaled if needed
update_interval: step
interval_scaling: [rescaled, fixed]
OlivierDehaene marked this conversation as resolved.
Show resolved Hide resolved
lengths: [0.1, 0.9] # 100ep
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 1
NUM_PROC_PER_NODE: 1
INIT_METHOD: tcp
RUN_ID: auto
MACHINE:
DEVICE: gpu
CHECKPOINT:
DIR: "."
AUTO_RESUME: True
CHECKPOINT_FREQUENCY: 1
OVERWRITE_EXISTING: true
1 change: 1 addition & 0 deletions dev/run_quick_tests.sh
Expand Up @@ -33,6 +33,7 @@ popd
# -----------------------------------------------------------------------------

CFG_LIST=(
"test/integration_test/quick_barlow_twins"
"test/integration_test/quick_deepcluster_v2"
"test/integration_test/quick_pirl"
"test/integration_test/quick_simclr"
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api/losses.rst
Expand Up @@ -79,6 +79,16 @@ vissl.losses.deepclusterv2_loss
:show-inheritance:


vissl.losses.barlow_twins_loss
-----------------------------------------------------

.. automodule:: vissl.losses.barlow_twins_loss
:members:
:undoc-members:
:show-inheritance:



vissl.losses.cross_entropy_multiple_output_single_target
-------------------------------------------------------------------

Expand Down