Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions validated/vision/classification/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,78 @@ bash run_tuning.sh --input_model=path/to/model \ # model path as *.onnx
### Model inference
We use onnxruntime to perform Resnet50_fp32 and Resnet50_int8 inference. View the notebook [onnxrt_inference](../onnxrt_inference.ipynb) to understand how to use these 2 models for doing inference as well as which preprocess and postprocess we use.

## Model inference with AMD Ryzen AI for offloading AI workload to AMD NPU
This is an example showing how to compile and run the https://github.com/lihaofd/models/blob/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx on AMD's Ryzen AI NPU with ease of usage. Validated on AMD Ryzen™ AI 5 340 with 50 NPU TOPS.

### Install Ryzen AI msi with relative NPU driver from https://ryzenai.docs.amd.com/en/latest/inst.html

```shell
conda activate ryzen-ai-1.x.0
```

### Compile NPU Cache

```shell
cd RyzenAI
python compile.py resnet50-v2-7.onnx

WARNING: Logging before InitGoogleLogging() is written to STDERR
I20251024 16:56:58.719477 34428 register_ssmlp.cpp:124] Registering Custom Operator: com.amd:SSMLP
I20251024 16:56:58.720481 34428 register_matmulnbits.cpp:110] Registering Custom Operator: com.amd:MatMulNBits
I20251024 16:56:58.835477 34428 vitisai_compile_model.cpp:1266] Vitis AI EP Load ONNX Model Success
I20251024 16:56:58.835477 34428 vitisai_compile_model.cpp:1267] Graph Input Node Name/Shape (1)
I20251024 16:56:58.835477 34428 vitisai_compile_model.cpp:1271] data : [-1x3x224x224]
I20251024 16:56:58.835477 34428 vitisai_compile_model.cpp:1277] Graph Output Node Name/Shape (1)
I20251024 16:56:58.835477 34428 vitisai_compile_model.cpp:1281] resnetv24_dense0_fwd : [-1x1000]
Adding RYZEN_AI_INSTALLATION_PATH=C:\Program Files\RyzenAI\... to installation search path

...

Compilation Complete
(WARNING:95, CRITICAL-WARNING:0, ERROR:0)
[Vitis AI EP] No. of Operators : CPU 1 VAIML 140
[Vitis AI EP] No. of Subgraphs : NPU 1 Actually running on NPU 1


NPU cache model is saved as resnet50-v2-7_ctx.onnx

```

### Run NPU cache model directly (set num_runs as 10 in run.py, can track NPU utlization in task manager)

```shell
cd RyzenAI
python run.py resnet50-v2-7_ctx.onnx

WARNING: Logging before InitGoogleLogging() is written to STDERR
I20251024 17:34:19.560344 6760 register_ssmlp.cpp:124] Registering Custom Operator: com.amd:SSMLP
I20251024 17:34:19.561345 6760 register_matmulnbits.cpp:110] Registering Custom Operator: com.amd:MatMulNBits
I20251024 17:34:19.574344 6760 vitisai_compile_model.cpp:1266] Vitis AI EP Load ONNX Model Success
I20251024 17:34:19.574344 6760 vitisai_compile_model.cpp:1267] Graph Input Node Name/Shape (1)
I20251024 17:34:19.574344 6760 vitisai_compile_model.cpp:1271] data : [-1x3x224x224]
I20251024 17:34:19.574344 6760 vitisai_compile_model.cpp:1277] Graph Output Node Name/Shape (1)
I20251024 17:34:19.574344 6760 vitisai_compile_model.cpp:1281] resnetv24_dense0_fwd : [-1x1000]
[Vitis AI EP] No. of Subgraphs supported by Vitis AI EP: VAIML 1
'''

Top 3 Probabilities
[208 209 207]
------------------------------------|------------
Classification |Percentage
------------------------------------|------------
Labrador retriever | 54.27
------------------------------------|------------
Chesapeake Bay retriever | 6.09
------------------------------------|------------
golden retriever | 4.74
------------------------------------|------------
INFO: Test passed

```




## References
* **ResNetv1**
[Deep residual learning for image recognition](https://arxiv.org/abs/1512.03385)
Expand All @@ -210,6 +282,9 @@ In European Conference on Computer Vision, pp. 630-645. Springer, Cham, 2016.

* [Intel® Neural Compressor](https://github.com/intel/neural-compressor)

* [AMD Ryzen AI](https://ryzenai.docs.amd.com/en/latest/inst.html)


## Contributors
* [ankkhedia](https://github.com/ankkhedia) (Amazon AI)
* [abhinavs95](https://github.com/abhinavs95) (Amazon AI)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions validated/vision/classification/resnet/RyzenAI/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import sys
import json
import onnxruntime as ort

onnx_model_path = sys.argv[1]
model_name = os.path.splitext(os.path.basename(onnx_model_path))[0]
ctx_cache = model_name + "_ctx.onnx"

# Delete prexisting EP context cache model
if os.path.exists(ctx_cache):
print(f"INFO: EP context model {ctx_cache} already exists. Deleting it.")
os.remove(ctx_cache)

session_options = ort.SessionOptions()
session_options.add_session_config_entry('ep.context_enable', '1')
session_options.add_session_config_entry('ep.context_file_path', ctx_cache)
session_options.add_session_config_entry('ep.context_embed_mode', '1')
onnx_session = ort.InferenceSession(
onnx_model_path,
sess_options=session_options,
providers=["VitisAIExecutionProvider"],
provider_options=[{"cache_dir": os.getcwd(),
"cache_key": model_name,}]
)
79 changes: 79 additions & 0 deletions validated/vision/classification/resnet/RyzenAI/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, List, Tuple

import numpy as np
import numpy.typing as npt
import torch
from PIL import Image # type: ignore [import-untyped]
from torchvision import transforms


def load_and_preprocess_image(image_file: str) -> torch.Tensor:
"""
Load and preprocess image_file for inference test
It works for all imagenet images
"""

img = Image.open(image_file).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)


def softmax(vector: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
"""
Calculate softmax of a vector
"""
e = np.exp(vector)
res: npt.NDArray[np.float32] = e / e.sum()
return res


def top_n_probabilities(
res: npt.NDArray[np.float32],
labels: List[str],
top_n: int = 3,
run_softmax: bool = False,
) -> List[Tuple[Any, Any]]:
"""
Compute probabilities of top 3 classifications from res
Inputs:
data_in: output as 1-D numpy array from full connected layer or softmax
run_softmax: whether or not to run softmax on data_in
"""
indices = np.flip(np.argsort(res))
if run_softmax:
percentage = softmax(res) * 100
else:
percentage = res * 100

print(indices[:3])
top_n_result = [(labels[idx], percentage[idx].item()) for idx in indices[:3]]

return top_n_result


def top3_probabilities(
data_in: npt.NDArray[np.float32], labels: List[str], run_softmax: bool = False
) -> List[Tuple[Any, Any]]:
"""
Helper function to get top 3 probabilities for backward compatibility
"""
top3 = top_n_probabilities(data_in, labels, top_n=3, run_softmax=run_softmax)

return top3


def load_labels(label_file: str) -> List[str]:
classes_fh = open(label_file)
labels = [line.strip() for line in classes_fh]

classes_fh.close()
return labels
Loading