Skip to content

Commit

Permalink
Make atom type a make_2d_copy argument (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandschulz authored Apr 18, 2024
1 parent 53147de commit b2746a2
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ namespace cute
auto [y, x] = src.data().coord_;
XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data());
}
};


template <class GTensor>
struct Copy_Traits<XE_2D_SAVE, GTensor>
{
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails
using ThrID = Layout<_1>;
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct?
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;

GTensor tensor;

template <class TS, class SLayout,
class TD, class DLayout>
Expand All @@ -52,11 +69,11 @@ namespace cute
}
};

template <class GEngine, class GLayout>
template <class Copy, class GEngine, class GLayout>
auto make_xe_2d_copy(Tensor<GEngine, GLayout> gtensor)
{
using GTensor = Tensor<GEngine, GLayout>;
using Traits = Copy_Traits<XE_2D_LOAD, GTensor>;
using Traits = Copy_Traits<Copy, GTensor>;
Traits traits{gtensor};
return Copy_Atom<Traits, typename GEngine::value_type>{traits};
}
Expand Down

0 comments on commit b2746a2

Please sign in to comment.