Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Excessive Copy Loops #121

Open
matth2k opened this issue Dec 12, 2023 · 5 comments
Open

[BUG] Excessive Copy Loops #121

matth2k opened this issue Dec 12, 2023 · 5 comments
Labels
bug Something isn't working

Comments

@matth2k
Copy link

matth2k commented Dec 12, 2023

Describe the bug
Excessive copy loops are created due to data type conversion of tensors expressed in the linalg dialect.

To Reproduce

def test_vadd():
    from allo import add

    def kernel(A: uint32[N], B: uint32[N]) -> uint32[N]:
        return A + B

    s = allo.customize(kernel)
    print(s.module)

Buggy output

#map = affine_map<(d0) -> (d0)>
module {
  func.func @kernel(%arg0: memref<20xi32>, %arg1: memref<20xi32>) -> memref<20xi32> attributes {itypes = "uu", otypes = "u"} {
    %alloc = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : memref<20xi32>) outs(%alloc : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_0 = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : memref<20xi32>) outs(%alloc_0 : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_1 = memref.alloc() : memref<20xi33>
    linalg.add {op_name = "add_0"} ins(%alloc, %alloc_0 : memref<20xi33>, memref<20xi33>) outs(%alloc_1 : memref<20xi33>)
    %alloc_2 = memref.alloc() {unsigned} : memref<20xi32>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%alloc_1 : memref<20xi33>) outs(%alloc_2 : memref<20xi32>) {
    ^bb0(%in: i33, %out: i32):
      %0 = arith.trunci %in : i33 to i32
      linalg.yield %0 : i32
    }
    return %alloc_2 : memref<20xi32>
  }
}

In short, when this is lowered to affine it manifests in excessive copying in the beginning of the program, and our AMC flow is very sensitive to this.

What really should occur is noticing that value that addition is bound to is the same as the input type. So just make add (i32, i32) -> i32 with normal wraparound.

@matth2k matth2k added the bug Something isn't working label Dec 12, 2023
@chhzh123
Copy link
Member

Thanks for bringing this issue up. Allo has a strong type system, and that's why it requires to guarantee the intermediate results will not overflow. I think either we can (1) test the input and output data types and bypass the type extension rule if the types align; or (2) fuse those linalg operations into one.

As a workaround, you can explicitly traverse each element in the arrays using for loops so no linalg operations will be built.

@andrewb1999
Copy link

Your explanation makes sense @chhzh123 but I wonder what optimization HLS is doing to avoid this issue. Does it just fully unroll the copy loops so the extension and truncation can be no cost? If that's the case, maybe we can do a similar optimization in AMC to avoid this issue altogether.

@chhzh123
Copy link
Member

I think Vivado/Vitis HLS only unrolls loops with small loop bounds. Otherwise, we need to explicitly write an unroll pragma to inform HLS. However, unrolling may incur excessive resource usage. The best way I think is still fusing the loops into one.

@zhangzhiru
Copy link

zhangzhiru commented Dec 15, 2023

I think the main hiccup is that we are lowering to linalg, which is less expressive than imperative programs. So we have to extend both input vectors to int33 first, then add them, and finally truncate back to int32. To clean things up, we really need an extra pass to remove the unnecessary extend and truncate. Another option is insert in a primitive to fuse the loops so another optimization pass at a lower level can finish the job. This is not a good solution though.

@matth2k
Copy link
Author

matth2k commented Jan 14, 2024

I have the fix implemented within our AMC backend. But I will eventually come up with a more universal solution and submit it as a separate PR to Allo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants