Skip to content

Commit

Permalink
Merge pull request #8010 from Zhyrek/add_break_continue
Browse files Browse the repository at this point in the history
Support for break and continue keywords
  • Loading branch information
asi1024 committed Dec 19, 2023
2 parents 9050fa6 + a0cd5cd commit 7f55d35
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cupyx/jit/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,9 @@ def _transpile_stmt(
if isinstance(stmt, ast.Pass):
return [';']
if isinstance(stmt, ast.Break):
raise NotImplementedError('Not implemented.')
return ['break;']
if isinstance(stmt, ast.Continue):
raise NotImplementedError('Not implemented.')
return ['continue;']
assert False


Expand Down
71 changes: 71 additions & 0 deletions tests/cupyx_tests/jit_tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,77 @@ def f(x, m):
y[:mask] += 1
assert bool((x == y).all())

def test_loop_continue(self):
@jit.rawkernel()
def f(x, y, z):
tid = jit.grid(1)

for i in range(10):
# adds 0-9, except for 5.
# Sum is 40
if i == 5:
continue
x[tid] += i

i2 = 0
while i2 < 9:
# adds 1-9 in a while loop, except for 6,
# should equal 39
i2 += 1
if i2 == 6:
continue
y[tid] += i2

for i in range(11):
# adds 0-10, but skips if the sum is greater than 3*i,
# skips 8 and 9, but not 10 (28 < 3*10), sum is 38
if z[tid] > 3*i:
continue
z[tid] += i

x = cupy.zeros(32, dtype=int)
y = cupy.zeros(32, dtype=int)
z = cupy.zeros(32, dtype=int)
f[1, 32](x, y, z)
assert bool((x == 40).all())
assert bool((y == 39).all())
assert bool((z == 38).all())

def test_loop_break(self):
@jit.rawkernel()
def f(x, y, z):
tid = jit.grid(1)

for i in range(10):
# adds 0-4,
# break at 5. Sum is 10
if i == 5:
break
x[tid] += i

i2 = 0
while i2 < 9:
# adds 1-5 in a while loop,
# breaks at 6, should equal 15
i2 += 1
if i2 == 6:
break
y[tid] += i2
for i in range(11):
# adds 0-10, but stops once the sum is greater than 3*i,
# breaks at 8 (28 > 3*8), sum is 28
if z[tid] > 3*i:
break
z[tid] += i

x = cupy.zeros(32, dtype=int)
y = cupy.zeros(32, dtype=int)
z = cupy.zeros(32, dtype=int)
f[1, 32](x, y, z)
assert bool((x == 10).all())
assert bool((y == 15).all())
assert bool((z == 28).all())

def test_shared_memory_static(self):
@jit.rawkernel()
def f(x, y):
Expand Down

0 comments on commit 7f55d35

Please sign in to comment.