Skip to content

Commit

Permalink
Fix issue getting context_length and vocab_size from HF models
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Apr 16, 2023
1 parent 395851e commit b8fd066
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""

import re
import warnings

import torch
import torch.nn as nn
Expand Down Expand Up @@ -132,7 +132,14 @@ def __init__(
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
pooler_type = (arch_dict[self.config.model_type]["pooler"])


# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
if not self.vocab_size or not self.context_length:
warnings.warn(
f'vocab_size ({self.vocab_size} and context_length ({self.context_length} were not properly set.')

self.pooler = _POOLERS[pooler_type]()

d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
Expand Down

0 comments on commit b8fd066

Please sign in to comment.