Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft CLIP model support #3796

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@

from pathlib import Path

from ..config import PathField, BoolField
from ..config import PathField, BoolField, StringField
from ..representation import ClassificationAnnotation
from ..utils import read_txt, get_path, check_file_existence, read_json

from .format_converter import BaseFormatConverter, ConverterReturn, verify_label_map
from ._nlp_common import get_tokenizer


class ImageNetFormatConverter(BaseFormatConverter):
__provider__ = 'imagenet'
annotation_types = (ClassificationAnnotation, )
max_seq_length = 128

@classmethod
def parameters(cls):
Expand All @@ -47,6 +49,15 @@ def parameters(cls):
),
'dataset_meta_file': PathField(
description='path to json file with dataset meta (e.g. label_map, color_encoding)', optional=True
),
'prepare_input_ids_from_labels': BoolField(
optional=True, default=False,
description="Convert label strings into image captions list. It's required for CLIP models"
),
'lower_case': BoolField(optional=True, default=False, description='Switch tokens to lower case register'),
'model_id': StringField(
optional=True,
description='The model id of a predefined tokenizer hosted inside a model repo on huggingface.co'
)
})
return configuration_parameters
Expand All @@ -57,6 +68,10 @@ def configure(self):
self.has_background = self.get_value_from_config('has_background')
self.images_dir = self.get_value_from_config('images_dir') or self.annotation_file.parent
self.dataset_meta = self.get_value_from_config('dataset_meta_file')
self.prepare_input_ids_from_labels = self.get_value_from_config('prepare_input_ids_from_labels')
self.lower_case = self.get_value_from_config('lower_case')
self.model_id = self.get_value_from_config('model_id')
self.tokenizer, self.external_tok = get_tokenizer(self.config, self.lower_case)

def convert(self, check_content=False, progress_callback=None, progress_interval=100, **kwargs):
annotation = []
Expand All @@ -78,7 +93,23 @@ def convert(self, check_content=False, progress_callback=None, progress_interval
return ConverterReturn(annotation, self.get_meta(), content_errors)

@staticmethod
def _create_meta(labels_file, dataset_meta, has_background=False):
def _create_captions(label_map, tokenizer):
tokenized_captions = []
input_masks = []
for label in label_map.values():
first_label = label.split(',')[0]
caption = f"This is a picture of {first_label}."
tokens = tokenizer.tokenize(caption, add_special_tokens=True)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
if len(tokens) > ImageNetFormatConverter.max_seq_length:
tokens = tokens[:ImageNetFormatConverter.max_seq_length]
input_mask = [1] * len(input_ids)
tokenized_captions.append(input_ids)
input_masks.append(input_mask)
return tokenized_captions, input_masks

@staticmethod
def _create_meta(labels_file, dataset_meta, tokenizer, has_background=False, prepare_input_ids_from_labels=False, ):
meta = {}
label_map = {}
if dataset_meta:
Expand Down Expand Up @@ -106,8 +137,12 @@ def _create_meta(labels_file, dataset_meta, has_background=False):
label_map[0] = 'background'
meta['background_label'] = 0

if prepare_input_ids_from_labels:
(captions, masks) = ImageNetFormatConverter._create_captions(label_map, tokenizer)
meta['input_ids'] = captions
meta['input_masks'] = masks
return meta

def get_meta(self):
meta = self._create_meta(self.labels_file, self.dataset_meta, self.has_background) or None
meta = self._create_meta(self.labels_file, self.dataset_meta, self.tokenizer, self.has_background, self.prepare_input_ids_from_labels ) or None
return meta