diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index e47c101..69473ef 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -154,13 +154,13 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyZOH = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>{})); // Val layout, 4 vals per read using GmemTiledCopyActiveMask = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>{})); // Val layout, 4 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{},