Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] GPU Accelerated Image normalization for DirectML #20155

Closed
NevermindNilas opened this issue Mar 30, 2024 · 1 comment
Closed

[Performance] GPU Accelerated Image normalization for DirectML #20155

NevermindNilas opened this issue Mar 30, 2024 · 1 comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:DML issues related to the DirectML execution provider platform:windows issues related to the Windows platform stale issues that have not been addressed in a while; categorized by a bot

Comments

@NevermindNilas
Copy link

Describe the issue

As seen in the " To Reproduce " section, the performance is heavily bottlenecked by the requirement to normalize using Numpy.

For ONNX Runtime CUDA + Pytorch Cuda, I can easily move the normalizations to the GPU using:

        frame = (
            torch.from_numpy(frame)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .float()
            .mul_(1 / 255)
        )

        frame = frame.contiguous()
        frame = frame.to(self.device).half() # self device being cuda

        self.model.run_with_iobinding(self.IoBinding)

        frame = self.output.squeeze(0).permute(1, 2, 0).mul_(255).byte().cpu().numpy()
        return frame

And then just inference with pytorch tensors.

Would there be a workaround to allow for the normalizations to be moved in one form or another to the GPU For faster inference.
For w/e it's worth I have some performance benchmarks here:

"""
frame= 153 fps=3.0 q=-0.0 Lsize=N/A time=00:00:06.38 bitrate=N/A speed=0.126x
Compact no video encoding, 1080p, onnxruntime directml, fp16, with clamp 0-255
"""

To reproduce

class CompactDirectML:
    def __init__(self):
        model = (
            r"G:\TheAnimeScripter\2x_AnimeJaNai_HD_V3_Compact_583k-fp16.onnx"
        )

        print(f"Using model: {model}")

        providers = ort.get_available_providers()

        if "DmlExecutionProvider" in providers:
            self.model = ort.InferenceSession(
                model, providers=["DmlExecutionProvider"]
            )
        else:
            self.model = ort.InferenceSession(model, providers=["CPUExecutionProvider"])

        self.IoBinding = self.model.io_binding()
        self.frame = np.zeros((1, 3, 1080, 1920), dtype=np.float16)
        self.output = np.zeros((1, 3, 2160, 3840), dtype=np.float16)

        self.IoBinding.bind_output(
            name='output',
            device_type='cpu',
            device_id=0,
            element_type=np.float16,
            shape=self.output.shape,
            buffer_ptr=self.output.ctypes.data,
        )

    def run(self, frame: np.ndarray) -> np.ndarray:
        frame = frame.transpose((2, 0, 1))
        frame = frame.reshape(1, 3, 1080, 1920)
        frame = frame / 255
        frame = frame.astype(np.float16)

        np.copyto(self.frame, frame)

        self.IoBinding.bind_input(
            name='input',
            device_type='cpu',
            device_id=0,
            element_type=np.float16,
            shape=self.frame.shape,
            buffer_ptr=self.frame.ctypes.data,
        )

        self.model.run_with_iobinding(self.IoBinding)

        output = self.output.reshape(3, 2160, 3840)
        output = output.transpose((1, 2, 0))
        output *= 255
        output = np.clip(output, 0, 255) # interesting side node, pure Pytorch doesn't need clipping whilst onnx requires it, maybe a model issue?
        output = output.astype(np.uint8)

        return output

Urgency

No, used for benchmarking purposes only to compare to NCNN inference performances.

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

17.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

DirectML

Execution Provider Library Version

No response

Model File

https://github.com/NevermindNilas/TAS-Modes-Host/releases/download/main/2x_AnimeJaNai_HD_V3_Compact_583k-fp16.onnx

Is this a quantized model?

Unknown

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:DML issues related to the DirectML execution provider platform:windows issues related to the Windows platform labels Mar 30, 2024
Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:DML issues related to the DirectML execution provider platform:windows issues related to the Windows platform stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

1 participant