Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Loss implementation #5

Closed
laura-wang opened this issue Mar 22, 2021 · 2 comments
Closed

Loss implementation #5

laura-wang opened this issue Mar 22, 2021 · 2 comments

Comments

@laura-wang
Copy link

Hi, Jure Zbontar, great work!

I am trying to re-implement the code. But I just find some inconsistence of the loss implementations (rather than the scale-loss).

In the paper, Eq.1 shows that the redundancy reduction item is computed on the cross-correlation matrix , and in the current implementation in this repo, it is also computed based on the cross-correlation matrix c.
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss)
off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss)
loss = on_diag + self.args.lambd * off_diag

But in the pseudocode implementation in Algorithm 1 in the official paper, it seems that the loss is computed on the c_diff rather than c. Could you please illustrate more about this? Thanks a lot!

@jzbontar
Copy link
Contributor

Hi Laura,

the psudocode in the paper does in fact show a different way of computing the Barlow Twins loss compared to the method we use in main.py. However, both ways of computing the loss should produce the same result.

Consider the following code snippet:

import torch

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:]

c = torch.rand(3, 3)
lambd = 0.1

# compute loss as in the psuedocode
c_diff = (c - torch.eye(3)).pow(2)
off_diagonal(c_diff).mul_(lambd)
loss = c_diff.sum()
print('from pseudocode:', loss.item())

# compute loss as in main.py
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = off_diagonal(c).pow_(2).sum()
loss = on_diag + lambd * off_diag
print('from main.py:', loss.item())

@laura-wang
Copy link
Author

Thanks for your quick reponse! I've tried, they produce the same results. Thanks a lot!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants