Skip to content

Commit

Permalink
Create zDNN Status Message Environment variable (#2845)
Browse files Browse the repository at this point in the history
* Testing

* Fix syntax

* Switch compiler flag to env variable

* Fix format of tests

* My fault

* Fix env variable naming convention

* Small changes

* Fix check

* Testing things out

* Fix format

* Still trying here

* Fixing env var

* fix format

* Switching things around

* Update cmakeList file

* Cleanup

* Remove duplicate defines

* My fault

---------

Co-authored-by: Megan Hampton <hamptonm@us.ibm.com>
  • Loading branch information
hamptonm1 and MegoHam21 committed Jun 20, 2024
1 parent 693ba93 commit 6923cfd
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 18 deletions.
54 changes: 43 additions & 11 deletions src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,54 +222,86 @@ static zdnn_status zdnn_binary_elementwise_common(const zdnn_ztensor *inputA,

zdnn_status zdnn_add_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_ADD_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_ADD_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_add");
return status;
}

zdnn_status zdnn_sub_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_SUB_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_SUB_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_sub");
return status;
}

zdnn_status zdnn_mul_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MUL_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MUL_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_mul");
return status;
}

zdnn_status zdnn_div_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_DIV_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_DIV_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_div");
return status;
}

zdnn_status zdnn_min_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MIN_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MIN_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_min");
return status;
}

zdnn_status zdnn_max_ext(const zdnn_ztensor *inputA, const zdnn_ztensor *inputB,
zdnn_ztensor *output) {
return zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MAX_EXT);
zdnn_status status =
zdnn_binary_elementwise_common(inputA, inputB, output, ZDNN_MAX_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_max");
return status;
}

zdnn_status zdnn_exp_ext(const zdnn_ztensor *input, zdnn_ztensor *output) {
return zdnn_unary_elementwise_common(input, NULL, output, ZDNN_EXP_EXT);
zdnn_status status =
zdnn_unary_elementwise_common(input, NULL, output, ZDNN_EXP_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_exp");
return status;
}

zdnn_status zdnn_log_ext(const zdnn_ztensor *input, zdnn_ztensor *output) {
return zdnn_unary_elementwise_common(input, NULL, output, ZDNN_LOG_EXT);
zdnn_status status =
zdnn_unary_elementwise_common(input, NULL, output, ZDNN_LOG_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_log");
return status;
}

zdnn_status zdnn_relu_ext(const zdnn_ztensor *input, const void *clippingValue,
zdnn_ztensor *output) {
return zdnn_unary_elementwise_common(
zdnn_status status = zdnn_unary_elementwise_common(
input, clippingValue, output, ZDNN_RELU_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_relu");
return status;
}

zdnn_status zdnn_sigmoid_ext(const zdnn_ztensor *input, zdnn_ztensor *output) {
return zdnn_unary_elementwise_common(input, NULL, output, ZDNN_SIGMOID_EXT);
zdnn_status status =
zdnn_unary_elementwise_common(input, NULL, output, ZDNN_SIGMOID_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_sigmoid");
return status;
}

zdnn_status zdnn_tanh_ext(const zdnn_ztensor *input, zdnn_ztensor *output) {
return zdnn_unary_elementwise_common(input, NULL, output, ZDNN_TANH_EXT);
zdnn_status status =
zdnn_unary_elementwise_common(input, NULL, output, ZDNN_TANH_EXT);
CHECK_ZDNN_STATUS(status, "zdnn_tanh");
return status;
}

#ifdef __cplusplus
Expand Down
18 changes: 12 additions & 6 deletions src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ extern "C" {
static inline zdnn_status call_zdnn_matmul_op(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output, bool isBcast) {
zdnn_status status;
if (isBcast)
return zdnn_matmul_bcast_op(
status = zdnn_matmul_bcast_op(
inputA, inputB, inputC, (zdnn_matmul_bcast_ops)opType, output);
return zdnn_matmul_op(
inputA, inputB, inputC, (zdnn_matmul_ops)opType, output);
else
status =
zdnn_matmul_op(inputA, inputB, inputC, (zdnn_matmul_ops)opType, output);
CHECK_ZDNN_STATUS(status, "zdnn_matmul");
return status;
}

#ifndef __MVS__
Expand Down Expand Up @@ -104,7 +108,7 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA,
zdnn_ztensor *zyb = getTile(&siYB, j);
zdnn_status status =
call_zdnn_matmul_op(za, zb, zc, opType, zyb, isBcast);
assert(status == ZDNN_OK);
CHECK_ZDNN_STATUS(status, "zdnn_matmul");
if (OMZTensorSplitDebug) {
int cpuId = 0;
#ifdef __MVS__
Expand Down Expand Up @@ -152,8 +156,10 @@ static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA,
zdnn_status zdnn_matmul_op_ext(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output) {
return zdnn_matmul_op_common(
zdnn_status status = zdnn_matmul_op_common(
inputA, inputB, inputC, opType, output, /*isBcast=*/false);
CHECK_ZDNN_STATUS(status, "zdnn_matmul");
return status;
}

zdnn_status zdnn_matmul_bcast_op_ext(const zdnn_ztensor *inputA,
Expand All @@ -163,7 +169,7 @@ zdnn_status zdnn_matmul_bcast_op_ext(const zdnn_ztensor *inputA,
inputA, inputB, inputC, opType, output, /*isBcast=*/true);
// Compiler does not check the return result at this moment. Thus, check it
// here.
assert(status == ZDNN_OK);
CHECK_ZDNN_STATUS(status, "zdnn_matmul");
return status;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ zdnn_status zdnn_softmax_ext(const zdnn_ztensor *input, void *save_area,
zdnn_ztensor *zy = getTile(&siY, i);
zdnn_status status = zdnn_softmax(
zx, (siX.reuseFullZTensor) ? save_area : NULL, act_func, zy);
assert(status == ZDNN_OK);
CHECK_ZDNN_STATUS(status, "zdnn_softmax");
}
if (OMZTensorSplitDebug) {
end_time = clock();
Expand Down
22 changes: 22 additions & 0 deletions src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ extern "C" {
bool OMZTensorSplitEnabled = DEFAULT_ZTENSOR_SPLIT_ENABLED;
bool OMZTensorSplitDebug = DEFAULT_ZTENSOR_SPLIT_DEBUG;
uint32_t OMZTensorSplitSize = DEFAULT_ZTENSOR_SPLIT_SIZE;
bool OMStatusMessagesEnabled = DEFAULT_STATUS_MESSAGES_ENABLED;

static uint32_t ZTensorSplitSizeFromEnv() {
uint32_t cs = DEFAULT_ZTENSOR_SPLIT_SIZE;
Expand Down Expand Up @@ -59,6 +60,14 @@ static bool ZTensorSplitDebugFromEnv() {
return enabled;
}

static bool StatusMessagesEnabledEnv() {
int enabled = DEFAULT_STATUS_MESSAGES_ENABLED;
const char *s = getenv("OM_STATUS_MESSAGES_ENABLED");
if (s)
enabled = atoi(s);
return enabled;
}

// malloc_aligned_4k is from zdnn.
static void *malloc_aligned_4k(size_t size) {
// Request one more page + size of a pointer from the OS.
Expand Down Expand Up @@ -90,12 +99,25 @@ void zDNNExtensionInit() {
OMZTensorSplitEnabled = ZTensorSplitEnabledFromEnv();
OMZTensorSplitDebug = ZTensorSplitDebugFromEnv();
OMZTensorSplitSize = ZTensorSplitSizeFromEnv();
OMStatusMessagesEnabled = StatusMessagesEnabledEnv();
if (OMZTensorSplitDebug) {
printf("OM_ZTENSOR_SPLIT_ENABLED: %d\n", OMZTensorSplitEnabled);
printf("OM_ZTENSOR_SPLIT_SIZE: %d\n", OMZTensorSplitSize);
}
if (OMStatusMessagesEnabled) {
printf("OM_STATUS_MESSAGES_ENABLED: %d\n", OMStatusMessagesEnabled);
}
}

void checkStatus(zdnn_status status, const char *zdnn_name) {
if (OMStatusMessagesEnabled && status != ZDNN_OK) {
fprintf(stdout, "[zdnnx] %s : %s\n", zdnn_name,
zdnn_get_status_message(status));
}
}

#define CHECK_ZDNN_STATUS(status, zdnn_name) checkStatus(status, zdnn_name)

void getUnmappedShape(const zdnn_ztensor *t, UnmappedShape *shape) {
const zdnn_tensor_desc *desc = t->transformed_desc;
shape->e4 = desc->dim4;
Expand Down
18 changes: 18 additions & 0 deletions src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ extern "C" {
#define DEFAULT_ZTENSOR_SPLIT_ENABLED 0
// zTensor splitting debug is off by default.
#define DEFAULT_ZTENSOR_SPLIT_DEBUG 0
// zDNN status message is off by default.
#define DEFAULT_STATUS_MESSAGES_ENABLED 0

extern bool OMZTensorSplitEnabled;
extern bool OMZTensorSplitDebug;
extern uint32_t OMZTensorSplitSize;
// We want to enable zdnn status messages when a user
// manually specifies the environment variable.
extern bool OMStatusMessagesEnabled;

// -----------------------------------------------------------------------------
// Misc Macros
Expand Down Expand Up @@ -143,6 +148,19 @@ inline void omUnreachable() {
#endif
}

/**
* \brief Check zdnn status
*
* Check if the zdnn status is not a zdnn_ok and print out the
* status message along with the error
*
* @param status zdnn status
* @param zdnn_name name of the zdnn api
*/
void checkStatus(zdnn_status status, const char *zdnn_name);

#define CHECK_ZDNN_STATUS(status, zdnn_name) checkStatus(status, zdnn_name)

/**
* \brief Get the unmapped shape (4D) of ztensor.
*
Expand Down

0 comments on commit 6923cfd

Please sign in to comment.