3x3 Median filter slow (on GPU) #7302
-
I'm writing a median filter for image processing, the ideal size of the kernel would be 9x9 but as a proof of concept I'm first implementing a 3x3 version. The current implementation works, but it's unexpectedly slow: it takes around 5.5ms on a roughly 2000x1300 float image on an Nvidia RTX 2080 Ti. The data is kept on the GPU only and the backend used is OpenCL. I'd expect it to be much faster, since quite little work is done per thread. Since Halide does not support mutating intermediate values, I've tried to emulate this behavior by dynamically creating a Func for each intermediate result and scheduling them using I'm attaching the code below. import halide as hl
def median_filter_3x3(values: hl.ImageParam):
col = hl.Var("col")
row = hl.Var("row")
values_bound = hl.BoundaryConditions.constant_exterior(
values, hl.f32(-1.0)
)
def index_to_name(prefix: str, i, step: int):
return prefix + "_" + str(i) + "_" + str(step)
all_variables = []
variables = [hl.Func(index_to_name("value", x, 0)) for x in range(3 * 3)]
for v in range(3 * 3):
dx = v // 3
dy = v % 3
variables[v][col, row] = values_bound[col + dx, row + dy]
all_variables.append(variables[v])
current_step = [0]
def cmpswp(a: int, b: int):
func_a = variables[a]
func_b = variables[b]
is_larger = hl.Func(index_to_name("larger", (a, b), current_step[0]))
is_larger[col, row] = func_a[col, row] > func_b[col, row]
new_a = hl.Func(index_to_name("value", a, current_step[0]))
new_a[col, row] = hl.select(is_larger[col, row], func_b[col, row], func_a[col, row])
new_b = hl.Func(index_to_name("value", b, current_step[0]))
new_b[col, row] = hl.select(is_larger[col, row], func_a[col, row], func_b[col, row])
variables[a] = new_a
variables[b] = new_b
all_variables.append(new_a)
all_variables.append(new_b)
all_variables.append(is_larger)
current_step[0] += 1
cmpswp(0, 1)
cmpswp(3, 4)
cmpswp(6, 7)
cmpswp(1, 2)
cmpswp(4, 5)
cmpswp(7, 8)
cmpswp(0, 1)
cmpswp(3, 4)
cmpswp(6, 7)
cmpswp(0, 3)
cmpswp(3, 6)
# cmpswp(0, 3) #
cmpswp(1, 4)
cmpswp(4, 7)
cmpswp(1, 4)
cmpswp(5, 8)
cmpswp(2, 5)
# cmpswp(5, 8) #
cmpswp(2, 4)
cmpswp(4, 6)
cmpswp(2, 4)
# cmpswp(1, 3) #
# cmpswp(2, 3) #
# cmpswp(5, 7) #
# cmpswp(5, 6) #
smoothed_value = hl.Func("smoothed_value")
smoothed_value[col, row] = variables[4][col, row]
def schedule(output: hl.Func, col_inner: hl.Var) -> None:
smoothed_value.compute_at(output, col_inner)
for func in all_variables:
func.compute_at(output, col_inner)
return smoothed_value, schedule |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I actually published a whole paper on this topic using Halide: https://dl.acm.org/doi/abs/10.1145/3450626.3459773 It includes Halide code in the supplemental material. First, for a compare and swap you want to use min/max rather than a select: (a, b) = (min(a, b), max(a, b)) Second, you can use a single Func and use update definitions to do the swaps. Use scatter and gather - they let you read and write groups of elements at once, e.g. to sort the elements at positions 1 and 4, you might write something like: a = f(x, y, 1) Finally the key to fast medium support (e.g. 9x9) median filters is having each thread produce a small tile of pixels. Most of the sorting work can be shared between neighboring pixels. |
Beta Was this translation helpful? Give feedback.
I actually published a whole paper on this topic using Halide: https://dl.acm.org/doi/abs/10.1145/3450626.3459773 It includes Halide code in the supplemental material.
First, for a compare and swap you want to use min/max rather than a select: (a, b) = (min(a, b), max(a, b))
Second, you can use a single Func and use update definitions to do the swaps. Use scatter and gather - they let you read and write groups of elements at once, e.g. to sort the elements at positions 1 and 4, you might write something like:
a = f(x, y, 1)
b = f(x, y, 4)
(a, b) = (min(a, b), max(a, b))
f(x, y, scatter(1, 4)) = gather(a, b)
Finally the key to fast medium support (e.g. 9x9) median filters is having each th…