Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/interface/blas1/rotmg.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,22 @@ template typename SB_Handle::event_t _rotmg(
BufferIterator<${DATA_TYPE}> _y1, BufferIterator<${DATA_TYPE}> _param,
const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, BufferIterator<${DATA_TYPE}> _d1,
BufferIterator<${DATA_TYPE}> _d2, BufferIterator<${DATA_TYPE}> _x1,
${DATA_TYPE} _y1, BufferIterator<${DATA_TYPE}> _param,
const typename SB_Handle::event_t& dependencies);

#ifdef SB_ENABLE_USM
template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
${DATA_TYPE} * _x1, ${DATA_TYPE} * _y1, ${DATA_TYPE} * _param,
const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
${DATA_TYPE} * _x1, ${DATA_TYPE} _y1, ${DATA_TYPE} * _param,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
33 changes: 28 additions & 5 deletions src/interface/blas1_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,37 @@ typename sb_handle_t::event_t _rotmg(
auto d1_view = make_vector_view(_d1, inc, vector_size);
auto d2_view = make_vector_view(_d2, inc, vector_size);
auto x1_view = make_vector_view(_x1, inc, vector_size);
auto y1_view = make_vector_view(_y1, inc, vector_size);
auto param_view = make_vector_view(_param, inc, param_size);

auto operation =
Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view, y1_view, param_view);
auto ret = sb_handle.execute(operation, _dependencies);
if constexpr (std::is_arithmetic_v<container_3_t>) {
constexpr helper::AllocType mem_type = std::is_pointer_v<container_0_t>
? helper::AllocType::usm
: helper::AllocType::buffer;
auto _y1_tmp = blas::helper::allocate<mem_type, container_3_t>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, @s-Nick.

I think that _y1_tmp needs to be deallocated when everything is done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review @pgorlani

yes, you are right. I addressed it in a4e833c and update it in c128d67 to avoid host_task usage that causes some issues.

1, sb_handle.get_queue());

return ret;
auto copy_y1 = blas::helper::copy_to_device(sb_handle.get_queue(), &_y1,
_y1_tmp, 1, _dependencies);

auto y1_view = make_vector_view(_y1_tmp, inc, vector_size);
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
y1_view, param_view);

auto operator_event =
sb_handle.execute(operation, typename sb_handle_t::event_t{copy_y1});
if constexpr (mem_type != helper::AllocType::buffer) {
// This wait is necessary to free the temporary memory created above and
// avoiding the host_task
operator_event[0].wait();
sycl::free(_y1_tmp, sb_handle.get_queue());
}
return operator_event;
} else {
auto y1_view = make_vector_view(_y1, inc, vector_size);
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
y1_view, param_view);
return sb_handle.execute(operation, _dependencies);
}
}

/**
Expand Down
Loading