Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isn't loss only supposed to be calculated on masked tokens? #14

Open
EmaadKhwaja opened this issue Nov 8, 2022 · 6 comments
Open

Isn't loss only supposed to be calculated on masked tokens? #14

EmaadKhwaja opened this issue Nov 8, 2022 · 6 comments

Comments

@EmaadKhwaja
Copy link

In the training loop we have:

imgs = imgs.to(device=args.device)
logits, target = self.model(imgs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
loss.backward()

However, the output of the transformer is:

  _, z_indices = self.encode_to_z(x)
.
.
.
  a_indices = mask * z_indices + (~mask) * masked_indices

  a_indices = torch.cat((sos_tokens, a_indices), dim=1)

  target = torch.cat((sos_tokens, z_indices), dim=1)

  logits = self.transformer(a_indices)

  return logits, target

which means the returned target is the original unmasked image tokens.

The MaskGIT paper seems to suggest that loss was only calculated on the masked tokens

image

@darius-lam
Copy link

I've attempted both strategies for a simple MaskGIT on CIFAR10 but the generation quality seems to still be bad. There are tricks that the authors are not telling us in the paper for their training scheme

@xuesongnie
Copy link

I have the same issue. Why loss was calculated on all tokens?

@EmaadKhwaja
Copy link
Author

EmaadKhwaja commented Sep 3, 2023

@Lamikins I believe the training issues come from an error in the masking formula. I've ammended the error: #16.

@xuesongnie

@xuesongnie
Copy link

@EmaadKhwaja return logits[~mask], target[~mask] seems a bit problematic, we should calculate masked token loss return logits[mask], target[mask]

@EmaadKhwaja
Copy link
Author

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

@xuesongnie
Copy link

@xuesongnie it's because the mask calculated is applied to the wrong values. The other option would be to do r = math.floor(1-self.gamma(np.random.uniform()) * z_indices.shape[1]), but I don't like that because it's different from how the formula appears in the paper

Hi, bro. I find that poor performance after modifying return logits[mask], target[mask]. It is weird. I guess the embedding layer also needs to train the corresponding unmasked token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants