Skip to content

Commit

Permalink
Add HF BERT configs for Microsoft PubMed CLIP model (#491)
Browse files Browse the repository at this point in the history
* Add bert HF config

* Add ClsLastHiddenStatePooler to avoid needing to override ClsPooler bool arg

* Update pooler comment for clarity
  • Loading branch information
rwightman committed Apr 16, 2023
1 parent 37b729b commit ff2df73
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/open_clip/hf_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,15 @@
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/bert
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
},
"pooler": "cls_pooler",
},
}
14 changes: 14 additions & 0 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.last_hidden_state[:, self.cls_token_position, :]


@register_pooler
class ClsLastHiddenStatePooler(nn.Module):
"""CLS token pooling
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
"""

def __init__(self):
super().__init__()
self.cls_token_position = 0

def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
output_tokens: torch.jit.Final[bool]
Expand Down

0 comments on commit ff2df73

Please sign in to comment.