Skip to content

Commit

Permalink
[transformer-kernel] turn off unit test printing (microsoft#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Jan 27, 2021
1 parent cd29f8b commit 91b1b7f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/torch16.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
# - 'docs/**'
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
Expand Down
19 changes: 11 additions & 8 deletions tests/unit/test_cuda_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def check_equal(first, second, atol=1e-2, verbose=False):
diction_x = {}
diction_y = {}

for i, (x, y) in enumerate(zip(first, second)):
print(x[1], y[1])
if verbose:
for i, (x, y) in enumerate(zip(first, second)):
print(x[1], y[1])

for i, (x, y) in enumerate(zip(first, second)):
k = 0
Expand All @@ -38,18 +39,20 @@ def check_equal(first, second, atol=1e-2, verbose=False):
diction_y[k, y[1]] = y[0]
if verbose:
print()
for i, (x, y) in enumerate(zip(diction_x, diction_y)):
print(x, y)
for i, (x, y) in enumerate(zip(diction_x, diction_y)):
print(x, y)

for i, (x, y) in enumerate(zip(diction_x, diction_y)):
if (x[0] == 1): continue
print("checking ", x[1], ":")
if verbose:
print("checking ", x[1], ":")
y = diction_y[x[0], x[1]]
x = diction_x[x[0], x[1]]
x = x.cpu().detach().numpy()
y = y.cpu().detach().numpy()
print(x)
print(y)
if verbose:
print(x)
print(y)

avgx = np.sum(abs(x), dtype=float)
countx = x.shape[0]
Expand All @@ -60,8 +63,8 @@ def check_equal(first, second, atol=1e-2, verbose=False):
if avgx != float('inf') and avgx != -float('inf'):
avgx = avgx / countx
tollerance = avgx * atol
print("tollerance is ", tollerance)
if verbose:
print("tollerance is ", tollerance)
print("x = {}".format(x.flatten()))
print("y = {}".format(y.flatten()))
print('-' * 80)
Expand Down

0 comments on commit 91b1b7f

Please sign in to comment.