Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLAVA] Separate out text and image encoders #102

Closed
wants to merge 3 commits into from

Conversation

ankitade
Copy link
Contributor

@ankitade ankitade commented Jun 20, 2022

Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests

Test plan:
pytest

Stack from ghstack (oldest at bottom):

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 20, 2022
ankitade added a commit that referenced this pull request Jun 20, 2022
ghstack-source-id: 5212d85775105b8176a13aedbcbf576fb3dc3291
Pull Request resolved: #102
ankitade added a commit that referenced this pull request Jun 23, 2022
ghstack-source-id: 278663179f3ff6a73ede11e714f7f8e5e9a6a8bb
Pull Request resolved: #102
@codecov-commenter
Copy link

codecov-commenter commented Jun 23, 2022

Codecov Report

Merging #102 (fbb769b) into gh/ankitade/3/base (d0a347a) will increase coverage by 0.19%.
The diff coverage is 83.24%.

@@                  Coverage Diff                   @@
##           gh/ankitade/3/base     #102      +/-   ##
======================================================
+ Coverage               88.85%   89.04%   +0.19%     
======================================================
  Files                      33       35       +2     
  Lines                    1722     1744      +22     
======================================================
+ Hits                     1530     1553      +23     
+ Misses                    192      191       -1     
Impacted Files Coverage Δ
...orchmultimodal/models/flava/flava_image_encoder.py 74.54% <74.54%> (ø)
torchmultimodal/models/flava/flava_text_encoder.py 94.04% <94.04%> (ø)
torchmultimodal/models/flava/flava_model.py 92.92% <100.00%> (+5.90%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d0a347a...fbb769b. Read the comment docs.

Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests




[ghstack-poisoned]
ankitade added a commit that referenced this pull request Jun 23, 2022
ghstack-source-id: cbef8d57b722b36a66fa0b4155d2a9f82c0b2fe0
Pull Request resolved: #102
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to self.patch_embedding.patch_size

pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
attentions=output.attentions,
image_labels=image_labels,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this, image_labels is not a field

@ankitade ankitade marked this pull request as ready for review June 23, 2022 05:13
@ankitade
Copy link
Contributor Author

@ankitade has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Comment on lines +21 to +22
set_rng_seed(0)
torch.manual_seed(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is redundant

atol=1e-4,
rtol=0,
)
assert_expected(out.pooler_output, out.last_hidden_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe a bit confusing to do the transitive thing here. Can you just set the expected result to a var and compare both last_hidden_state and pooler_output to that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt the transitive thing actually making it clear which values should line up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just personal preference I guess

[
[
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.1999 due to rounding error? Maybe just make them all 0.2 for readability?

Image to Patch Embedding.
"""

def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add types

Copy link
Contributor Author

@ankitade ankitade Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing will get handled separately in another PR (either Rafi or I will do it)

Comment on lines +117 to +124
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(
1, int(math.sqrt(n)), int(math.sqrt(n)), dim
).permute(0, 3, 1, 2),
scale_factor=(h0 / math.sqrt(n), w0 / math.sqrt(n)),
mode="bicubic",
align_corners=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's with all the chaining here? Can we split it up a bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I want to avoid touching core logic as part of this refactor. I have a feeling some of the image encoder is going to get deleted in the end

Comment on lines +42 to +43
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important, but I'm seeing this comment a lot in our embeddings classes. Is it actually adding any value? If not, maybe remove it

Comment on lines +192 to +194
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a class method. Katrina is adding it as a util in #99

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this might need some more thought. do we want to handle attention mask HF style?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not touch the FLAVA's core logic in general. Feel free to refactor stuff around the logic but it would be good to avoid touching the logic itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we discussed this @apsdehal, it should be fine as long as tests are added, they pass and ckpt is kept in sync.

In general we need to refactor for somethings. examples: the projection being part of only pretraining model (but its needed for zero shot) or trying to use a common implementation of transformers. Will add you to all the PRs so we can address any concerns you have.

Comment on lines +160 to +164
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here, can we remove this?

@ankitade
Copy link
Contributor Author

general comment, not going to introduce deeper refactoring changes in this PR since first time we are adding tests for the encoders (will get handled as part of unification / cleanup)

ankitade added a commit to ankitade/multimodal that referenced this pull request Jun 24, 2022
Summary:
Pull Request resolved: facebookresearch#102

Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests

Test Plan: pytest

Differential Revision: D37407717

Pulled By: ankitade

fbshipit-source-id: 7ebacb969b864438372ff9304a46ed2f4be4c906
ankitade added a commit to ankitade/multimodal that referenced this pull request Jun 24, 2022
Summary:
Pull Request resolved: facebookresearch#115

Pull Request resolved: facebookresearch#102

Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests

Test Plan: pytest

Reviewed By: ebsmothers

Differential Revision: D37407717

Pulled By: ankitade

fbshipit-source-id: cd9e120eea4890bb813cb8bbe77577f9e2c77c40
ankitade added a commit to ankitade/multimodal that referenced this pull request Jun 24, 2022
Summary:
Pull Request resolved: facebookresearch#115

Pull Request resolved: facebookresearch#102

Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests

Test Plan: pytest

Reviewed By: ebsmothers

Differential Revision: D37407717

Pulled By: ankitade

fbshipit-source-id: bb56e29c798081e8fb8f04ff9307d0f8903628a8
@facebook-github-bot facebook-github-bot deleted the gh/ankitade/3/head branch June 28, 2022 14:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants