Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Implement Barlow Twins #229

Closed
3 tasks done
OlivierDehaene opened this issue Mar 10, 2021 · 11 comments
Closed
3 tasks done

Implement Barlow Twins #229

OlivierDehaene opened this issue Mar 10, 2021 · 11 comments

Comments

@OlivierDehaene
Copy link
Contributor

OlivierDehaene commented Mar 10, 2021

馃専 New SSL approach addition

Approach description

Implement Barlow Twins (arxiv link).

image

Pseudocode:

# f: encoder network
# lambda: weight on the off-diagonal terms
# N: batch size
# D: dimensionality of the representation
#
# mm: matrix-matrix multiplication
# off_diagonal: off-diagonal elements of a matrix
# eye: identity matrix

for x in loader: # load a batch with N samples
    # two randomly augmented versions of x
    y_a, y_b = augment(x)

    # compute representations
    z_a = f(y_a) # NxD
    z_b = f(y_b) # NxD

    # normalize repr. along the batch dimension
    z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD
    z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD

    # cross-correlation matrix
    c = mm(z_a_norm.T, z_b_norm) / N # DxD

    # loss
    c_diff = (c - eye(D)).pow(2) # DxD
    # multiply off-diagonal elems of c_diff by lambda
    off_diagonal(c_diff).mul_(lambda)
    loss = c_diff.sum()

    # optimization step
    loss.backward()
    optimizer.step()

Open source status

The model implementation is not yet available. However, it will be open sourced at: https://github.com/facebookresearch/barlowtwins

  • the model implementation is available
  • the model weights are available
  • who are the authors: Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, St茅phane Deny
@OlivierDehaene
Copy link
Contributor Author

I would be glad to work on this if you want.

@prigoyal
Copy link
Contributor

prigoyal commented Mar 10, 2021

EDIT: We are checking with the authors if they planned to provide an implementation themselves! I'll update here if authors plan to do it in which case I'd recommend to wait. We will communicate the plan.

Hi @OlivierDehaene , the proposal looks great and we would love to have this in VISSL. Go ahead :)

Also take a look at projects/ folder where you can add a README etc to ensure authors are properly credited for their work in VISSL.

@IgorSusmelj
Copy link

You could use the code here: https://github.com/IgorSusmelj/barlowtwins
We used the SimSiam model and just swapped the loss. We're still tuning some parameters since we're running the code on CIFAR10 and the paper does not mention experiments on this dataset.

@OlivierDehaene
Copy link
Contributor Author

@prigoyal sorry I only just saw your edit.

I started an implementation here.

@jzbontar
Copy link

Hi Oliver,

paper author here. Thanks for volunteering to implement Barlow Twins in VISSL! Please go ahead with the implementation.
I can review your code and, potentially, run it on our cluster to verify the accuracy.

Also, the official code should be out next week. We are making final tweaks to the codebase.

Best,
Jure

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 11, 2021

Hello Jure,

Cool! I think I'll wait for the code to come out before continuing as I'm not really sure about the value of some hyper-parameters.

Nice paper by the way :)

Cheers

@prigoyal
Copy link
Contributor

@jzbontar , it it possible to share the hyperparams setup in @OlivierDehaene 's PR? :)

@jzbontar
Copy link

@OlivierDehaene,

I also think that waiting for us to release to code would make sense so that the VISSL implementation and our official implementation match. Like I said, you can expect us to release the code sometime next week.

Out of curiosity, which hyper-parameters are you unsure about?

@OlivierDehaene
Copy link
Contributor Author

OlivierDehaene commented Mar 14, 2021

Out of curiosity, which hyper-parameters are you unsure about?

@jzbontar I only have questions regarding the optimization hyper-parameters:

  • Do you adjust the learning rate as a function of the minibatch size like in MoCo and SimCLR?
  • Do you use mixed precision? If so and with Apex, using which optimization level?

The first can be inferred from the "We follow the optimization protocol described in BYOL" sentence but I would prefer to be sure :)

@jzbontar
Copy link

  • This is how we set the learning rate:
def adjust_learning_rate(args, optimizer, loader, step):
    max_steps = args.epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = args.learning_rate * args.batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr
  • We use torch AMP and train in half precision.

@jzbontar
Copy link

We released the code for Barlow Twins.

facebook-github-bot pushed a commit that referenced this issue Apr 30, 2021
Summary:
## Required (TBC)

- [X] BarlowTwinsLoss and Criterion
- [x] Documentation
  - [X] Loss
  - [x] SSL Approaches + Index
  - [x] Model Zoo
  - [x] Project
- [x] Default configs
    - [x] pretrain
    - [X] test/integration
    - [X] debugging/pretrain
- [x] Benchmarks
  - [x] ImageNet: 70.75 for 300 epochs
  - [x] Imagenette 160: 88.8 Top1 accuracy

closes #229

Pull Request resolved: #230

Reviewed By: iseessel

Differential Revision: D28118605

Pulled By: prigoyal

fbshipit-source-id: 4436d6fd9d115b80ef5c5396318caa3cb26faadb
facebook-github-bot pushed a commit that referenced this issue Mar 9, 2022
Summary: Pull Request resolved: fairinternal/ssl_scaling#229

Reviewed By: iseessel

Differential Revision: D34690454

Pulled By: QuentinDuval

fbshipit-source-id: 6faf1e7ffcef28b4cd8feba6fc0f3434901810dc
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants