diff --git a/data/dataset_definitions.yml b/data/dataset_definitions.yml index 58bd45d9e1..343ccc0f43 100644 --- a/data/dataset_definitions.yml +++ b/data/dataset_definitions.yml @@ -1525,3 +1525,17 @@ datasets: image_postfix: '.jpg' annotation: human_matting.pickle dataset_meta: human_matting.json + + - name: tungsten_dataset_rt_hdr_alb + data_source: tungsten\ + annotation_conversion: + converter: tungsten + dataset_root_dir: tungsten\ + input_subfolder: [ 'spp_4_data' ] + target_subfolder: spp_4096_data + annotation_loader: opencv_unchanged + features: ['color', 'albedo'] + extension: exr + reader: + type: opencv_imread + reading_flag: unchanged diff --git a/demos/denoise_ray_tracing_demo/python/README.md b/demos/denoise_ray_tracing_demo/python/README.md new file mode 100644 index 0000000000..c1ccb1259e --- /dev/null +++ b/demos/denoise_ray_tracing_demo/python/README.md @@ -0,0 +1,90 @@ +# Denoise Ray Tracing Demo + +This example demonstrates an approach to denoising which is suitable for images rendered with Monte Carlo ray tracing methods like unidirectional and bidirectional path tracing using OpenVINO™. + +This demo also supports images from datasets [Tungsten](https://sites.google.com/view/bilateral-grid-denoising/home/supplemental-material-dataset). + +## How It Works + +The demo workflow is the following: + +The demo first reads an image and performs the preprocessing such as autoexposure and padding. Then after loading model to the plugin, the inference will start. The demo will display the image. + +## Preparing to Run + +The list of models supported by the demo is in `/demos/denoise_ray_tracing_demo/python/models.lst` file. +This file can be used as a parameter for [Model Downloader](../../../tools/model_tools/README.md) and Converter to download and, if necessary, convert models to OpenVINO IR format (\*.xml + \*.bin). + +An example of using the Model Downloader: + +```sh +omz_downloader --list models.lst +``` + +An example of using the Model Converter: + +```sh +omz_converter --list models.lst +``` + +### Supported Models + +* denoise_rt_hdr_alb + +## Running + +Running the application with the `-h` option yields the following usage message: + +``` +usage: denoise_ray_traicing_demo.py [-h] -m MODEL --hdr HDR --albedo ALBEDO + [-d DEVICE] [--no_show] [-o OUTPUT] + [-ob OUTPUT_BLOB] + [--input_scale INPUT_SCALE] + +Options: + --input_scale INPUT_SCALE, --is INPUT_SCALE + Scales values in the main input image before + filtering, without scaling the output too + +Options: + -h, --help Show this help message and exit. + -m MODEL, --model MODEL + Required. Path to an .xml file with a trained model. + --input_hdr HDR Required. Path to an HDR image to infer + --input_albedo ALBEDO Required. Path to an albedo image to infer + -d DEVICE, --device DEVICE + Optional. Specify the target device to infer on. The + demo will look for a suitable plugin for device + specified. Default value is CPU + --no_show Optional. Don't show output. Cannot be used in GUI mode + -o OUTPUT, --output OUTPUT + Optional. Save output to the file with provided filename. + -ob OUTPUT_BLOB, --output_blob OUTPUT_BLOB + Optional. Name of the output layer of the model. + Default is None, in which case the demo will read the + output name from the model, assuming there is only 1 output layer + +``` + +For example, to do inference on a CPU with the OpenVINO™ toolkit pre-trained `denoise_rt_hdr_alb` model, run the following command: + +```sh +python denoise_ray_tracing_demo.py \ + --model /denoise_rt_hdr_alb.xml \ + --input_hdr data/color.exr \ + --input_albedo data/albedo.exr \ + --output result.exr +``` + +## Demo Output + +The demo uses OpenCV window to display and save the resulting image. The demo reports + + +* **Latency**: total processing time required to process input data (from preprocessing the data to displaying the results). + +## See Also + +* [Open Model Zoo Demos](../../README.md) +* [Model Optimizer](https://docs.openvino.ai/latest/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html) +* [Model Downloader](../../../tools/model_tools/README.md) diff --git a/demos/denoise_ray_tracing_demo/python/data/albedo.exr b/demos/denoise_ray_tracing_demo/python/data/albedo.exr new file mode 100644 index 0000000000..66b3ac734c Binary files /dev/null and b/demos/denoise_ray_tracing_demo/python/data/albedo.exr differ diff --git a/demos/denoise_ray_tracing_demo/python/data/color.exr b/demos/denoise_ray_tracing_demo/python/data/color.exr new file mode 100644 index 0000000000..84e70b4b2c Binary files /dev/null and b/demos/denoise_ray_tracing_demo/python/data/color.exr differ diff --git a/demos/denoise_ray_tracing_demo/python/denoise_ray_tracing_demo.py b/demos/denoise_ray_tracing_demo/python/denoise_ray_tracing_demo.py new file mode 100644 index 0000000000..5ec3c0678a --- /dev/null +++ b/demos/denoise_ray_tracing_demo/python/denoise_ray_tracing_demo.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 + +""" + Copyright (C) 2023 KNS Group LLC (YADRO) + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os +import sys +from time import perf_counter +import logging as log +from argparse import ArgumentParser, SUPPRESS +from pathlib import Path + +import numpy as np + +#pylint: disable=wrong-import-position +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +from openvino.runtime import Core, get_version +from utils.color import autoexposure, get_transfer_function, round_up, srgb_inverse, srgb_forward + +log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.DEBUG, stream=sys.stdout) + + +def build_argparser(): + parser = ArgumentParser(add_help=False) + args = parser.add_argument_group('Options') + args.add_argument('-h', '--help', action='help', default=SUPPRESS, + help='Show this help message and exit.') + args.add_argument("-m", "--model", type=Path, required=True, + help="Required. Path to an .xml file with a trained model.") + args.add_argument("--input_hdr", type=Path, required=True, + help="Required. Path to an HDR image to infer") + args.add_argument("--input_albedo", type=Path, required=True, + help="Required. Path to an albedo image to infer") + args.add_argument("-d", "--device", type=str, default="CPU", + help="Optional. Specify the target device to infer on. " + "The demo will look for a suitable plugin for device specified. Default value is CPU") + args.add_argument("--no_show", help="Optional. Don't show output. Cannot be used in GUI mode", action='store_true') + args.add_argument("-o", "--output", help="Optional. Save output to the file with provided filename.", + default="", type=Path) + args.add_argument('--input_scale', '--is', type=float, default=1., + help='Scales values in the main input image before filtering, ' + 'without scaling the output too') + + return parser + + +def is_srgb_image(filename): + return filename.suffix not in ('.pfm', '.phm', '.exr', '.hdr') + + +def pad_image(image, shape): + image = np.pad(image, ((0, 0), + (0, 0), + (0, round_up(shape[2], 16) - shape[2]), + (0, round_up(shape[3], 16) - shape[3]))) + + return image + + +def load_image(filename): + image = cv2.imread(str(filename), cv2.IMREAD_UNCHANGED) + if image is None: + raise RuntimeError('Could not read image') + + if is_srgb_image(filename): + image = srgb_inverse(image) + + image = image[:, :, :3] + image = np.nan_to_num(image) + return image + + +def load_image_features(cfg): + color = load_image(cfg.input_hdr) + albedo = load_image(cfg.input_albedo) + images = {"color": color, "albedo": albedo} + return images + + +def preprocess_input(images, args): + color = images["color"] + albedo = images["albedo"] + + exposure = autoexposure(color) + transfer = get_transfer_function() + + if args.input_scale: + color *= args.input_scale + color *= exposure + + color = transfer.forward(color) + + # HWC -> BCHW + color = np.expand_dims(color.transpose((2, 0, 1)), 0) + shape = color.shape + color = pad_image(color, shape) + + albedo = np.expand_dims(albedo.transpose((2, 0, 1)), 0) + albedo = pad_image(albedo, shape) + + params = {"exposure": exposure, "transfer": transfer, "shape": shape} + features = {"color": color, "albedo": albedo} + return features, params + + +def postprocess_image(image, args, params): + shape = params["shape"] + image = image[0, :, :shape[2], :shape[3]].transpose((1, 2, 0)) + image = np.maximum(image, 0.) + + image = params.get("transfer").inverse(image) + + image = image / params.get("exposure") + + if is_srgb_image(args.output): + image = srgb_forward(image) + return image + + +def main(): + args = build_argparser().parse_args() + + # Plugin initialization + log.info('OpenVINO Runtime') + log.info(f'\tbuild: {get_version()}') + core = Core() + + if 'GPU' in args.device: + core.set_property("GPU", {"GPU_ENABLE_LOOP_UNROLLING": "NO", "CACHE_DIR": "./"}) + + # Read IR + log.info(f'Reading model {args.model}') + model = core.read_model(args.model) + + input_tensor_names = [model.inputs[i].get_any_name() for i in range(len(model.inputs))] + + if len(model.outputs) != 1: + raise RuntimeError("Demo supports only single output topologies") + output_tensor_name = model.outputs[0].get_any_name() + + # load input features + load_start_time = perf_counter() + images = load_image_features(args) + load_total_time = perf_counter() - load_start_time + + # pre-process input features + preprocessing_start_time = perf_counter() + input_image, params_image = preprocess_input(images, args) + preprocessing_total_time = perf_counter() - preprocessing_start_time + + # Loading model to the plugin + compiled_model = core.compile_model(model, args.device) + infer_request = compiled_model.create_infer_request() + log.info(f'The model {args.model} is loaded to {args.device}') + + # Start sync inference + inference_start_time = perf_counter() + infer_request.infer(inputs={input_tensor_names[0]: input_image[input_tensor_names[0]], + input_tensor_names[1]: input_image[input_tensor_names[1]]}) + preds = infer_request.get_tensor(output_tensor_name).data[:] + inference_total_time = perf_counter() - inference_start_time + + postprocessing_start_time = perf_counter() + result = postprocess_image(preds, args, params_image) + postprocessing_total_time = perf_counter() - postprocessing_start_time + + total_latency = (load_total_time + + preprocessing_total_time + + inference_total_time + + postprocessing_total_time) * 1e3 + + log.info("Metrics report:") + log.info(f"\tLatency: {total_latency:.1f} ms") + log.info(f"\tLoad features: {load_total_time * 1e3:.1f} ms") + log.info(f"\tPreprocessing: {preprocessing_total_time * 1e3:.1f} ms") + log.info(f"\tInference: {inference_total_time * 1e3:.1f} ms") + log.info(f"\tPostprocessing: {postprocessing_total_time * 1e3:.1f} ms") + + if args.output.name != "": + result_save = result * 255 if is_srgb_image(args.output) else result + cv2.imwrite(str(args.output), result_save) + if not args.no_show: + input_image = images["color"] + if not is_srgb_image(args.input_hdr): + input_image = srgb_forward(input_image) + if not is_srgb_image(args.output): + result = srgb_forward(result) + + imshow_image = cv2.hconcat([input_image, result]) + cv2.namedWindow("Denoise Ray Tracing Image Demo", cv2.WINDOW_NORMAL) + cv2.imshow('Denoise Ray Tracing Image Demo', imshow_image) + cv2.waitKey(0) + + sys.exit() + + +if __name__ == '__main__': + main() diff --git a/demos/denoise_ray_tracing_demo/python/models.lst b/demos/denoise_ray_tracing_demo/python/models.lst new file mode 100644 index 0000000000..cdff9a9ed3 --- /dev/null +++ b/demos/denoise_ray_tracing_demo/python/models.lst @@ -0,0 +1,2 @@ +# This file can be used with the --list option of the model downloader. +denoise_rt_hdr_alb diff --git a/demos/denoise_ray_tracing_demo/python/utils/__init__.py b/demos/denoise_ray_tracing_demo/python/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/demos/denoise_ray_tracing_demo/python/utils/color.py b/demos/denoise_ray_tracing_demo/python/utils/color.py new file mode 100644 index 0000000000..93ce82a6dd --- /dev/null +++ b/demos/denoise_ray_tracing_demo/python/utils/color.py @@ -0,0 +1,140 @@ +""" + Copyright (C) 2023 KNS Group LLC (YADRO) + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +HDR_Y_MAX = 65504. # maximum HDR value + + +def luminance(r, g, b): + return 0.212671 * r + 0.715160 * g + 0.072169 * b + + +def autoexposure(image): + key = 0.18 + eps = 1e-8 + k = 16 # downsampling amount + + # Compute the luminance of each pixel + r = image[..., 0] + g = image[..., 1] + b = image[..., 2] + lum = luminance(r, g, b) + + # Down sample the image to minimize sensitivity to noise + h = lum.shape[0] # original height + w = lum.shape[1] # original width + hk = (h + k // 2) // k # down sampled height + wk = (w + k // 2) // k # down sampled width + + lk = np.zeros((hk, wk), dtype=lum.dtype) + for i in range(hk): + for j in range(wk): + begin_h = i * h // hk + begin_w = j * w // wk + end_h = (i + 1) * h // hk + end_w = (j + 1) * w // wk + + lk[i, j] = lum[begin_h:end_h, begin_w:end_w].mean() + + lum = lk + + # Keep only values greater than epsilon + lum = lum[lum > eps] + if lum.size == 0: + return 1. + + # Compute the exposure value + return float(key / np.exp2(np.log2(lum).mean())) + + +def round_up(a, b): + return (a + b - 1) // b * b + + +def get_transfer_function(): + return PUTransferFunction() + + +# Transfer function: sRGB + +SRGB_A = 12.92 +SRGB_B = 1.055 +SRGB_C = 1. / 2.4 +SRGB_D = -0.055 +SRGB_Y0 = 0.0031308 +SRGB_X0 = 0.04045 + + +def srgb_forward(y): + return np.where(y <= SRGB_Y0, + SRGB_A * y, + SRGB_B * np.power(y, SRGB_C) + SRGB_D) + + +def srgb_inverse(x): + return np.where(x <= SRGB_X0, + x / SRGB_A, + np.power((x - SRGB_D) / SRGB_B, 1. / SRGB_C)) + + +class SRGBTransferFunction: + def forward(self, y): + return srgb_forward(y) + + def inverse(self, x): + return srgb_inverse(x) + + +# Transfer function: PU + +# Fit of PU2 curve normalized at 100 cd/m^2 +# [Aydin et al., 2008, "Extending Quality Metrics to Full Luminance Range Images"] +PU_A = 1.41283765e+03 +PU_B = 1.64593172e+00 +PU_C = 4.31384981e-01 +PU_D = -2.94139609e-03 +PU_E = 1.92653254e-01 +PU_F = 6.26026094e-03 +PU_G = 9.98620152e-01 +PU_Y0 = 1.57945760e-06 +PU_Y1 = 3.22087631e-02 +PU_X0 = 2.23151711e-03 +PU_X1 = 3.70974749e-01 + + +def pu_forward(y): + return np.where(y <= PU_Y0, + PU_A * y, + np.where(y <= PU_Y1, + PU_B * np.power(y, PU_C) + PU_D, + PU_E * np.log(y + PU_F) + PU_G)) + + +def pu_inverse(x): + return np.where(x <= PU_X0, + x / PU_A, + np.where(x <= PU_X1, + np.power((x - PU_D) / PU_B, 1. / PU_C), + np.exp((x - PU_G) / PU_E) - PU_F)) + + +PU_NORM_SCALE = 1. / pu_forward(HDR_Y_MAX) + + +class PUTransferFunction: + def forward(self, y): + return pu_forward(y) * PU_NORM_SCALE + + def inverse(self, x): + return pu_inverse(x / PU_NORM_SCALE) diff --git a/models/public/denoise_rt_hdr_alb/README.md b/models/public/denoise_rt_hdr_alb/README.md new file mode 100644 index 0000000000..14d2b53609 --- /dev/null +++ b/models/public/denoise_rt_hdr_alb/README.md @@ -0,0 +1,63 @@ +# denoise-rt-hdr-alb + +## Use Case and High-Level Description + +It denoise the Monte Carlo noise inherent to stochastic ray tracing methods like path tracing, reducing the amount of necessary samples per pixel by even multiple orders of magnitude (depending on the desired closeness to the ground truth) + +More details provided in the [repository](https://github.com/OpenImageDenoise/oidn). + + +## Example + +Example for denoising image (left - source image, right - image after denoising): + +![](./assets/denoising_image.png) + + + +## Specification +Accuracy metrics are obtained on [Tungsten dataset](https://sites.google.com/view/bilateral-grid-denoising/home/supplemental-material-dataset) +. + +| Metric | Value | +|------------------|-----------| +| SSIM | 0.99 | +| GFlops | 12.3637 | +| MParams | 0.9173 | +| Source framework | PyTorch\* | + + +## Inputs + +1. Image, name: `color`, dynamic shape in the format `B, C, H, W`, where: + + - `B` - batch size + - `C` - number of channels + - `H` - image height + - `W` - image width + +2. Image, name: `albedo`, dynamic shape in the format `B, C, H, W`, where: + + - `B` - batch size + - `C` - number of channels + - `H` - image height + - `W` - image width + +Image `color` and `albedo` should have the same shape. + +## Outputs +The net output is a blob with same shapes in the input image with format `B, C, H, W`, where: + + - `B` - batch size + - `C` - number of channels + - `H` - image height + - `W` - image width + +## Demo usage + +The model can be used in the following demos provided by the Open Model Zoo to show its capabilities: + +* [Denoise Render Demo](../../../demos/denoise_ray_tracing_demo/python/README.md) + +## Legal Information +[*] Other names and brands may be claimed as the property of others. diff --git a/models/public/denoise_rt_hdr_alb/accuracy-check.yml b/models/public/denoise_rt_hdr_alb/accuracy-check.yml new file mode 100644 index 0000000000..56f75451ed --- /dev/null +++ b/models/public/denoise_rt_hdr_alb/accuracy-check.yml @@ -0,0 +1,32 @@ +models: + - name: denoise-render-hdr-alb + + launchers: + - framework: openvino + adapter: + type: denoise_rt + inputs: + - name: "color" + type: INPUT + value: ".*_color.exr" + - name: "albedo" + type: INPUT + value: ".*_albedo.exr" + + datasets: + - name: tungsten_dataset + annotation_conversion: + converter: tungsten + preprocessing: + - type: autoexposure + key: 0.18 + k: 16 + - type: pu_transfer_function + postprocessing: + - type: pu_inverse_transfer_function + - type: autoexposure + key: 0.18 + k: 16 + metrics: + - type: ssim + presenter: print_vector diff --git a/models/public/denoise_rt_hdr_alb/assets/denoising_image.png b/models/public/denoise_rt_hdr_alb/assets/denoising_image.png new file mode 100644 index 0000000000..f870354ec4 Binary files /dev/null and b/models/public/denoise_rt_hdr_alb/assets/denoising_image.png differ diff --git a/models/public/denoise_rt_hdr_alb/model.yml b/models/public/denoise_rt_hdr_alb/model.yml new file mode 100644 index 0000000000..3f15b6c08d --- /dev/null +++ b/models/public/denoise_rt_hdr_alb/model.yml @@ -0,0 +1,44 @@ +# Copyright (C) 2023 KNS Group LLC (YADRO) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: >- + Denoising for images rendered with ray tracing +task_type: image_processing +files: + - name: denoise_rt_hdr_alb.tza + size: 3670280 + checksum: 37ea4fe857cb633d2b4c68c56b99053083dcbea01a849bbbd1f0b70ec1216e16f6d99a9eedcbd8fbd4057de686ca292a + source: https://media.githubusercontent.com/media/OpenImageDenoise/oidn-weights/a34b7641349c5a79e46a617d61709c35df5d6c28/rt_hdr_alb.tza + - name: tza.py + size: 5875 + checksum: 00c09df9fa07e483402464c4c596639d03f6d0325133e41183cb6caea5705b25f1f449b29ec1bbe071a3440c863e95cc + source: https://raw.githubusercontent.com/OpenImageDenoise/oidn/d959bac5b7130b31c41095811ddfbe58c4cf03f4/training/tza.py + - name: model.py + size: 3380 + checksum: 7a1922a8327b5f54a4fe403a0f99785a5909e09ba66e005f338ed9f394ab2dc25854df66c2a8e96c089561e672be8622 + source: https://raw.githubusercontent.com/isalyahova/OIDN_onnx_model/d9a49ce5706aa72e83e93954ef49ed3893d406d0/model.py + - name: convert_model.py + size: 2787 + checksum: ea03100e75455cd96cc04b181e31c505c627aef8e3b78958e3c76916245bcc74e6beb3eadf5d36b25c5aea84f7646ebc + source: https://raw.githubusercontent.com/isalyahova/OIDN_onnx_model/d9a49ce5706aa72e83e93954ef49ed3893d406d0/convert_model.py +input_info: + - name: 'color' + layout: NCHW + - name: 'albedo' + layout: NCHW +model_optimizer_args: + - --input_model=$dl_dir/denoise_rt_hdr_alb.onnx + - --input=color[?,3,?,?],albedo[?,3,?,?] +framework: onnx +license: https://github.com/OpenImageDenoise/oidn/blob/master/LICENSE.txt diff --git a/models/public/denoise_rt_hdr_alb/pre-convert.py b/models/public/denoise_rt_hdr_alb/pre-convert.py new file mode 100644 index 0000000000..0355600028 --- /dev/null +++ b/models/public/denoise_rt_hdr_alb/pre-convert.py @@ -0,0 +1,42 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import sys +import subprocess # nosec - disable B404:import-subprocess check + +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('input_dir', type=Path) + parser.add_argument('output_dir', type=Path) + args = parser.parse_args() + + saved_model_dir = args.output_dir + + subprocess.run([sys.executable, '--', + str(args.input_dir / 'convert_model.py'), + "--features", "hdr", "alb", + "--input_names", "color", "albedo", + f'--input_path_tza={args.input_dir / "denoise_rt_hdr_alb.tza"}', + f'--output_path_onnx={saved_model_dir / "denoise_rt_hdr_alb.onnx"}' + ], check=True) + + +if __name__ == '__main__': + main() diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/__init__.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/__init__.py index bd45181674..82df00c4bb 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/__init__.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/__init__.py @@ -37,7 +37,8 @@ ) from .image_processing import ( - ImageProcessingAdapter, SuperResolutionAdapter, MultiSuperResolutionAdapter, SuperResolutionYUV, TrimapAdapter + ImageProcessingAdapter, SuperResolutionAdapter, MultiSuperResolutionAdapter, SuperResolutionYUV, + TrimapAdapter, DenoiseRTAdapter ) from .attributes_recognition import ( HeadPoseEstimatorAdapter, @@ -200,6 +201,7 @@ 'MultiSuperResolutionAdapter', 'SuperResolutionYUV', 'TrimapAdapter', + 'DenoiseRTAdapter', 'HeadPoseEstimatorAdapter', 'VehicleAttributesRecognitionAdapter', diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/image_processing.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/image_processing.py index 32d2ccd88c..6f8f6cbfcd 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/image_processing.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/adapters/image_processing.py @@ -266,3 +266,21 @@ def process(self, raw, identifiers, frame_meta): result.append(ImageProcessingPrediction(identifier, out_img)) return result + + +class DenoiseRTAdapter(ImageProcessingAdapter): + __provider__ = 'denoise_rt' + prediction_types = (ImageProcessingPrediction, ) + + def process(self, raw, identifiers, frame_meta): + result = [] + raw_outputs = self._extract_predictions(raw, frame_meta) + if not self.output_verified: + self.select_output_blob(raw_outputs) + + for identifier, out_img in zip(identifiers, raw_outputs[self.target_out]): + out_img = out_img.transpose((1, 2, 0)) + out_img = np.maximum(out_img, 0.) + result.append(ImageProcessingPrediction(identifier, out_img)) + + return result diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/README.md b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/README.md index 14de24ead2..4b18ebbc98 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/README.md +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/README.md @@ -854,6 +854,13 @@ The main difference between this converter and `super_resolution` in data organi * `from_landmarks` - allow to calculate hand bounding box coordinates from landmarks data instead of data provided in `bbox_file` (optional, default `False`). * `padding` - additional padding, in pixels, around hand bounding box, calculated in `from_landmarks` mode (optional, default `10`). * `num_keypoints` - number of keypoints in annotation expected by model (optional, default `21`). +* `tungsten` - converts Tungsten dataset for denoising images rendered with ray tracing task to `ImageProcessingAnnotation`. + * `dataset_root_dir` - path to dataset root. + * `features` - list of features. Optional, default color. + * `input_subfolder` - sub-directory for input features(Optional, default spp_4_data) + * `target_subfolder` - sub-directory for targets(Optional, default spp_4096_data) + * `extension` - images extension (Optional, default - png). + * `annotation_loader` - which library will be used for ground truth image reading. Supported: `opencv`, `opencv_unchanged` (Optional. Default value is opencv_unchanged). ## Customizing Dataset Meta There are situations when we need to customize some default dataset parameters (e.g. replace original dataset label map with own.) diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/__init__.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/__init__.py index 1fd9ab5817..2f37624025 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/__init__.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/__init__.py @@ -127,6 +127,7 @@ from .speaker_identification import SpeakerReIdentificationDatasetConverter from .mvtec import MVTecDatasetConverter from .gan_annotation_converter import GANAnnotationConverter +from .tungsten import TungstenAnnotationConverter from .kitti_converter import KITTIConverter from .smartlab_action_recognition import SmartLabActionRecognition from .malware_classification import MalwareClassificationDatasetConverter @@ -258,6 +259,7 @@ 'SpeakerReIdentificationDatasetConverter', 'MVTecDatasetConverter', 'GANAnnotationConverter', + 'TungstenAnnotationConverter', 'KITTIConverter', 'SmartLabActionRecognition', 'MalwareClassificationDatasetConverter', diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/image_processing.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/image_processing.py index 9b03b3263d..44fad620a7 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/image_processing.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/image_processing.py @@ -26,7 +26,8 @@ 'pillow': GTLoader.PILLOW, 'dicom': GTLoader.DICOM, 'skimage': GTLoader.SKIMAGE, - 'pillow_rgb': GTLoader.PILLOW_RGB + 'pillow_rgb': GTLoader.PILLOW_RGB, + 'opencv_unchanged': GTLoader.OPENCV_UNCHANGED } diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/tungsten.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/tungsten.py new file mode 100644 index 0000000000..ff1f20b3d2 --- /dev/null +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/tungsten.py @@ -0,0 +1,85 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from pathlib import Path +from ..representation import ImageProcessingAnnotation +from ..config import PathField, ListField, StringField, ConfigError +from .format_converter import BaseFormatConverter, ConverterReturn +from ..utils import check_file_existence +from .image_processing import LOADERS_MAPPING + + +class TungstenAnnotationConverter(BaseFormatConverter): + __provider__ = 'tungsten' + annotation_types = (ImageProcessingAnnotation, ) + + @classmethod + def parameters(cls): + configuration_parameters = super().parameters() + configuration_parameters.update({ + 'dataset_root_dir': PathField(is_directory=True, description="path to dataset root"), + 'extension': StringField(default='png', optional=True, + description="images extension"), + 'features': ListField(value_type=str, default='color', optional=True, + description='List of features'), + 'input_subfolder': ListField(value_type=str, default='spp_4_data', optional=True, + description='sub-directory for input features'), + 'target_subfolder': StringField( + optional=True, + default='spp_4096_data', + description="sub-directory for targets." + ), + 'annotation_loader': StringField( + optional=True, choices=LOADERS_MAPPING.keys(), default='opencv_unchanged', + description="Which library will be used for ground truth image reading. " + "Supported: {}".format(', '.join(LOADERS_MAPPING.keys()))) + }) + + return configuration_parameters + + def configure(self): + self.dataset_root = self.get_value_from_config('dataset_root_dir') + self.features = self.get_value_from_config('features') + self.extension = self.get_value_from_config('extension') + self.input_subfolder = self.get_value_from_config('input_subfolder') + self.target_subfolder = self.get_value_from_config('target_subfolder') + self.annotation_loader = LOADERS_MAPPING.get(self.get_value_from_config('annotation_loader')) + if not self.annotation_loader: + raise ConfigError('provided not existing loader') + + def convert(self, check_content=False, **kwargs): + content_errors = None if not check_content else [] + annotations = [] + + for scene in self.dataset_root.iterdir(): + scene_path = Path(scene) + for folder in self.input_subfolder: + path_data = scene_path / folder + path_target = scene_path / self.target_subfolder + num_images = len(list(Path(path_data).rglob(r'*_color.{}'.format(self.extension)))) + for idx in range(num_images): + color = path_data / f'{idx}_color.{self.extension}' + albedo = path_data / f'{idx}_albedo.{self.extension}' + target = path_target / f'{idx}_color.{self.extension}' + if check_content: + if not check_file_existence(color): + content_errors.append(f'{color}: does not exist') + if not check_file_existence(albedo): + content_errors.append(f'{albedo}: does not exist') + annotations.append(ImageProcessingAnnotation([str(color), str(albedo)], str(target), + gt_loader=self.annotation_loader)) + + return ConverterReturn(annotations, self.get_meta(), content_errors) diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/main.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/main.py index 4364edf50a..a5518fc356 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/main.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/main.py @@ -17,6 +17,9 @@ import json import sys from datetime import datetime +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 from .argparser import build_arguments_parser diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/README.md b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/README.md index c63c5bbcde..3f791b408b 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/README.md +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/README.md @@ -139,3 +139,7 @@ Accuracy Checker supports following set of postprocessors: * `pooling type` - pooling type for embeddings - `mean` for mean pooling, `max` for max pooling (Optional, default `mean`). * `remove_padding` - remove end of string padding from word embeddings (Optional, default `True`). * `hand_landmarks` - converts hand landmark coordinates to source image coordinate space. Supported representations: `HandLandmarksPrediction`. +* `autoexposure` - dividing an image by exposure(value calculated automatically). Supported representations: `ImageProcessingAnnotation`, `ImageProcessingPrediction`. + * `key` - Destination width. + * `k` - Downsampling amount. +* `pu_inverse_transfer_function` - apply inverse perceptually uniform encoding. diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/__init__.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/__init__.py index 7fb0f67951..578e47162d 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/__init__.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/__init__.py @@ -56,6 +56,8 @@ from .remove_repeats import RemoveRepeatTokens from .tokens_to_lower_case import TokensToLowerCase from .super_resolution_image_recovery import SRImageRecovery, ColorizationLABRecovery +from .autoexposure import AutoExposureImage +from .transfer_function import PuInverseTransferFunction from .argmax_segmentation_mask import ArgMaxSegmentationMask from .normalize_salient_map import SalientMapNormalizer from .min_max_normalization import MinMaxRegressionNormalization @@ -136,6 +138,8 @@ 'SRImageRecovery', 'ColorizationLABRecovery', + 'AutoExposureImage', + 'PuInverseTransferFunction', 'SalientMapNormalizer', diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/autoexposure.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/autoexposure.py new file mode 100644 index 0000000000..a213e79b98 --- /dev/null +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/autoexposure.py @@ -0,0 +1,48 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from ..config import NumberField +from .postprocessor import Postprocessor +from ..representation import ImageProcessingPrediction, ImageProcessingAnnotation +from ..preprocessor import AutoExposure + + +class AutoExposureImage(Postprocessor): + __provider__ = 'autoexposure' + + prediction_types = (ImageProcessingAnnotation, ) + annotation_types = (ImageProcessingPrediction, ) + + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters.update({ + 'key': NumberField(value_type=float, optional=True, description="Destination width."), + 'k': NumberField(value_type=int, optional=True, min_value=1, description="Downsampling amount"), + }) + return parameters + + def configure(self): + self.params = {'key': self.get_value_from_config('key'), + 'k': self.get_value_from_config('k')} + + def process_image(self, annotation, prediction): + for prediction_, annotation_ in zip(prediction, annotation): + exposure = annotation_.metadata.get('exposure', None) + prediction_.value = prediction_.value / exposure if exposure \ + else prediction_.value / AutoExposure.autoexposure(prediction_.value, self.params) + + return annotation, prediction diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/transfer_function.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/transfer_function.py new file mode 100644 index 0000000000..75095ee78f --- /dev/null +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/postprocessor/transfer_function.py @@ -0,0 +1,51 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import numpy as np + +from .postprocessor import Postprocessor +from ..representation import ImageProcessingPrediction, ImageProcessingAnnotation +from ..preprocessor import PuTransferFunction + + +#pylint: disable=W0223 +class PuInverseTransferFunction(Postprocessor): + """ + Fit of PU2 curve normalized at 100 cd/m^2 + [Aydin et al., 2008, "Extending Quality Metrics to Full Luminance Range Images"] + """ + __provider__ = 'pu_inverse_transfer_function' + + prediction_types = (ImageProcessingAnnotation, ) + annotation_types = (ImageProcessingPrediction, ) + + def configure(self): + self.hdr_y_max = 65504. # maximum HDR value + + def process_image_with_metadata(self, annotation, prediction, image_metadata=None): + for prediction_, _ in zip(prediction, annotation): + params = image_metadata.get('params', None) + pu_norm_scale = 1. / PuTransferFunction.pu_forward(self.hdr_y_max, params) + prediction_.value = self.pu_inverse(prediction_.value / pu_norm_scale, params) + + return annotation, prediction + + @staticmethod + def pu_inverse(data, params): + return np.where(data <= params["pu_x0"], + data / params["pu_a"], + np.where(data <= params["pu_x1"], + np.power((data - params["pu_d"]) / params["pu_b"], 1. / params["pu_c"]), + np.exp((data - params["pu_g"]) / params["pu_e"]) - params["pu_f"])) diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/README.md b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/README.md index b9f314d281..06940114fc 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/README.md +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/README.md @@ -248,7 +248,10 @@ Accuracy Checker supports following set of preprocessors: * `quality_factor` - quality of compression, from 0 to 100 (the higher is the better). * `transpose` - transpose data using specified axes order. * `axes` - list of dimensions in transposing order. - +* `autoexposure` - multiplying an image by exposure(value calculated automatically). + * `key` - Destination width. + * `k` - Downsampling amount. +* `pu_transfer_function` - apply perceptually uniform encoding. ## Optimized preprocessing via OpenVINO Inference Engine OpenVINO™ is able perform preprocessing during model execution. For enabling this behaviour you can use command line parameter `--ie_preprocessing True`. diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/__init__.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/__init__.py index 40ab145ab3..f3b163903d 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/__init__.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/__init__.py @@ -77,6 +77,8 @@ from .raw_image_preprocessing import PackBayerImage from .trimap import TrimapPreprocessor, AlphaChannel from .compression import JPEGCompression +from .autoexposure import AutoExposure +from .transfer_function import PuTransferFunction __all__ = [ 'PreprocessingExecutor', @@ -164,5 +166,8 @@ 'TrimapPreprocessor', 'AlphaChannel', - 'JPEGCompression' + 'JPEGCompression', + + 'AutoExposure', + 'PuTransferFunction' ] diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/autoexposure.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/autoexposure.py new file mode 100644 index 0000000000..6111340b8f --- /dev/null +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/autoexposure.py @@ -0,0 +1,82 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import numpy as np + +from ..config import NumberField +from .preprocessor import Preprocessor + + +class AutoExposure(Preprocessor): + __provider__ = 'autoexposure' + + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters.update({ + 'key': NumberField(value_type=float, optional=True, description="Destination width"), + 'k': NumberField(value_type=int, optional=True, min_value=1, description="Downsampling amount"), + }) + return parameters + + def configure(self): + self.params = {'key': self.get_value_from_config('key'), + 'k': self.get_value_from_config('k')} + + def process(self, image, annotation_meta=None): + exposure = self.autoexposure(image.data[0], self.params) + image.data[0] = image.data[0] * exposure + annotation_meta['exposure'] = exposure + return image + + @staticmethod + def autoexposure(image, params): + def luminance(r, g, b): + return 0.212671 * r + 0.715160 * g + 0.072169 * b + + eps = 1e-8 + key, k = params['key'], params['k'] + + # Compute the luminance of each pixel + r = image[..., 0] + g = image[..., 1] + b = image[..., 2] + lum = luminance(r, g, b) + + # Downsample the image to minimize sensitivity to noise + h = lum.shape[0] # original height + w = lum.shape[1] # original width + hk = (h + k // 2) // k # down sampled height + wk = (w + k // 2) // k # down sampled width + + lk = np.zeros((hk, wk), dtype=lum.dtype) + for i in range(hk): + for j in range(wk): + begin_h = i * h // hk + begin_w = j * w // wk + end_h = (i + 1) * h // hk + end_w = (j + 1) * w // wk + + lk[i, j] = lum[begin_h:end_h, begin_w:end_w].mean() + + lum = lk + + # Keep only values greater than epsilon + lum = lum[lum > eps] + if lum.size == 0: + return 1. + + # Compute the exposure value + return float(key / np.exp2(np.log2(lum).mean())) diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/transfer_function.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/transfer_function.py new file mode 100644 index 0000000000..f2eda32b27 --- /dev/null +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/preprocessor/transfer_function.py @@ -0,0 +1,55 @@ +""" +Copyright (C) 2023 KNS Group LLC (YADRO) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import numpy as np + +from .preprocessor import Preprocessor + + +class PuTransferFunction(Preprocessor): + """ + Fit of PU2 curve normalized at 100 cd/m^2 + [Aydin et al., 2008, "Extending Quality Metrics to Full Luminance Range Images"] + """ + __provider__ = 'pu_transfer_function' + + def configure(self): + self.params = {"pu_y0": 1.57945760e-06, + "pu_y1": 3.22087631e-02, + "pu_x0": 2.23151711e-03, + "pu_x1": 3.70974749e-01, + "pu_a": 1.41283765e+03, + "pu_b": 1.64593172e+00, + "pu_c": 4.31384981e-01, + "pu_d": -2.94139609e-03, + "pu_e": 1.92653254e-01, + "pu_f": 6.26026094e-03, + "pu_g": 9.98620152e-01} + + self.hdr_y_max = 65504. # maximum HDR value + self.pu_norm_scale = 1. / self.pu_forward(self.hdr_y_max, self.params) + + def process(self, image, annotation_meta=None): + image.data[0] = self.pu_norm_scale * self.pu_forward(image.data[0], self.params) + image.metadata['params'] = self.params + return image + + @staticmethod + def pu_forward(data, params): + return np.where(data <= params["pu_y0"], + params["pu_a"] * data, + np.where(data <= params["pu_y1"], + params["pu_b"] * np.power(data, params["pu_c"]) + params["pu_d"], + params["pu_e"] * np.log(data + params["pu_f"]) + params["pu_g"])) diff --git a/tools/accuracy_checker/openvino/tools/accuracy_checker/representation/image_processing.py b/tools/accuracy_checker/openvino/tools/accuracy_checker/representation/image_processing.py index 9cb8a42d96..860dc47686 100644 --- a/tools/accuracy_checker/openvino/tools/accuracy_checker/representation/image_processing.py +++ b/tools/accuracy_checker/openvino/tools/accuracy_checker/representation/image_processing.py @@ -29,6 +29,7 @@ class GTLoader(Enum): SKIMAGE = 4 PILLOW_RGB = 5 NUMPY = 6 + OPENCV_UNCHANGED = 7 class ImageProcessingRepresentation(BaseRepresentation): @@ -43,7 +44,8 @@ class ImageProcessingAnnotation(ImageProcessingRepresentation): GTLoader.RAWPY: 'rawpy', GTLoader.SKIMAGE: 'skimage_imread', GTLoader.PILLOW_RGB: 'pillow_imread', - GTLoader.NUMPY: 'numpy_reader' + GTLoader.NUMPY: 'numpy_reader', + GTLoader.OPENCV_UNCHANGED: {'type': 'opencv_imread', 'reading_flag': 'unchanged'}, } def __init__(self, identifier, path_to_gt, gt_loader=GTLoader.PILLOW): @@ -66,10 +68,15 @@ def value(self): data_source = self.metadata.get('additional_data_source') if not data_source: data_source = self.metadata['data_source'] - loader = BaseReader.provide(self._gt_loader, data_source) + if isinstance(self._gt_loader, str): + loader = BaseReader.provide(self._gt_loader, data_source) + else: + loader = BaseReader.provide(self._gt_loader['type'], data_source, config=self._gt_loader) if self._gt_loader == self.LOADERS[GTLoader.PILLOW]: loader.convert_to_rgb = self._pillow_to_rgb if hasattr(self, '_pillow_to_rgb') else False gt = loader.read(self._image_path) + if isinstance(self._gt_loader, dict): + return gt return gt.astype(np.uint8) if self._gt_loader not in ['dicom_reader', 'rawpy', 'numpy_reader'] else gt return self._value