Skip to content

Commit

Permalink
[FLAVA] Separate out text and image encoders
Browse files Browse the repository at this point in the history
ghstack-source-id: 278663179f3ff6a73ede11e714f7f8e5e9a6a8bb
Pull Request resolved: #102
  • Loading branch information
ankitade committed Jun 23, 2022
1 parent 7b6b26c commit 860332c
Show file tree
Hide file tree
Showing 8 changed files with 842 additions and 530 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace_packages = True
install_types = True

# TODO (T116951827): Remove after fixing FLAVA type check errors
exclude = models/flava/flava_model.py|modules/losses/flava.py
exclude = models/flava/flava_model.py| models/flava/flava_text_encoder.py|modules/losses/flava.py

[mypy-PIL.*]
ignore_missing_imports = True
Expand Down
5 changes: 5 additions & 0 deletions test/models/flava/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
File renamed without changes.
167 changes: 167 additions & 0 deletions test/models/flava/test_flava_image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from test.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.models.flava.flava_image_encoder import (
ImageEmbeddings,
ImageTransformer,
)
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder


class TestFlavaImageEncoder(unittest.TestCase):
def setUp(self):
set_rng_seed(0)
torch.manual_seed(0)
self.image_embedding = ImageEmbeddings(
image_size=2, patch_size=1, hidden_size=2
)

encoder = FLAVATransformerEncoder(
hidden_size=2,
num_attention_heads=1,
num_hidden_layers=1,
hidden_dropout_prob=0.0,
intermediate_size=1,
attention_probs_dropout_prob=0.0,
)
self.image_encoder = ImageTransformer(
embeddings=self.image_embedding,
encoder=encoder,
layernorm=nn.LayerNorm(2),
pooler=nn.Identity(),
)

def test_embedding(self):
input = torch.ones(2, 3, 2, 2)
out = self.image_embedding(input)
assert_expected(
out,
torch.Tensor(
[
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
]
),
atol=1e-4,
rtol=0,
)

def test_image_encoder(self):
input = torch.ones(2, 3, 2, 2)
out = self.image_encoder(input)
assert_expected(
out.last_hidden_state,
torch.Tensor(
[
[
[-0.0040, 0.0040],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
],
[
[-0.0040, 0.0040],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
[-0.9840, 0.9840],
],
]
),
atol=1e-4,
rtol=0,
)
assert_expected(out.pooler_output, out.last_hidden_state)
assert_expected(
out.hidden_states,
(
torch.Tensor(
[
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
[
[0.0000, 0.0000],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
[0.0224, 0.0573],
],
]
),
torch.Tensor(
[
[
[0.0008, 0.0008],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
],
[
[0.0008, 0.0008],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
[0.0232, 0.0581],
],
]
),
),
atol=1e-4,
rtol=0,
)
assert_expected(
out.attentions,
(
torch.Tensor(
[
[
[
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
]
],
[
[
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
]
],
]
),
),
atol=1e-4,
rtol=0,
)
97 changes: 97 additions & 0 deletions test/models/flava/test_flava_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from test.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.models.flava.flava_text_encoder import (
TextEmbeddings,
TextTransformer,
)
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder


class TestFlavaTextEncoder(unittest.TestCase):
def setUp(self):
set_rng_seed(0)
self.text_embedding = TextEmbeddings(
hidden_size=2,
vocab_size=3,
max_position_embeddings=2,
hidden_dropout_prob=0,
)
emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]])
self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights)
self.text_embedding.position_embeddings = nn.Embedding.from_pretrained(
emb_weights
)
self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained(
emb_weights
)

encoder = FLAVATransformerEncoder(
hidden_size=2,
num_attention_heads=1,
num_hidden_layers=1,
hidden_dropout_prob=0.0,
intermediate_size=1,
attention_probs_dropout_prob=0.0,
)
self.text_encoder = TextTransformer(
embeddings=self.text_embedding,
encoder=encoder,
layernorm=nn.LayerNorm(2),
pooler=nn.Identity(),
)

def test_embedding(self):
input_ids = torch.IntTensor([[0, 1]])
out = self.text_embedding(input_ids)
expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
assert_expected(out, expected)

def test_text_transformer(self):
out = self.text_encoder(torch.IntTensor([[0, 1]]))

assert_expected(
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
)

assert_expected(
out.hidden_states,
(
torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]),
torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]),
),
atol=1e-4,
rtol=0.0,
)

assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),))

def test_text_transformer_attn_mask(self):
input_ids = torch.IntTensor([[0, 1]])
attn_mask = torch.IntTensor([[1, 0]])
out = self.text_encoder(input_ids, attention_mask=attn_mask)

assert_expected(
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
)

assert_expected(
out.hidden_states,
(
torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]),
torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]),
),
atol=1e-4,
rtol=0.0,
)

assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]))
assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),))

0 comments on commit 860332c

Please sign in to comment.