Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TextEncoder in coca encode_image #321

Merged
merged 25 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 38 additions & 73 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class MultimodalCfg(CLIPTextCfg):
dim_head: int = 64
heads: int = 8
n_queries: int = 256
dim_latents: int = None
attn_pooler_heads: int = 8
latent_dim: int = 512


def _build_input_dependent_text_tower(
Expand Down Expand Up @@ -76,62 +77,54 @@ def __init__(
):
super().__init__()


norm_layer = (
LayerNormFp32
if cast_dtype in (torch.float16, torch.bfloat16)
else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
self.context_length = self.positional_embedding.shape[0] - 1

self.cls_token = nn.Parameter(torch.randn(embed_dim))
self.visual = _build_vision_tower(
embed_dim, vision_cfg, quick_gelu, cast_dtype
)
self.heads = text_cfg["heads"]

self.multimodal_decoder, multimodal_cfg = _build_input_dependent_text_tower(
embed_dim, multimodal_cfg, quick_gelu, cast_dtype
)

text = _build_input_dependent_text_tower(multimodal_cfg.width, text_cfg, quick_gelu, cast_dtype, multimodal=False)
self.text = text
gpucce marked this conversation as resolved.
Show resolved Hide resolved
self.visual = _build_vision_tower(
multimodal_cfg.width, vision_cfg, quick_gelu, cast_dtype
)
self.img_attn_pool = AttentionalPooler(
multimodal_cfg.width, multimodal_cfg.heads, n_queries=n_queries + 1
multimodal_cfg.width, multimodal_cfg.attn_pooler_heads, n_queries=n_queries + 1 # extra query for contrastive_loss
)

self.img_attn_pool_norm = norm_layer(embed_dim)

self.dim_latents = multimodal_cfg.dim_latents if multimodal_cfg.dim_latents else multimodal_cfg.width
self.to_text_latent = nn.Linear(embed_dim, self.dim_latents, bias=False)
self.img_attn_pool_norm = norm_layer(multimodal_cfg.width)
vocab_size = (
self.text.config.vocab_size
if "hf_model_name" in text_cfg and text_cfg["hf_model_name"] is not None
else multimodal_cfg.vocab_size # for hf models
)

self.to_logits = nn.Sequential(
norm_layer(embed_dim), nn.Linear(embed_dim, self.vocab_size, bias=False)
norm_layer(multimodal_cfg.width), nn.Linear(multimodal_cfg.width, vocab_size, bias=False)
)

self.to_img_latent = nn.Linear(multimodal_cfg.width, multimodal_cfg.latent_dim, bias=False)
self.to_txt_latent = nn.Linear(multimodal_cfg.width, multimodal_cfg.latent_dim, bias=False)
# tie embedding weights and projection
self.to_logits[-1].weight = self.token_embedding.weight
# self.to_logits[-1].weight = self.text.token_embedding.weight
gpucce marked this conversation as resolved.
Show resolved Hide resolved

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.pad_id = 0

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
self.text.set_grad_checkpointing(enable)
self.multimodal_decoder.grad_checkpointing = enable

def encode_image(self, images, normalize=True, return_tokens=False):
x = self.visual(images, output_tokens=True)

if hasattr(self.visual, "ln_post"):
if hasattr(self.visual, "ln_post") and self.visual.ln_post is not None:
x = self.visual.ln_post(x)

if hasattr(self.visual, "proj") and self.visual.proj is not None:
Expand All @@ -140,73 +133,48 @@ def encode_image(self, images, normalize=True, return_tokens=False):
x = self.img_attn_pool(x, x)
x = self.img_attn_pool_norm(x)

image_latent = x[:, 0]
image_latent = self.to_img_latent(x[:, 0])
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent

return (image_latent, x[:, 1:]) if return_tokens else image_latent

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(*cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask

def encode_text(self, text, normalize=True, return_tokens=False):
text = text[:, :-1] # make space for CLS token
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = torch.cat(
[
x + self.positional_embedding[:seq_len, :].to(cast_dtype),
self._repeat(self.cls_token + self.positional_embedding[-1, :], x.shape[0])
],
dim=1
)
seq_len += 1 # seq is 1 longer as we added CLS
attn_mask = self.attn_mask[None, :seq_len, :seq_len].expand(
text.shape[0] * self.heads, seq_len, seq_len
)
cls_mask = self._build_cls_mask(text, cast_dtype)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask + cls_mask)
x = x.permute(1, 0, 2) # LND -> NLD
cls_emb, token_emb = self.text(text, output_tokens=True)
gpucce marked this conversation as resolved.
Show resolved Hide resolved

x = x[torch.arange(x.shape[0]), :] @ self.text_projection

cls_emb = x[torch.arange(x.shape[0]), -1]
token_emb = x[torch.arange(x.shape[0]), :-1]

cls_emb = self.ln_final(cls_emb)
text_latent = self.to_text_latent(cls_emb)
text_latent = self.to_txt_latent(cls_emb)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent

return (text_latent, token_emb) if return_tokens else text_latent

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
labels = text[:, 1:]

text_latents, text_tokens = self.encode_text(text, return_tokens=True)
image_latents, image_tokens = self.encode_image(image, return_tokens=True)
text_latent, token_embs = self.encode_text(text, return_tokens=True)
image_latent, image_embs = self.encode_image(image, return_tokens=True)

text_tokens = self.multimodal_decoder(image_tokens, text_tokens)
logits = self.to_logits(text_tokens)
token_embs = self.multimodal_decoder(image_embs, token_embs)
logits = self.to_logits(token_embs)
if output_dict:
return {
"image_features":image_latent,
"text_features":text_latent,
"logits":logits,
"labels":labels,
"logit_scale":self.logit_scale.exp()
}

return image_latents, text_latents, logits, labels, self.logit_scale.exp()
return image_latent, text_latent, logits, labels, self.logit_scale.exp()

def generate(
self,
image,
text,
seq_len,
max_seq_len=None,
max_seq_len=77,
mask_prob = 0.0,
temperature = 1.,
filter_logits_fn = top_k,
Expand All @@ -217,9 +185,6 @@ def generate(

assert mask_prob < 1, "mask_prob must be smaller than 1."

if max_seq_len is None:
max_seq_len = self.context_length

was_training = self.training
num_dims = len(text.shape)

Expand Down
24 changes: 16 additions & 8 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):

return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""

Expand All @@ -90,7 +89,8 @@ def __init__(
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True):
pretrained: bool = True
):
super().__init__()

self.output_dim = output_dim
Expand All @@ -113,11 +113,10 @@ def __init__(
else:
self.config = config
self.transformer = AutoModel.from_config(config)

if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
pooler_type = (arch_dict[self.config.model_type]["pooler"])

self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
Expand All @@ -132,12 +131,21 @@ def __init__(
nn.Linear(hidden_size, output_dim, bias=False),
)

def forward(self, x: TensorType) -> TensorType:
def forward(self, x: TensorType, output_tokens=False) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)

return self.proj(pooled_out)
if output_tokens:
tokens = self.proj(
out.last_hidden_state[:, 1:, :]
gpucce marked this conversation as resolved.
Show resolved Hide resolved
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
return projected, tokens
gpucce marked this conversation as resolved.
Show resolved Hide resolved

return projected

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
Expand Down
20 changes: 18 additions & 2 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CLIPTextCfg:
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
embed_cls: bool = False
pad_id: int = 0


def get_cast_dtype(precision: str):
Expand Down Expand Up @@ -146,6 +148,8 @@ def _build_text_tower(
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
pad_id=text_cfg.pad_id,
act_layer=act_layer,
norm_layer=norm_layer,
)
Expand Down Expand Up @@ -202,9 +206,15 @@ def encode_text(self, text, normalize: bool = False):
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if output_dict:
return {
"image_features":image_features,
"text_features":text_features,
"logit_scale":self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()


Expand Down Expand Up @@ -242,9 +252,15 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

def forward(self, image, text):
def forward(self, image, text, output_dict=False):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
if output_dict:
return {
"image_features":image_features,
"text_features":text_features,
"logit_scale":self.logit_scale.exp()
}
return image_features, text_features, self.logit_scale.exp()


Expand Down
9 changes: 6 additions & 3 deletions src/open_clip/model_configs/coca_ViT-B-32.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"layers": 12,
"embed_cls": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
"layers": 12,
"latent_dim": 512,
"attn_pooler_heads": 8
},
"custom_text": true
}
11 changes: 7 additions & 4 deletions src/open_clip/model_configs/coca_base.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{
"embed_dim": 768,
"embed_dim": 512,
"multimodal_cfg": {
"width": 768,
"context_length": 76,
"mlp_ratio": 4,
"layers": 12,
"dim_head": 64,
"heads": 12,
"n_queries": 256
"n_queries": 256,
"latent_dim": 512,
"attn_pooler_heads": 8
},
"vision_cfg": {
"image_size": 288,
Expand All @@ -16,11 +18,12 @@
"patch_size": 18
},
"text_cfg": {
"context_length": 77,
"context_length": 76,
"vocab_size": 64000,
"layers": 12,
"heads": 12,
"width": 768
"width": 768,
"embed_cls": true
},
"custom_text": "True"
}
21 changes: 21 additions & 0 deletions src/open_clip/model_configs/coca_roberta-ViT-B-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "linear"
},
"multimodal_cfg": {
"context_length": 76,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
Loading