Skip to content

Commit

Permalink
Add get_logits method and NLLB tokenizer (#756)
Browse files Browse the repository at this point in the history
* Get logits method and set_language for tokenizer.
  • Loading branch information
visheratin committed Dec 9, 2023
1 parent 3480360 commit ebe135b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH


HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
Expand Down
20 changes: 19 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def encode_text(self, text, normalize: bool = False):

return F.normalize(x, dim=-1) if normalize else x

def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits

def forward(
self,
image: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -354,6 +363,15 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits

def forward(
self,
image: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -603,4 +621,4 @@ def get_model_tokenize_cfg(model):
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
return cfg
13 changes: 13 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import string
from functools import lru_cache, partial
from typing import Callable, List, Optional, Union
import warnings

import ftfy
import numpy as np
Expand Down Expand Up @@ -402,9 +403,15 @@ def __init__(
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
Expand Down Expand Up @@ -438,6 +445,12 @@ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] =
)

return input_ids

def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)
else:
warnings.warn('Cannot set language for the tokenizer.')


class SigLipTokenizer:
Expand Down

0 comments on commit ebe135b

Please sign in to comment.