Skip to content

Commit

Permalink
Fix/cos nan (#1090)
Browse files Browse the repository at this point in the history
* fix pytest's cosine

* fix similarity

---------

Co-authored-by: hejunchao <hejunchao@canaan-creative.com>
Co-authored-by: HeJunchao100813 <HeJunchao100813@users.noreply.github.com>
Co-authored-by: Curio Yang <curioyhq@gmail.com>
  • Loading branch information
4 people committed Oct 16, 2023
1 parent 01da2d0 commit 3930f47
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 10 deletions.
15 changes: 15 additions & 0 deletions tests/caffe_test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import caffe
from test_runner import *
import os
Expand Down
50 changes: 47 additions & 3 deletions tests/compare_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
import enum
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import math
import os
import re
Expand All @@ -11,7 +25,38 @@


def cosine(gt: np.ndarray, pred: np.ndarray, *args):
return (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2))
# remove the NaN values in the same location.
if np.isnan(gt).any() and np.isnan(pred).any():
gt_mask = np.isnan(gt)
pred_mask = np.isnan(pred)
mask = gt_mask & pred_mask
gt = gt[~mask]
pred = pred[~mask]

# remove the INF values in the same location.
if np.isinf(gt).any() and np.isinf(pred).any():
gt_mask = np.isinf(gt)
pred_mask = np.isinf(pred)
mask = gt_mask & pred_mask
gt = gt[~mask]
pred = pred[~mask]

# return -1 if the nan/inf value is still in the array.
if np.isnan(gt).any() or np.isnan(pred).any() or np.isinf(gt).any() or np.isinf(pred).any():
return -1

# exclude the situation of all zeros in array.
if compare_arrays(gt, pred):
return 1

result = (gt @ pred) / (np.linalg.norm(gt, 2) * np.linalg.norm(pred, 2))

# When tensor gt is a multiple of tensor pred, their similarity is also 1.
return -1 if math.isnan(result) else result


def compare_arrays(gt: np.ndarray, pred: np.ndarray):
return np.array_equal(gt, pred)


def euclidean(gt: np.ndarray, pred: np.ndarray, *args):
Expand Down Expand Up @@ -86,7 +131,6 @@ def compare_ndarray(expected: np.ndarray,
threshold: float = 0.99,
dump_hist: bool = True,
dump_file: str = 'hist.csv') -> bool:

if expected.size == actual.size:
similarity = similarity_func[similarity_name](expected.flatten(), actual.flatten())
else:
Expand Down
15 changes: 15 additions & 0 deletions tests/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

from typing import List, Dict, Union, Tuple
import os
import nncase
Expand Down
15 changes: 15 additions & 0 deletions tests/generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

from typing import Any, Dict, List, Tuple, Union
import numpy as np
import os
Expand Down
4 changes: 2 additions & 2 deletions tests/importer/tflite_/basic/test_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, x):
strides = [
[1, 1],
[1, 3],
[5, 5]
# [5, 5]
]

paddings = [
Expand All @@ -69,7 +69,7 @@ def __call__(self, x):

dilations = [
[1, 1],
# [2, 2] there is a bug in tf.nn.depthwise_conv2d that produces incorrect output shape
# [2, 2] there is a bug in tf.nn.depthwise_conv2d that produces incorrect output shape.
]


Expand Down
15 changes: 15 additions & 0 deletions tests/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

from typing import List, Dict, Union, Tuple
import os
import nncase
Expand Down
15 changes: 15 additions & 0 deletions tests/json2md.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import argparse
import json
import os
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/test_cum_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ TEST_P(CumSumTest, cum_sum) {
.expect("create tensor failed");

// actual
float_t exclusive[] = {0};
float exclusive[] = {0};
auto exclusive_ptr = hrt::create(nncase::dt_float32, {1},
{reinterpret_cast<gsl::byte *>(exclusive),
sizeof(exclusive)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
float_t reverse[] = {0};
float reverse[] = {0};
auto reverse_ptr =
hrt::create(nncase::dt_float32, {1},
{reinterpret_cast<gsl::byte *>(reverse), sizeof(reverse)},
Expand Down
6 changes: 3 additions & 3 deletions tests/kernels/test_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ class RangeTest : public KernelTest,
auto step_value = GetFloatNumber("step");
auto typecode = GetDataType("lhs_type");

float_t begin_array[] = {begin_value};
float begin_array[] = {begin_value};
begin = hrt::create(typecode, shape,
{reinterpret_cast<gsl::byte *>(begin_array),
sizeof(begin_array)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

float_t end_array[] = {end_value};
float end_array[] = {end_value};
end = hrt::create(
typecode, shape,
{reinterpret_cast<gsl::byte *>(end_array), sizeof(end_array)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

float_t step_array[] = {step_value};
float step_array[] = {step_value};
step = hrt::create(typecode, shape,
{reinterpret_cast<gsl::byte *>(step_array),
sizeof(step_array)},
Expand Down
15 changes: 15 additions & 0 deletions tests/nuc_proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import os
import argparse
import stat
Expand Down
15 changes: 15 additions & 0 deletions tests/onnx_test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

from onnx import version_converter, helper
import onnxsim
import onnxruntime as ort
Expand Down
15 changes: 15 additions & 0 deletions tests/preprocess_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

def get_source_transpose_index(perm):
"""
transpose model output with postprocess to framework output
Expand Down
15 changes: 15 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import copy
import os
import re
Expand Down
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import os
import json
import numpy as np
Expand Down
15 changes: 15 additions & 0 deletions tests/tflite_test_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import tensorflow as tf
from test_runner import *
import os
Expand Down

0 comments on commit 3930f47

Please sign in to comment.