Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_logits method and NLLB tokenizer #756

Merged
merged 7 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH


HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
Expand Down
20 changes: 19 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def encode_text(self, text, normalize: bool = False):

return F.normalize(x, dim=-1) if normalize else x

def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

    def get_logits(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        image_logits = self.logit_scale.exp() * image_features @ text_features.T
        if self.logit_bias is not None:
            image_logits += self.logit_bias
        text_logits = image_logits.T
        return image_logits, text_logits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By bad. Fixed.

def forward(
self,
image: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -354,6 +363,15 @@ def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features

def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits

def forward(
self,
image: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -603,4 +621,4 @@ def get_model_tokenize_cfg(model):
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
return cfg
13 changes: 13 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import string
from functools import lru_cache, partial
from typing import Callable, List, Optional, Union
import warnings

import ftfy
import numpy as np
Expand Down Expand Up @@ -402,9 +403,15 @@ def __init__(
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
Expand Down Expand Up @@ -438,6 +445,12 @@ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] =
)

return input_ids

def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)
else:
warnings.warn('Cannot set language for the tokenizer.')


class SigLipTokenizer:
Expand Down