Skip to content

Commit

Permalink
Back out "[inductor] make thread order consistent with loop order (py…
Browse files Browse the repository at this point in the history
…torch#106827)"

Summary: D48295371 cause batch fusion failure, which will block our mc proposal on all mc models.

Test Plan: Without revert, f469732293. With revert diff f472266199.

Reviewed By: yanboliang

Differential Revision: D48593029

fbshipit-source-id: 751a3f6b20e51b728044852a4a5fd3a376529cce
  • Loading branch information
jackiexu1992 authored and facebook-github-bot committed Aug 23, 2023
1 parent 1809dc4 commit 4beaad5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 30 deletions.
14 changes: 2 additions & 12 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,11 @@ def set_last_usage(self, nodes):
)

def initialize_range_tree(self, pid_cache):
names = list(
reversed(["xindex", "yindex", "zindex"][: len(self.numels) - 1])
) + ["rindex"]
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
for i in range(len(self.numels)):
pid_idx = i if names[i][0] == "r" else "xyz".find(names[i][0])
self.range_trees.append(
IterationRangesRoot(
names[i], self.numels[i], names[i][0], pid_idx, self, pid_cache
names[i], self.numels[i], names[i][0], i, self, pid_cache
)
)
for tree in self.range_trees:
Expand Down Expand Up @@ -2003,13 +2000,6 @@ def dense_size_str(self):
sizes.append(f"{tree.prefix.upper()}BLOCK")
elif tree.prefix == "r" and tree.numel != 1:
sizes.append("1")

if sizes[0:3] == ["ZBLOCK", "YBLOCK", "XBLOCK"]:
sizes[0:3] = reversed(sizes[0:3])

if sizes[0:2] == ["YBLOCK", "XBLOCK"]:
sizes[0:2] = reversed(sizes[0:2])

return f"[{', '.join(sizes)}]"

def call_kernel(self, name: str):
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,15 +905,15 @@ def index_cmp(a, b):

# equivalent to
# np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
a_first = sum(
a_first = all(
sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
)
b_first = sum(
b_first = all(
sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
)
if a_first > b_first:
if a_first and not b_first:
return -1
if b_first > a_first:
if b_first and not a_first:
return 1

# otherwise contiguous
Expand Down
15 changes: 1 addition & 14 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,10 +672,6 @@ def triton_config(
override the num_elements_per_warp.
"""
# Ideally we want to read this from some device config

# for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK
size_hints = list(reversed(size_hints))

maxGridSize = [2147483647, 65535, 65535]

target = conditional_product(x, y, z)
Expand Down Expand Up @@ -1009,18 +1005,9 @@ def foreach(meta, num_warps, filename=None):
)


def grid(*numels):
def grid(xnumel, ynumel=None, znumel=None):
"""Helper function to compute triton grids"""

if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")

def get_grid_dim(numel, block):
if numel is None:
return 1
Expand Down

0 comments on commit 4beaad5

Please sign in to comment.