Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize CLIPArchitecture #89

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 81 additions & 0 deletions test/architectures/test_clip.py
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from test.test_utils import assert_expected, set_rng_seed
from torchmultimodal.architectures.clip import CLIPArchitecture


class TestCLIPArchitecture:
@pytest.fixture
def start(self):
set_rng_seed(1234)

query_encoder = torch.nn.Linear(5, 3)
retrieval_encoder = torch.nn.Linear(4, 3)
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
encoders = torch.nn.ModuleDict(
{"query": query_encoder, "retrieval": retrieval_encoder}
)
clip = CLIPArchitecture(encoders=encoders)

input_query = torch.randint(1, 8, (2, 5), dtype=torch.float)
input_retrieval = torch.randint(1, 8, (2, 4), dtype=torch.float)

return clip, input_query, input_retrieval

def test_forward(self, start):
clip, input_query, input_retrieval = start
assert isinstance(clip, torch.nn.Module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if it's necessary to ensure that clip is a Module, I would remove this


out = clip(modalities={"query": input_query, "retrieval": input_retrieval})
assert (
hasattr(out, "query_embeddings")
and hasattr(out, "retrieval_embeddings")
and len(out.__dict__) == 2
)

actual_q_embedding = out.query_embeddings
actual_r_embedding = out.retrieval_embeddings
expected_q_embedding = torch.Tensor(
[[-0.8066, -0.1749, 0.5647], [-0.7709, -0.1118, 0.6271]]
)
expected_r_embedding = torch.Tensor(
[[-0.1719, 0.7932, 0.5842], [-0.2805, 0.8761, -0.3921]]
)
assert_expected(
actual=actual_q_embedding, expected=expected_q_embedding, rtol=0, atol=1e-4
)
assert_expected(
actual=actual_r_embedding, expected=expected_r_embedding, rtol=0, atol=1e-4
)

def test_forward_missing_input(self, start):
clip, input_query, _ = start
assert isinstance(clip, torch.nn.Module)
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(AssertionError):
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
clip(modalities={"query": input_query})

def test_forward_extra_input(self, start):
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
clip, input_query, input_retrieval = start
assert isinstance(clip, torch.nn.Module)

with pytest.warns(UserWarning):
out = clip(
modalities={
"query": input_query,
"retrieval": input_retrieval,
"extra": torch.Tensor([1]).to(dtype=float),
}
)

assert (
hasattr(out, "query_embeddings")
and hasattr(out, "retrieval_embeddings")
and len(out.__dict__) == 2
)
44 changes: 26 additions & 18 deletions torchmultimodal/architectures/clip.py
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from dataclasses import make_dataclass
from typing import Dict

import torch
Expand All @@ -20,32 +22,38 @@ class CLIPArchitecture(nn.Module):
encoders, while the loss is implemented in ContrastiveLossWithTemperature.


Args: vision_encoder (nn.Module): Instantiated vision encoder.
See e.g. ResNetForCLIP class.
text_encoder (nn.Module): Instantiated text encoder.
See CLIPTextEncoder class.
Args: encoders (nn.ModuleDict): Dict of instantiated encoders, keyed by modality.
E.g. {"vision": ResNetForCLIP(), "text": CLIPTextEncoder()}

Inputs: image (Tensor): Tensor containing image features.
text (Tensor): Tensor containing text features.
Inputs: modalities (Dict[str, Tensor]): Dict of Tensor features, keyed by modality.
Must contain one entry for every modality in ``encoders``.

Output: CLIPOutput object with fields ``{modality}_embeddings`` for every modality
in ``encoders``.
"""

def __init__(
self,
vision_encoder: nn.Module,
text_encoder: nn.Module,
encoders: nn.ModuleDict,
):
super().__init__()
self.vision_encoder = vision_encoder
self.text_encoder = text_encoder
self.encoders = nn.ModuleDict({k: encoders[k] for k in sorted(encoders.keys())})
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
modalities: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:

img_embeddings = self.vision_encoder(image)
text_embeddings = self.text_encoder(text)
img_embeddings = F.normalize(img_embeddings)
text_embeddings = F.normalize(text_embeddings)
return {"image": img_embeddings, "text": text_embeddings}
embeddings = {}
for key, encoder in self.encoders.items():
assert key in modalities, f"{key} missing in input"
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
embeddings[f"{key}_embeddings"] = F.normalize(encoder(modalities[key]))
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
for key in modalities.keys():
if key not in self.encoders:
warnings.warn(f"Missing encoder for extra input {key}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your choice to raise a warning here makes sense. We might also want to do the same in late_fusion for the sake of consistency (doesn't have to be done in this PR though)


# Return a dataclass object instead of a dictionary
clip_output = make_dataclass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason we want to return a dataclass here? Imo one of the main advantages of dataclasses is that they follow a fixed schema, so returning one dynamically feels a bit unnatural.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it feels unnatural (it took me a while to figure out how to make a dataclass dynamically). I used a dataclass to match the pattern set by other modules, but now I realize a lot of modules don't have it, so unless anyone is a strong proponent of output classes then I can return a dictionary instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The creation is dynamics but once created the schema is fixed.

An advantage of dataclass is that we can use it for type hints.

The counterpart to dataclass is to use NamedTuple if we don't intend for inheritance. But no strong preference here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer NamedTuple for consistency with all our other model outputs, unless there's a clear advantage of using dataclass over NamedTuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating NamedTuple dynamically causes issues with mypy such that i have to include # type: ignore on the namedtuple creation line, but besides that i don't see other relative advantages of dataclass

"CLIPOutput",
[(f"{k}_embeddings", torch.Tensor) for k in self.encoders.keys()],
sophiazhi marked this conversation as resolved.
Show resolved Hide resolved
)
return clip_output(**embeddings)