[Model] Add PP-OCRV5_mobile_rec Model Support#43793
[Model] Add PP-OCRV5_mobile_rec Model Support#43793liu-jiaxuan wants to merge 2 commits intohuggingface:mainfrom
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, pp_ocrv5_mobile_rec |
yonigozlan
left a comment
There was a problem hiding this comment.
Hello @liu-jiaxuan! Thanks for opening this PR, however there is quite a bit to change here to fit the standards of the Transformers library.
The biggest issue is that you've written everything from scratch without inheriting from existing models. The modular file should maximize inheritance. Even if this is a novel architecture (especially the Conv modules part, which might not exist elsewhere in the library), components like MLP blocks, attention, and layer norms should use standard library patterns by inheriting form an existing model's module in modular.
The novel modules that can't be inherited through modular should also follow library standards in terms of naming, formatting, structure and good-practices ("PPOCRV5MobileRec" prefix for all module names, weight names standardized with other similar modules in the library, no single letter variables, type hints, docstrings when args are not standards or obvious, never use "eval()" etc.), and the model should support as much transformers features as possible, such as the attention interface through flags in PreTrainedModel( _supports_attention_backend, _supports_sdpa, _supports_flash_attn etc.)
Some other big things wrong or missing:
- We shouldn't have a cv2 dependency in image processors, "slow" should use PiL/numpy functions, fast torch/torchvision.
- Weight initialization shouldn't be scattered in individual module constructors but centralized in _init_weights() on the PreTrainedModel class, and use the transfromers "init" module.
- Attention modules are standardized across models in the transformers library, so using modular for attention modules is a must.
Before we go deeper in reviewing this new model addition (and other Paddle Paddle ones open recently that are very similar), please have a good look at how other models are implemented in the library. Notably, you can have a look at the recently merged PP-DocLayoutV3 PR (here's its modular file.
We also have resources to learn more about how to contribute a new model and how to use modular: Contributing a new model, using modular.
Also as the multiple Paddle Paddle models that have a new model addition PR open currently seem to be quite similar, I'd recommend focusing on one (the simplest) for now, then we'll be able to leverage modular to easily add the other models.
Happy to answer any questions you may have!
| class DropPath(nn.Module): | ||
| def __init__(self, drop_prob=None): | ||
| super().__init__() | ||
|
|
||
| self.drop_prob = drop_prob | ||
|
|
||
| def forward(self, x): | ||
| if self.drop_prob == 0.0 or not self.training: | ||
| return x | ||
| keep_prob = 1 - self.drop_prob | ||
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | ||
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
| random_tensor.floor_() | ||
| return x.div(keep_prob) * random_tensor |
There was a problem hiding this comment.
Is this really used anywhere? seems to always be 0, with no way to change it. If that's the case let's remove entirely
| @auto_docstring(custom_intro="ImageProcessor for the PP-OCRv5_mobile_rec model.") | ||
| class PPOCRV5MobileRecImageProcessor(BaseImageProcessor): | ||
| r""" | ||
| Constructs a PPOCRV5MobileRec image processor. | ||
|
|
||
| Args: | ||
| rec_image_shape (`List[int]`, *optional*, defaults to `[3, 48, 320]`): | ||
| The target image shape for recognition in format [channels, height, width]. | ||
| max_img_width (`int`, *optional*, defaults to `3200`): | ||
| The maximum width allowed for the resized image. | ||
| character_list (`List[str]` or `str`, *optional*, defaults to `None`): | ||
| The list of characters for text recognition decoding. If `None`, defaults to | ||
| "0123456789abcdefghijklmnopqrstuvwxyz". | ||
| use_space_char (`bool`, *optional*, defaults to `True`): | ||
| Whether to include space character in the character list. | ||
| do_rescale (`bool`, *optional*, defaults to `True`): | ||
| Whether to rescale the image pixel values to [0, 1] by dividing by 255. | ||
| do_normalize (`bool`, *optional*, defaults to `True`): | ||
| Whether to normalize the image with mean=0.5 and std=0.5. | ||
| """ | ||
|
|
||
| model_input_names = ["pixel_values"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| rec_image_shape: list[int] = [3, 48, 320], | ||
| max_img_width: int = 3200, | ||
| character_list: list[str] | str | None = None, | ||
| use_space_char: bool = True, | ||
| do_rescale: bool = True, | ||
| do_normalize: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| super().__init__(**kwargs) | ||
| self.rec_image_shape = rec_image_shape if rec_image_shape is not None else [3, 48, 320] | ||
| self.max_img_width = max_img_width | ||
| self.do_rescale = do_rescale | ||
| self.do_normalize = do_normalize | ||
|
|
||
| # Initialize character list for decoding | ||
| self._init_character_list(character_list, use_space_char) | ||
|
|
||
| def _init_character_list( | ||
| self, | ||
| character_list: list[str] | str | None, | ||
| use_space_char: bool, | ||
| ) -> None: | ||
| """ | ||
| Initialize the character list and character-to-index mapping for CTC decoding. | ||
|
|
||
| Args: | ||
| character_list (`List[str]` or `str`, *optional*): | ||
| The list of characters or a string of characters. If `None`, defaults to | ||
| "0123456789abcdefghijklmnopqrstuvwxyz". | ||
| use_space_char (`bool`): | ||
| Whether to include space character in the character list. | ||
| """ | ||
| if character_list is None: | ||
| characters = list("0123456789abcdefghijklmnopqrstuvwxyz") | ||
| elif isinstance(character_list, str): | ||
| characters = list(character_list) | ||
| else: | ||
| characters = list(character_list) | ||
|
|
||
| if use_space_char: | ||
| characters.append(" ") | ||
|
|
||
| # Add CTC blank token at the beginning | ||
| characters = ["blank"] + characters | ||
|
|
||
| self.character = characters | ||
| self.char_to_idx = {char: idx for idx, char in enumerate(characters)} | ||
|
|
||
| def _resize_norm_img( | ||
| self, | ||
| img: np.ndarray, | ||
| max_wh_ratio: float, | ||
| data_format: ChannelDimension | None = None, | ||
| ) -> np.ndarray: | ||
| """ | ||
| Resize and normalize a single image while maintaining aspect ratio. | ||
|
|
||
| Args: | ||
| img (`np.ndarray`): | ||
| The input image in HWC format. | ||
| max_wh_ratio (`float`): | ||
| The maximum width-to-height ratio for resizing. | ||
| data_format (`ChannelDimension`, *optional*): | ||
| The channel dimension format of the output image. | ||
|
|
||
| Returns: | ||
| `np.ndarray`: The processed image in CHW format with padding. | ||
| """ | ||
| img_c, img_h, img_w = self.rec_image_shape | ||
|
|
||
| # Calculate target width based on max_wh_ratio | ||
| target_w = int(img_h * max_wh_ratio) | ||
|
|
||
| if target_w > self.max_img_width: | ||
| # If target width exceeds max, resize to max width | ||
| resized_image = cv2.resize(img, (self.max_img_width, img_h)) | ||
| resized_w = self.max_img_width | ||
| target_w = self.max_img_width | ||
| else: | ||
| h, w = img.shape[:2] | ||
| ratio = w / float(h) | ||
| if math.ceil(img_h * ratio) > target_w: | ||
| resized_w = target_w | ||
| else: | ||
| resized_w = int(math.ceil(img_h * ratio)) | ||
| resized_image = cv2.resize(img, (resized_w, img_h)) | ||
|
|
||
| # Convert to float32 | ||
| resized_image = resized_image.astype(np.float32) | ||
|
|
||
| # Transpose to CHW format | ||
| resized_image = resized_image.transpose((2, 0, 1)) | ||
|
|
||
| # Rescale to [0, 1] | ||
| if self.do_rescale: | ||
| resized_image = resized_image / 255.0 | ||
|
|
||
| # Normalize with mean=0.5, std=0.5 | ||
| if self.do_normalize: | ||
| resized_image = (resized_image - 0.5) / 0.5 | ||
|
|
||
| # Create padded image | ||
| padding_im = np.zeros((img_c, img_h, target_w), dtype=np.float32) | ||
| padding_im[:, :, 0:resized_w] = resized_image | ||
|
|
||
| return padding_im | ||
|
|
||
| def preprocess( | ||
| self, | ||
| img: ImageInput, | ||
| rec_image_shape: list[int] | None = None, | ||
| max_img_width: int | None = None, | ||
| do_rescale: bool | None = None, | ||
| do_normalize: bool | None = None, | ||
| return_tensors: str | TensorType | None = None, | ||
| data_format: ChannelDimension = ChannelDimension.FIRST, | ||
| **kwargs, | ||
| ) -> BatchFeature: | ||
| """ | ||
| Preprocess an image for PPOCRV5MobileRec text recognition. | ||
|
|
||
| Args: | ||
| img (`ImageInput`): | ||
| The input image to preprocess. Can be a PIL Image, numpy array, or torch tensor. | ||
| rec_image_shape (`List[int]`, *optional*): | ||
| The target image shape [channels, height, width]. Defaults to `self.rec_image_shape`. | ||
| max_img_width (`int`, *optional*): | ||
| The maximum width for the resized image. Defaults to `self.max_img_width`. | ||
| do_rescale (`bool`, *optional*): | ||
| Whether to rescale pixel values to [0, 1]. Defaults to `self.do_rescale`. | ||
| do_normalize (`bool`, *optional*): | ||
| Whether to normalize with mean=0.5 and std=0.5. Defaults to `self.do_normalize`. | ||
| return_tensors (`str` or `TensorType`, *optional*): | ||
| The type of tensors to return. Can be "pt", "tf", "np", or None. | ||
| data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): | ||
| The channel dimension format of the output image. | ||
|
|
||
| Returns: | ||
| `BatchFeature`: A BatchFeature containing the processed `pixel_values`. | ||
| """ | ||
| # Use instance defaults if not specified | ||
| rec_image_shape = rec_image_shape if rec_image_shape is not None else self.rec_image_shape | ||
| max_img_width = max_img_width if max_img_width is not None else self.max_img_width | ||
| do_rescale = do_rescale if do_rescale is not None else self.do_rescale | ||
| do_normalize = do_normalize if do_normalize is not None else self.do_normalize | ||
|
|
||
| # Store original values and temporarily update for processing | ||
| original_rec_image_shape = self.rec_image_shape | ||
| original_max_img_width = self.max_img_width | ||
| original_do_rescale = self.do_rescale | ||
| original_do_normalize = self.do_normalize | ||
|
|
||
| self.rec_image_shape = rec_image_shape | ||
| self.max_img_width = max_img_width | ||
| self.do_rescale = do_rescale | ||
| self.do_normalize = do_normalize | ||
|
|
||
| try: | ||
| # Convert to numpy array | ||
| img = np.array(img) | ||
|
|
||
| # Get image dimensions | ||
| img_c, img_h, img_w = self.rec_image_shape | ||
| h, w = img.shape[:2] | ||
|
|
||
| # Calculate max_wh_ratio dynamically | ||
| base_wh_ratio = img_w / img_h | ||
| wh_ratio = w * 1.0 / h | ||
| max_wh_ratio = max(base_wh_ratio, wh_ratio) | ||
|
|
||
| # Process the image | ||
| processed_img = self._resize_norm_img(img, max_wh_ratio) | ||
|
|
||
| # Add batch dimension | ||
| processed_img = np.expand_dims(processed_img, axis=0) | ||
|
|
||
| data = {"pixel_values": processed_img} | ||
| return BatchFeature(data=data, tensor_type=return_tensors) | ||
|
|
||
| finally: | ||
| # Restore original values | ||
| self.rec_image_shape = original_rec_image_shape | ||
| self.max_img_width = original_max_img_width | ||
| self.do_rescale = original_do_rescale | ||
| self.do_normalize = original_do_normalize | ||
|
|
||
| def _ctc_decode( | ||
| self, | ||
| text_index: np.ndarray, | ||
| text_prob: np.ndarray, | ||
| is_remove_duplicate: bool = True, | ||
| ) -> list[tuple[str, float]]: | ||
| """ | ||
| Decode CTC output indices to text. | ||
|
|
||
| Args: | ||
| text_index (`np.ndarray`): | ||
| The predicted character indices with shape (batch_size, sequence_length). | ||
| text_prob (`np.ndarray`): | ||
| The predicted character probabilities with shape (batch_size, sequence_length). | ||
| is_remove_duplicate (`bool`, *optional*, defaults to `True`): | ||
| Whether to remove duplicate consecutive characters. | ||
|
|
||
| Returns: | ||
| `List[Tuple[str, float]]`: A list of tuples containing (decoded_text, confidence_score). | ||
| """ | ||
| result_list = [] | ||
| ignored_tokens = [0] # CTC blank token | ||
| batch_size = len(text_index) | ||
|
|
||
| for batch_idx in range(batch_size): | ||
| selection = np.ones(len(text_index[batch_idx]), dtype=bool) | ||
|
|
||
| if is_remove_duplicate: | ||
| selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] | ||
|
|
||
| for ignored_token in ignored_tokens: | ||
| selection &= text_index[batch_idx] != ignored_token | ||
|
|
||
| char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]] | ||
|
|
||
| if text_prob is not None: | ||
| conf_list = text_prob[batch_idx][selection] | ||
| else: | ||
| conf_list = [1] * len(selection) | ||
|
|
||
| if len(conf_list) == 0: | ||
| conf_list = [0] | ||
|
|
||
| text = "".join(char_list) | ||
| result_list.append((text, np.mean(conf_list).tolist())) | ||
|
|
||
| return result_list | ||
|
|
||
| def post_process_text_recognition( | ||
| self, | ||
| pred: np.ndarray, | ||
| ) -> tuple[list[str], list[float]]: | ||
| """ | ||
| Post-process the model output to decode text recognition results. | ||
|
|
||
| Args: | ||
| pred (`np.ndarray`): | ||
| The model output predictions. Expected shape is (batch_size, sequence_length, num_classes) | ||
| or a list/tuple containing such an array. | ||
|
|
||
| Returns: | ||
| `Tuple[List[str], List[float]]`: A tuple containing: | ||
| - texts: List of decoded text strings. | ||
| - scores: List of confidence scores for each decoded text. | ||
| """ | ||
| preds = np.array(pred[0].detach().cpu()) | ||
| preds_idx = preds.argmax(axis=-1) | ||
| preds_prob = preds.max(axis=-1) | ||
|
|
||
| text = self._ctc_decode( | ||
| preds_idx, | ||
| preds_prob, | ||
| is_remove_duplicate=True, | ||
| ) | ||
|
|
||
| texts = [] | ||
| scores = [] | ||
| for t in text: | ||
| texts.append(t[0]) | ||
| scores.append(t[1]) | ||
|
|
||
| return texts, scores | ||
|
|
||
|
|
||
| @auto_docstring(custom_intro="FastImageProcessor for the PP-OCRv5_mobile_rec model.") | ||
| class PPOCRV5MobileRecImageProcessorFast(BaseImageProcessorFast): | ||
| r""" | ||
| Constructs a fast PPOCRV5MobileRec image processor that supports batch processing. | ||
|
|
||
| This processor is designed to handle multiple images efficiently while maintaining | ||
| strict compatibility with [`PPOCRV5MobileRecImageProcessor`]. The preprocessing | ||
| results are guaranteed to be identical to the non-fast version. | ||
|
|
||
| Args: | ||
| rec_image_shape (`List[int]`, *optional*, defaults to `[3, 48, 320]`): | ||
| The target image shape for recognition in format [channels, height, width]. | ||
| max_img_width (`int`, *optional*, defaults to `3200`): | ||
| The maximum width allowed for the resized image. | ||
| character_list (`List[str]` or `str`, *optional*, defaults to `None`): | ||
| The list of characters for text recognition decoding. If `None`, defaults to | ||
| "0123456789abcdefghijklmnopqrstuvwxyz". | ||
| use_space_char (`bool`, *optional*, defaults to `True`): | ||
| Whether to include space character in the character list. | ||
| do_rescale (`bool`, *optional*, defaults to `True`): | ||
| Whether to rescale the image pixel values to [0, 1] by dividing by 255. | ||
| do_normalize (`bool`, *optional*, defaults to `True`): | ||
| Whether to normalize the image with mean=0.5 and std=0.5. | ||
| image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): | ||
| The mean values for image normalization. Used for validation but actual | ||
| normalization uses fixed value 0.5 in `_resize_norm_img`. | ||
| image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): | ||
| The standard deviation values for image normalization. Used for validation | ||
| but actual normalization uses fixed value 0.5 in `_resize_norm_img`. | ||
|
|
||
| Examples: | ||
|
|
||
| ```python | ||
| >>> from PIL import Image | ||
| >>> from transformers import PPOCRV5MobileRecImageProcessorFast | ||
|
|
||
| >>> processor = PPOCRV5MobileRecImageProcessorFast() | ||
|
|
||
| >>> # Process a single image | ||
| >>> image = Image.open("text_image.png") | ||
| >>> inputs = processor(image, return_tensors="pt") | ||
|
|
||
| >>> # Process multiple images in batch | ||
| >>> images = [Image.open(f"text_image_{i}.png") for i in range(4)] | ||
| >>> batch_inputs = processor(images, return_tensors="pt") | ||
| """ | ||
|
|
||
| model_input_names = ["pixel_values"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| rec_image_shape: list[int] | None = None, | ||
| max_img_width: int = 3200, | ||
| character_list: list[str] | str | None = None, | ||
| use_space_char: bool = True, | ||
| do_rescale: bool = True, | ||
| do_normalize: bool = True, | ||
| image_mean: float | list[float] | None = None, | ||
| image_std: float | list[float] | None = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| super().__init__(**kwargs) | ||
| self.rec_image_shape = rec_image_shape if rec_image_shape is not None else [3, 48, 320] | ||
| self.max_img_width = max_img_width | ||
| self.do_rescale = do_rescale | ||
| self.do_normalize = do_normalize | ||
| # Set default image_mean and image_std for normalization (mean=0.5, std=0.5) | ||
| self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] | ||
| self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] | ||
|
|
||
| # Initialize character list for decoding | ||
| self._init_character_list(character_list, use_space_char) | ||
|
|
||
| def _init_character_list( | ||
| self, | ||
| character_list: list[str] | str | None, | ||
| use_space_char: bool, | ||
| ) -> None: | ||
| """ | ||
| Initialize the character list and character-to-index mapping for CTC decoding. | ||
|
|
||
| Args: | ||
| character_list (`List[str]` or `str`, *optional*): | ||
| The list of characters or a string of characters. If `None`, defaults to | ||
| "0123456789abcdefghijklmnopqrstuvwxyz". | ||
| use_space_char (`bool`): | ||
| Whether to include space character in the character list. | ||
| """ | ||
| if character_list is None: | ||
| characters = list("0123456789abcdefghijklmnopqrstuvwxyz") | ||
| elif isinstance(character_list, str): | ||
| characters = list(character_list) | ||
| else: | ||
| characters = list(character_list) | ||
|
|
||
| if use_space_char: | ||
| characters.append(" ") | ||
|
|
||
| # Add CTC blank token at the beginning | ||
| characters = ["blank"] + characters | ||
|
|
||
| self.character = characters | ||
| self.char_to_idx = {char: idx for idx, char in enumerate(characters)} | ||
|
|
||
| def _resize_norm_img( | ||
| self, | ||
| img: np.ndarray, | ||
| max_wh_ratio: float, | ||
| data_format: ChannelDimension | None = None, | ||
| ) -> np.ndarray: | ||
| """ | ||
| Resize and normalize a single image while maintaining aspect ratio. | ||
|
|
||
| This method is identical to the one in [`PPOCRV5MobileRecImageProcessor`] to ensure | ||
| consistent preprocessing results. | ||
|
|
||
| Args: | ||
| img (`np.ndarray`): | ||
| The input image in HWC format. | ||
| max_wh_ratio (`float`): | ||
| The maximum width-to-height ratio for resizing. | ||
| data_format (`ChannelDimension`, *optional*): | ||
| The channel dimension format of the output image. | ||
|
|
||
| Returns: | ||
| `np.ndarray`: The processed image in CHW format with padding. | ||
| """ | ||
| img_c, img_h, img_w = self.rec_image_shape | ||
|
|
||
| # Calculate target width based on max_wh_ratio | ||
| target_w = int(img_h * max_wh_ratio) | ||
|
|
||
| if target_w > self.max_img_width: | ||
| # If target width exceeds max, resize to max width | ||
| resized_image = cv2.resize(img, (self.max_img_width, img_h)) | ||
| resized_w = self.max_img_width | ||
| target_w = self.max_img_width | ||
| else: | ||
| h, w = img.shape[:2] | ||
| ratio = w / float(h) | ||
| if math.ceil(img_h * ratio) > target_w: | ||
| resized_w = target_w | ||
| else: | ||
| resized_w = int(math.ceil(img_h * ratio)) | ||
| resized_image = cv2.resize(img, (resized_w, img_h)) | ||
|
|
||
| # Convert to float32 | ||
| resized_image = resized_image.astype(np.float32) | ||
|
|
||
| # Transpose to CHW format | ||
| resized_image = resized_image.transpose((2, 0, 1)) | ||
|
|
||
| # Rescale to [0, 1] | ||
| if self.do_rescale: | ||
| resized_image = resized_image / 255.0 | ||
|
|
||
| # Normalize with mean=0.5, std=0.5 | ||
| if self.do_normalize: | ||
| resized_image = (resized_image - 0.5) / 0.5 | ||
|
|
||
| # Create padded image | ||
| padding_im = np.zeros((img_c, img_h, target_w), dtype=np.float32) | ||
| padding_im[:, :, 0:resized_w] = resized_image | ||
|
|
||
| return padding_im | ||
|
|
||
| def _preprocess( | ||
| self, | ||
| images: list["torch.Tensor"], | ||
| **kwargs, | ||
| ) -> BatchFeature: | ||
| """ | ||
| Preprocess a batch of images for text recognition. | ||
|
|
||
| Args: | ||
| images (`List[torch.Tensor]`): | ||
| List of images to preprocess. | ||
| **kwargs: | ||
| Additional keyword arguments. | ||
|
|
||
| Returns: | ||
| `BatchFeature`: A dictionary containing the processed pixel values. | ||
| """ | ||
| # Convert torch tensors to numpy arrays in HWC format | ||
| np_images = [] | ||
| for img in images: | ||
| # img is a torch.Tensor in CHW format, convert to HWC numpy array | ||
| if isinstance(img, torch.Tensor): | ||
| img_np = img.permute(1, 2, 0).numpy() | ||
| else: | ||
| img_np = np.array(img) | ||
| np_images.append(img_np) | ||
|
|
||
| # Calculate max width-to-height ratio across all images | ||
| for img in np_images: | ||
| imgC, imgH, imgW = self.rec_image_shape | ||
| max_wh_ratio = imgW / imgH | ||
| h, w = img.shape[:2] | ||
| wh_ratio = w / float(h) | ||
| max_wh_ratio = max(max_wh_ratio, wh_ratio) | ||
|
|
||
| # Process each image | ||
| processed_images = [] | ||
| for img in np_images: | ||
| processed_img = self._resize_norm_img( | ||
| img, | ||
| max_wh_ratio=max_wh_ratio, | ||
| ) | ||
| processed_images.append(processed_img) | ||
|
|
||
| # Stack into batch tensor | ||
| pixel_values = np.stack(processed_images, axis=0) | ||
| pixel_values = torch.from_numpy(pixel_values) | ||
|
|
||
| return BatchFeature(data={"pixel_values": pixel_values}) | ||
|
|
||
| def _ctc_decode( | ||
| self, | ||
| text_index: np.ndarray, | ||
| text_prob: np.ndarray, | ||
| is_remove_duplicate: bool = True, | ||
| ) -> list[tuple[str, float]]: | ||
| """ | ||
| Decode CTC output indices to text. | ||
|
|
||
| This method is identical to the one in [`PPOCRV5MobileRecImageProcessor`] to ensure | ||
| consistent decoding results. | ||
|
|
||
| Args: | ||
| text_index (`np.ndarray`): | ||
| The predicted character indices with shape (batch_size, sequence_length). | ||
| text_prob (`np.ndarray`): | ||
| The predicted character probabilities with shape (batch_size, sequence_length). | ||
| is_remove_duplicate (`bool`, *optional*, defaults to `True`): | ||
| Whether to remove duplicate consecutive characters. | ||
|
|
||
| Returns: | ||
| `List[Tuple[str, float]]`: A list of tuples containing (decoded_text, confidence_score). | ||
| """ | ||
| result_list = [] | ||
| ignored_tokens = [0] # CTC blank token | ||
| batch_size = len(text_index) | ||
|
|
||
| for batch_idx in range(batch_size): | ||
| selection = np.ones(len(text_index[batch_idx]), dtype=bool) | ||
|
|
||
| if is_remove_duplicate: | ||
| selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] | ||
|
|
||
| for ignored_token in ignored_tokens: | ||
| selection &= text_index[batch_idx] != ignored_token | ||
|
|
||
| char_list = [self.character[text_id] for text_id in text_index[batch_idx][selection]] | ||
|
|
||
| if text_prob is not None: | ||
| conf_list = text_prob[batch_idx][selection] | ||
| else: | ||
| conf_list = [1] * len(selection) | ||
|
|
||
| if len(conf_list) == 0: | ||
| conf_list = [0] | ||
|
|
||
| text = "".join(char_list) | ||
| result_list.append((text, np.mean(conf_list).tolist())) | ||
|
|
||
| return result_list | ||
|
|
||
| def post_process_text_recognition( | ||
| self, | ||
| pred: np.ndarray, | ||
| ) -> tuple[list[str], list[float]]: | ||
| """ | ||
| Post-process the model output to decode text recognition results. | ||
|
|
||
| This method is identical to the one in [`PPOCRV5MobileRecImageProcessor`] to ensure | ||
| consistent post-processing behavior. | ||
|
|
||
| Args: | ||
| pred (`np.ndarray`): | ||
| The model output predictions. Expected shape is (batch_size, sequence_length, num_classes) | ||
| or a list/tuple containing such an array. | ||
|
|
||
| Returns: | ||
| `Tuple[List[str], List[float]]`: A tuple containing: | ||
| - texts: List of decoded text strings. | ||
| - scores: List of confidence scores for each decoded text. | ||
| """ | ||
| preds = np.array(pred[0].detach().cpu()) | ||
| preds_idx = preds.argmax(axis=-1) | ||
| preds_prob = preds.max(axis=-1) | ||
|
|
||
| text = self._ctc_decode( | ||
| preds_idx, | ||
| preds_prob, | ||
| is_remove_duplicate=True, | ||
| ) | ||
|
|
||
| texts = [] | ||
| scores = [] | ||
| for t in text: | ||
| texts.append(t[0]) | ||
| scores.append(t[1]) | ||
|
|
||
| return texts, scores |
There was a problem hiding this comment.
Let's avoid cv2 as a dependency. We use PiL for "slow" image processors, and torch/torchvision for fast ones
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.