-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from JoDio-zd/main
feat: add coreml
- Loading branch information
Showing
12 changed files
with
587 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
include(ExternalProject) | ||
|
||
if (ENABLE_NNDEPLOY_INFERENCE_COREML STREQUAL "OFF") | ||
else() | ||
set(NNDEPLOY_THIRD_PARTY_LIBRARY ${NNDEPLOY_THIRD_PARTY_LIBRARY} "/System/Library/Frameworks/CoreML.framework") | ||
set(NNDEPLOY_THIRD_PARTY_LIBRARY ${NNDEPLOY_THIRD_PARTY_LIBRARY} "/System/Library/Frameworks/CoreVideo.framework") | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
#ifndef _NNDEPLOY_INFERENCE_COREML_COREML_CONVERT_H_ | ||
#define _NNDEPLOY_INFERENCE_COREML_COREML_CONVERT_H_ | ||
|
||
#include "nndeploy/base/common.h" | ||
#include "nndeploy/base/file.h" | ||
#include "nndeploy/base/glic_stl_include.h" | ||
#include "nndeploy/base/log.h" | ||
#include "nndeploy/base/macro.h" | ||
#include "nndeploy/base/object.h" | ||
#include "nndeploy/base/status.h" | ||
#include "nndeploy/device/device.h" | ||
#include "nndeploy/device/tensor.h" | ||
#include "nndeploy/inference/coreml/coreml_include.h" | ||
#include "nndeploy/inference/coreml/coreml_inference_param.h" | ||
#include "nndeploy/inference/inference_param.h" | ||
|
||
namespace nndeploy { | ||
namespace inference { | ||
|
||
class CoremlConvert { | ||
public: | ||
// TODO: these two functions are for buffer type kind data | ||
static base::DataType convertToDataType(const OSType &src); | ||
static OSType convertFromDataType(const base::DataType &src); | ||
|
||
static base::DataFormat convertToDataFormat(const MLFeatureDescription &src); | ||
|
||
static MLFeatureDescription *convertFromDataFormat(const base::DataFormat &src); | ||
// You need to free it manually | ||
static NSObject *convertFromDeviceType(const base::DeviceType &src); | ||
|
||
static device::Tensor *convertToTensor(MLFeatureDescription *src, NSString *name, | ||
device::Device *device); | ||
static MLFeatureDescription *convertFromTensor(device::Tensor *src); | ||
|
||
static base::Status convertFromInferenceParam(CoremlInferenceParam *src, | ||
MLModelConfiguration *dst); | ||
}; | ||
|
||
} // namespace inference | ||
} // namespace nndeploy | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
#ifndef _NNDEPLOY_INFERENCE_COREML_COREML_INCLUDE_H_ | ||
#define _NNDEPLOY_INFERENCE_COREML_COREML_INCLUDE_H_ | ||
#import <CoreML/CoreML.h> | ||
#import <CoreServices/CoreServices.h> | ||
#import <Foundation/Foundation.h> | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
#ifndef _NNDEPLOY_INFERENCE_COREML_COREML_INFERENCE_H_ | ||
#define _NNDEPLOY_INFERENCE_COREML_COREML_INFERENCE_H_ | ||
|
||
#include "nndeploy/base/common.h" | ||
#include "nndeploy/base/log.h" | ||
#include "nndeploy/base/macro.h" | ||
#include "nndeploy/base/object.h" | ||
#include "nndeploy/base/shape.h" | ||
#include "nndeploy/base/status.h" | ||
#include "nndeploy/base/value.h" | ||
#include "nndeploy/device/device.h" | ||
#include "nndeploy/device/tensor.h" | ||
#include "nndeploy/inference/coreml/coreml_convert.h" | ||
#include "nndeploy/inference/coreml/coreml_include.h" | ||
#include "nndeploy/inference/coreml/coreml_inference_param.h" | ||
#include "nndeploy/inference/inference.h" | ||
#include "nndeploy/inference/inference_param.h" | ||
|
||
namespace nndeploy { | ||
namespace inference { | ||
|
||
#define CHECK_ERR(err) \ | ||
if (err) NSLog(@"error: %@", err); | ||
|
||
class CoremlInference : public Inference { | ||
public: | ||
CoremlInference(base::InferenceType type); | ||
virtual ~CoremlInference(); | ||
|
||
virtual base::Status init(); | ||
virtual base::Status deinit(); | ||
|
||
virtual base::Status reshape(base::ShapeMap &shape_map); | ||
|
||
virtual int64_t getMemorySize(); | ||
|
||
virtual float getGFLOPs(); | ||
|
||
virtual device::TensorDesc getInputTensorAlignDesc(const std::string &name); | ||
virtual device::TensorDesc getOutputTensorAlignDesc(const std::string &name); | ||
|
||
virtual base::Status run(); | ||
|
||
private: | ||
base::Status allocateInputOutputTensor(); | ||
base::Status deallocateInputOutputTensor(); | ||
MLModel *mlmodel_ = nullptr; | ||
NSError *err_ = nil; | ||
MLModelConfiguration *config_ = nullptr; | ||
NSMutableDictionary *dict_ = nullptr; | ||
NSMutableDictionary *result_ = nullptr; | ||
}; | ||
|
||
} // namespace inference | ||
} // namespace nndeploy | ||
|
||
#endif |
42 changes: 42 additions & 0 deletions
42
include/nndeploy/inference/coreml/coreml_inference_param.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
|
||
#ifndef _NNDEPLOY_INFERENCE_COREML_COREML_INFERENCE_PARAM_H_ | ||
#define _NNDEPLOY_INFERENCE_COREML_COREML_INFERENCE_PARAM_H_ | ||
|
||
#include "nndeploy/device/device.h" | ||
#include "nndeploy/inference/coreml/coreml_include.h" | ||
#include "nndeploy/inference/inference_param.h" | ||
|
||
namespace nndeploy { | ||
namespace inference { | ||
|
||
class CoremlInferenceParam : public InferenceParam { | ||
public: | ||
CoremlInferenceParam(); | ||
virtual ~CoremlInferenceParam(); | ||
|
||
CoremlInferenceParam(const CoremlInferenceParam ¶m) = default; | ||
CoremlInferenceParam &operator=(const CoremlInferenceParam ¶m) = default; | ||
|
||
PARAM_COPY(CoremlInferenceParam) | ||
PARAM_COPY_TO(CoremlInferenceParam) | ||
|
||
virtual base::Status parse(const std::string &json, bool is_path = true); | ||
virtual base::Status set(const std::string &key, base::Value &value); | ||
virtual base::Status get(const std::string &key, base::Value &value); | ||
|
||
/// @brief A Boolean value that determines whether to allow low-precision | ||
/// accumulation on a GPU. | ||
bool low_precision_acceleration_ = false; | ||
enum inferenceUnits { | ||
ALL_UNITS = 0, | ||
CPU_ONLY = 1, | ||
CPU_AND_GPU = 2, | ||
CPU_AND_NPU | ||
}; | ||
inferenceUnits inference_units_ = CPU_ONLY; | ||
}; | ||
|
||
} // namespace inference | ||
} // namespace nndeploy | ||
|
||
#endif |
Oops, something went wrong.