Skip to content

Commit

Permalink
Merged in pipelines (pull request #174)
Browse files Browse the repository at this point in the history
Pipelines

Approved-by: Michel Steuwer <michel.steuwer@glasgow.ac.uk>
Approved-by: Christophe Dubach <christophe.dubach@ed.ac.uk>
  • Loading branch information
b-hagedorn authored and cdubach committed May 11, 2018
2 parents 158f97d + 03698fa commit 44e73a2
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 1 deletion.
194 changes: 194 additions & 0 deletions src/test/opencl/generator/pipelines/TestHalideSchedules.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package opencl.generator.pipelines

import ir._
import ir.ast._
import lift.arithmetic.SizeVar
import opencl.executor._
import opencl.ir._
import opencl.ir.pattern.{MapGlb, _}
import org.junit.Assert._
import org.junit._

object TestHalideSchedules extends TestWithExecutor

class TestHalideSchedules {

@Test
def singleStagePipelineHalideSchedules(): Unit = {
val bright = UserFun("bright", "x", "{ return min(x * 1.5f, 255.0f); }", Float, Float)
val bright4 = UserFun("bright4", "x", "{ return min(x * 1.5f, 255.0f); }", Float4, Float4)

val M = 4
val N = 4

val input = Array.tabulate(M, N) { (i, j) => i * M + j * 1.0f}

def lambda(l: Lambda): Lambda = {
λ(ArrayType(ArrayType(Float, N), M),
image => l $ image
)
}

// initial version
val brighten = MapSeq(MapSeq(bright))

// schedule: brighten.reorder(y,x)
val brighten_col_major = TransposeW() o MapSeq(MapSeq(bright)) o Transpose()

// schedule: split(x, x_outer, x_inner, 2)
val brighten_split = MapSeq(Join() o MapSeq(MapSeq(bright)) o Split(2))

// schedule: fuse(x, y, fused)
val brighten_fused = MapSeq(bright) o Join()

//tile(x, y, x_outer, y_outer, x_inner, y_inner, 4, 4)
// -- or --
//split(x, x_outer, x_inner, 4);
//split(y, y_outer, y_inner, 4);
//reorder(x_inner, y_inner, x_outer, y_outer);
val brighten_tiled =
Join() o Map(Map(Join()) o TransposeW()) o // untile
MapSeq(MapSeq(MapSeq(MapSeq(bright)))) o // compute
Map(Transpose()) o Split(2) o Map(Split(2)) // tile

// vectorize(x, 4);
// -- or --
// split(x, x, x_inner, 4);
// vectorize(x_inner);
val brighten_in_vectors = MapSeq(asScalar() o MapSeq(bright4) o asVector(4))

// split(x, x_outer, x_inner, 2);
// unroll(x_inner);
val brighten_unroll = MapSeq(Join() o MapSeq(MapSeqUnroll(bright)) o Split(2))

// tile(x, y, x_outer, y_outer, x_inner, y_inner, 2, 2)
// fuse(x_outer, y_outer, tile_index)
// parallel(tile_index);
val brighten_fused_tiles =
Join() o Map(Map(Join()) o TransposeW()) o // untile
Split(2) o // INVESTIGATE
MapGlb(MapSeq(MapSeq(bright))) o // compute
Join() o // fuse
Map(Transpose()) o Split(2) o Map(Split(2)) // tile
println(Compile(lambda(brighten_fused_tiles)))


val (outBrighten, _) = Execute(1,1,1,1,(false,false))[Array[Float]](lambda(brighten), input)
val (outBrightenColMajor, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_col_major), input)
val (outBrightenSplit, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_split), input)
val (outBrightenFused, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_fused), input)
val (outBrightenTiled, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_tiled), input)
val (outBrightenInVectors, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_in_vectors), input)
val (outBrightenUnroll, _) = Execute(1,1, 1, 1,(false,false))[Array[Float]](lambda(brighten_unroll), input)
val (outBrightenFusedTiles, _) = Execute(1,1, 4, 1,(false,false))[Array[Float]](lambda(brighten_fused_tiles), input)

assertArrayEquals(outBrightenColMajor, outBrighten, 0.0f)
assertArrayEquals(outBrightenSplit, outBrighten, 0.0f)
assertArrayEquals(outBrightenFused, outBrighten, 0.0f)
assertArrayEquals(outBrightenTiled, outBrighten, 0.0f)
assertArrayEquals(outBrightenInVectors, outBrighten, 0.0f)
assertArrayEquals(outBrightenUnroll, outBrighten, 0.0f)
assertArrayEquals(outBrightenFusedTiles, outBrighten, 0.0f)
}

@Test
def halideLesson05Fast(): Unit = {
val bright = UserFun("bright", "x", "{ return min(x * 1.5f, 255.0f); }", Float, Float)
val bright4 = UserFun("bright4", "x", "{ return min(x * 1.5f, 255.0f); }", Float4, Float4)
val M = 128
val N = 128
val input = Array.tabulate(M, N) { (i, j) => i * M + j * 1.0f}

def lambda(l: Lambda): Lambda = {
λ(ArrayType(ArrayType(Float, N), M),
image => l $ image
)
}

// [x][y]
val brighten0 = MapSeq(MapSeq(bright))

// construct
// brighten
// .tile(x, y, x_outer, y_outer, x_inner, y_inner, 64, 64) -> brighten1
// .fuse(x_outer, y_outer, tile_index) -> brighten2
// .parallel(tile_index); -> brighten3
// .tile(x_inner, y_inner, x_inner_outer, y_inner_outer, x_vectors, y_pairs, 4, 2) -> brighten4
// .vectorize(x_vectors) -> brighten5
// .unroll(y_pairs); -> brighten6

// [x_inner][y_inner][x_outer][y_outer]
val brighten1 =
Join() o Map(Map(Join()) o TransposeW()) o // REVERSE(tile)
MapSeq(MapSeq(MapSeq(MapSeq(bright)))) o // add two maps
Map(Transpose()) o Split(64) o Map(Split(64)) // tile(x, y, x_outer, y_outer, x_inner, y_inner, 64, 64)

// [x_inner][y_inner][tile_index]
val brighten2 =
Join() o Map(Map(Join()) o TransposeW()) o
Split(2) o // REVERSE(fuse) -- find out how many?!
MapSeq(MapSeq(MapSeq(bright))) o // remove one map
Join() o // fuse(x_outer, y_outer, tile_index)
Map(Transpose()) o Split(64) o Map(Split(64))

// [x_inner][y_inner][tile_index]
val brighten3 =
Join() o Map(Map(Join()) o TransposeW()) o
Split(2) o
MapGlb(MapSeq(MapSeq(bright))) o // parallel(tile_index)
Join() o
Map(Transpose()) o Split(64) o Map(Split(64))

// [x_vectors][y_pairs][x_inner_outer][y_inner_outer][tile_index]
val brighten4 =
Join() o Map(Map(Join()) o TransposeW()) o
Split(2) o
MapGlb(
Join() o Map(Map(Join()) o TransposeW()) o PrintType() o // REVERSE(tile)
MapSeq(MapSeq(MapSeq(MapSeq(bright)))) o // add two maps
Map(Transpose()) o Split(2) o MapSeq(Split(4)) // tile(x_inner, y_inner, x_inner_outer, y_inner_outer, x_vectors, y_pairs, 4, 2)
) o
Join() o
Map(Transpose()) o Split(64) o Map(Split(64))

// [x_vectors][y_pairs][x_inner_outer][y_inner_outer][tile_index]
val brighten5 =
Join() o Map(Map(Join()) o TransposeW()) o
Split(2) o
MapGlb(
Join() o Map(Map(Join()) o TransposeW()) o
MapSeq(MapSeq(Split(4) o asScalar() o MapSeq(bright4))) o // REVERSE(vectorize) = add asScalar + reconstruct array from vector using split
Map(Transpose()) o Split(2) o MapSeq(asVector(4)) // vectorize(x_vectors) = use vectorized UserFun + remove one map
) o
Join() o
Map(Transpose()) o Split(64) o Map(Split(64))

// [x_vectors][y_pairs][x_inner_outer][y_inner_outer][tile_index]
val brighten6 =
Join() o Map(Map(Join()) o TransposeW()) o
Split(2) o
MapGlb(
Join() o Map(Map(Join()) o TransposeW()) o
MapSeq(MapSeq(Split(4) o asScalar() o MapSeqUnroll(bright4))) o // unroll(y_pairs)
Map(Transpose()) o Split(2) o MapSeq(asVector(4))
) o
Join() o
Map(Transpose()) o Split(64) o Map(Split(64))


val (outBrighten0, _) = Execute(1,1,1,1,(false,false))[Array[Float]](lambda(brighten0), input)
val (outBrighten1, _) = Execute(1,1,1,1,(false,false))[Array[Float]](lambda(brighten1), input)
val (outBrighten2, _) = Execute(1,1,1,1,(false,false))[Array[Float]](lambda(brighten2), input)
val (outBrighten3, _) = Execute(1,1,32,1,(false,false))[Array[Float]](lambda(brighten3), input)
val (outBrighten4, _) = Execute(1,1,32,1,(false,false))[Array[Float]](lambda(brighten4), input)
val (outBrighten5, _) = Execute(1,1,32,1,(false,false))[Array[Float]](lambda(brighten5), input)
val (outBrighten6, _) = Execute(1,1,32,1,(false,false))[Array[Float]](lambda(brighten6), input)

assertArrayEquals(outBrighten1, outBrighten0, 0.0f)
assertArrayEquals(outBrighten2, outBrighten0, 0.0f)
assertArrayEquals(outBrighten3, outBrighten0, 0.0f)
assertArrayEquals(outBrighten4, outBrighten0, 0.0f)
assertArrayEquals(outBrighten5, outBrighten0, 0.0f)
assertArrayEquals(outBrighten6, outBrighten0, 0.0f)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opencl.generator.stencil
package opencl.generator.pipelines

import ir._
import ir.ast._
Expand Down
156 changes: 156 additions & 0 deletions src/test/opencl/generator/pipelines/TestPipelineFusion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package opencl.generator.pipelines

import ir._
import ir.ast._
import lift.arithmetic.SizeVar
import opencl.executor._
import opencl.ir._
import opencl.ir.pattern.{MapGlb, _}
import org.junit.Assert._
import org.junit._

object TestPipelineFusion extends TestWithExecutor

class TestPipelineFusion {

@Test
def boxBlurPipelineFusion(): Unit = {
val M = SizeVar("M")
val N = SizeVar("N")

def lambda(l: Lambda): Lambda = {
λ(ArrayType(ArrayType(Float, N), M),
image => l $ image
)
}

val blurx =
Map(λ(row => Join() o
Map(λ(window =>
Reduce(add, 0.0f) o
Map(id) $ window) // some meaningful userfun
) o Slide(3,1) $ row))

val blury = Transpose() o blurx o Transpose()

val pipeline = lambda(blury o blurx)

val f0 = PrintType() o Transpose() o blurx o Transpose() o // blury
Map(Join() o Map(Reduce(add, 0.0f) o Map(id)) o Slide(3,1)) // blurx

val f1 = PrintType() o Transpose() o blurx o
Transpose() o
Map(Join()) o
Map(Map(Reduce(add, 0.0f))) o
Map(Map(Map(id))) o
Map(Slide(3,1))

val f2 = PrintType() o Transpose() o blurx o
Join() o
Map(Transpose()) o
Transpose() o
Map(Map(Reduce(add, 0.0f))) o
Map(Map(Map(id))) o
Map(Slide(3,1))

val f3 = PrintType() o Transpose() o blurx o
Join() o
Map(Transpose()) o
Map(Map(Reduce(add, 0.0f))) o
Transpose() o
Map(Map(Map(id))) o
Map(Slide(3,1))

val f4 = PrintType() o Transpose() o blurx o
Join() o
Map(Transpose()) o
Map(Map(Reduce(add, 0.0f))) o
Map(Map(Map(id))) o
Transpose() o
Map(Slide(3,1))

// high-level
val P = PrintType()
val T = Transpose()
val J = Join()
val S = Slide(3,1)
val R = Reduce(add, 0.0f)
val f = id
def *(f: Lambda) = Map(f)
def **(f: Lambda) = Map(Map(f))
def ***(f: Lambda) = Map(Map(Map((f))))
def ****(f: Lambda) = Map(Map(Map(Map((f)))))

// low-level
def %(f: Lambda) = MapSeq(f)
def %%(f: Lambda) = MapSeq(MapSeq(f))
val RS = ReduceSeq(add, 0.0f)
val TW = TransposeW()

val f5 = P o T o *(J) o **(R) o ***(f) o J o **(S) o *(T) o **(R) o ***(f) o T o *(S)
val f6 = P o T o *(J) o **(R) o J o ****(f) o **(S) o *(T) o **(R) o ***(f) o T o *(S)
val f7 = P o T o *(J) o J o ***(R) o ****(f) o **(S) o *(T) o **(R) o ***(f) o T o *(S)

// parallelizing f7
val p0 = P o T o *(J) o J o *(**(R) o ***(f) o *(S) o T o *(R) o **(f)) o T o *(S)
val p1 = P o T o *(J) o J o *(*(*(R) o **(f) o (S)) o T o *((R) o *(f))) o T o *(S)
val p2 = P o TransposeW() o *(J) o J o MapGlb(%(%(%(toGlobal(id)) o ReduceSeq(add, 0.0f)) o %(%(f)) o (S)) o T o %((%(toGlobal(id)) o ReduceSeq(add, 0.0f)) o %(f))) o T o *(S)

// moving slide forward
val f8 = P o T o *(J) o J o ***(R) o ****(f) o *(T) o **(T) o *(S) o **(R) o ***(f) o T o *(S)
val f9 = P o T o *(J) o J o ***(R) o ****(f) o *(T) o **(T) o ***(R) o *(S) o ***(f) o T o *(S)
val f10 = P o T o *(J) o J o ***(R) o ****(f) o *(T) o **(T) o ***(R) o ****(f) o *(S) o T o *(S)
val f11 = P o T o *(J) o J o ***(R) o ****(f) o P o *(T) o **(T) o P o ***(R) o ****(f) o T o *(T) o S o *(S)
val f12 = P o T o *(J) o J o *(T) o ***(R) o ****(f) o **(T) o ***(R) o ****(f) o T o *(T) o S o *(S)

// parallelizing f12
val g0 = P o T o *(J) o J o *(T) o **(*(R) o **(f) o T o *(R) o **(f)) o T o *(T) o S o *(S)

// lowering g0
// process 3x3 using workgroup and one thread computes blury
val g1 = TW o *(J) o J o *(TW) o
MapWrg(1)(MapWrg(0)(
MapLcl(
MapSeq(toGlobal(id)) o ReduceSeq(add, 0.0f) //o MapSeq(f)
) o Transpose() o // synchronization here!
MapLcl(
MapSeq(toLocal(id)) o ReduceSeq(add, 0.0f) //o MapSeq(f)
)
)) o T o *(T) o S o *(S)

// process 3x3 by single thread, reuse of blurx in private memory
val g2 = TW o *(J) o J o *(TW) o
MapGlb(1)(MapGlb(0)(
MapSeq(
MapSeq(toGlobal(id)) o ReduceSeq(add, 0.0f) //o MapSeq(f)
) o Transpose() o // synchronization here!
MapSeq(
MapSeq(toPrivate(id)) o ReduceSeq(add, 0.0f) //o MapSeq(f)
)
)) o T o *(T) o S o *(S)

val input = Array.tabulate(32, 32) { (i, j) => i * 32.0f + j }
val (outG1, _) = Execute(1,1,32,32,(false,false))[Array[Float]](lambda(g1), input)
val (outG2, _) = Execute(1,1,32,32,(false,false))[Array[Float]](lambda(g2), input)
assertArrayEquals(outG1, outG2, 0.1f)

TypeChecker(lambda(f0))
TypeChecker(lambda(f1))
TypeChecker(lambda(f2))
TypeChecker(lambda(f3))
TypeChecker(lambda(f4))
TypeChecker(lambda(f5))
TypeChecker(lambda(f6))
TypeChecker(lambda(p0))
TypeChecker(lambda(p1))
TypeChecker(lambda(p2))
TypeChecker(lambda(f8))
TypeChecker(lambda(f9))
TypeChecker(lambda(f10))
TypeChecker(lambda(f11))
TypeChecker(lambda(f12))
TypeChecker(lambda(g0))
TypeChecker(lambda(g1))
TypeChecker(lambda(g2))
}
}

0 comments on commit 44e73a2

Please sign in to comment.