Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about function forward() in NUWA please. #9

Open
Fitzwong opened this issue Mar 3, 2022 · 1 comment
Open

Questions about function forward() in NUWA please. #9

Fitzwong opened this issue Mar 3, 2022 · 1 comment

Comments

@Fitzwong
Copy link

Fitzwong commented Mar 3, 2022

I'm confused me that, in function forward() of class NUWA, the ground-truth video is fed to transformer and calculate the output video, which is different from function generate().

frame_embeddings = self.video_transformer(
            frame_embeddings,  # calculated from ground-truth video
            context = text_embeds,
            context_mask = text_mask
        )

So when training NUWA, the loss comes from logits. But the logits are not only from text, but ground-truth video (only one transformer layer, different from the auto-regressive model in generate function). Is that some kind of cheating when training? Or should I generate logits in the same way as in generate(), and then calculate loss to train?

@lucidrains
Copy link
Owner

so the reason is because we compress the video into a sequence of tokens, and then we have each token predict the next token, autoregressively

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants