Skip to content

Commit

Permalink
Use explicit multi_ptr in vec store (#343)
Browse files Browse the repository at this point in the history
* SYCL-BLAS breaks with SYCL-2020 because the vec store function no longer
implicity converts a raw pointer to a multi_ptr.
* This is due to a new decoration template parameter in SYCL-2020.
* This PR:
  * Converts raw pointers to multi_ptr for use with vec.store
  * Is compatible with SYCL-1.2.1
  * Is compatible with SYCL-2020 implementations that set the default
  multi_ptr decoration parameter value to "legacy".
    * This is not defined in the spec, but a common implementation detail
    since it helps SYCL-1.2.1 backwards compatibility.
* Tested with DPC++ nightly 2022/10/17 (SYCL-2020 multi_ptr, does not
compile before this change), and DPC++ nightly 2022/06/25 (SYCL-1.2.1
style multi_ptr).
  • Loading branch information
hjabird committed Oct 19, 2022
1 parent fc1fe35 commit dfa0579
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/operations/blas3/gemm_interleaved.hpp
Expand Up @@ -74,7 +74,7 @@ SYCL_BLAS_INLINE void load(cl::sycl::vec<T, Dim> &packet, PtrT ptr) {
template <address_t Address = address_t::global_space, class T, int Dim,
class PtrT>
SYCL_BLAS_INLINE void store(const cl::sycl::vec<T, Dim> &packet, PtrT ptr) {
packet.template store<Address>(0, ptr);
packet.template store<Address>(0, cl::sycl::multi_ptr<T, Address>(ptr));
}

} // namespace internal
Expand Down
3 changes: 2 additions & 1 deletion src/operations/blas3/gemm_load_store.hpp
Expand Up @@ -120,7 +120,8 @@ struct Packetize {
static SYCL_BLAS_INLINE typename std::enable_if<!trans>::type store(
PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
packet.template store<address_t::local_space>(0, dest);
packet.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<value_t, address_t::local_space>(dest));
}
};

Expand Down
3 changes: 2 additions & 1 deletion src/operations/blas3/gemm_local.hpp
Expand Up @@ -478,7 +478,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
0, cl::sycl::multi_ptr<const element_t, address_t::private_space>(reg));
out_vec *= alpha_;

out_vec.template store<address_t::global_space>(0, out_ptr);
out_vec.template store<address_t::global_space>(
0, cl::sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
}
/*!
* @brief Store the computed gemm result to the C matrix
Expand Down
16 changes: 11 additions & 5 deletions src/operations/blas3/gemm_no_local_full_vec.hpp
Expand Up @@ -325,7 +325,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
out_vec *= beta_;

out_vec.template store<address_t::private_space>(
0, reg_res + i * item_rows + j * packet_size);
0, cl::sycl::multi_ptr<element_t, address_t::private_space>(
reg_res + i * item_rows + j * packet_size));
}
}
C += ldc * (need_check_boundary || !trans_b ? wg_cols
Expand Down Expand Up @@ -593,7 +594,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
}
}
auto out_reg = &reg[(i * row_iters + j) * work_per_load];
in_vec.template store<address_t::private_space>(0, out_reg);
in_vec.template store<address_t::private_space>(
0,
cl::sycl::multi_ptr<element_t, address_t::private_space>(out_reg));
}
}
}
Expand Down Expand Up @@ -740,7 +743,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
}
}
}
in_vec.template store<address_t::private_space>(0, reg);
in_vec.template store<address_t::private_space>(
0, cl::sycl::multi_ptr<element_t, address_t::private_space>(reg));
}

/*!
Expand Down Expand Up @@ -802,7 +806,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
}
}
}
in_vec.template store<address_t::private_space>(0, reg);
in_vec.template store<address_t::private_space>(
0, cl::sycl::multi_ptr<element_t, address_t::private_space>(reg));
}
/*!
* @brief The following function computes the partial GEMM for the input
Expand Down Expand Up @@ -926,7 +931,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
out_vec *= alpha_;

out_vec.template store<address_t::global_space>(
0, C + j * wg_rows * packet_size);
0, cl::sycl::multi_ptr<element_t, address_t::global_space>(
C + j * wg_rows * packet_size));
}
}
C += ldc * (check_block || !trans_b ? wg_cols : item_cols / packet_size);
Expand Down
9 changes: 6 additions & 3 deletions src/operations/blas3/gemm_no_local_partial_vec.hpp
Expand Up @@ -471,7 +471,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
0,
cl::sycl::multi_ptr<const element_t, address_t::global_space>(ptr));
}
in_vec.template store<address_t::private_space>(0, reg);
in_vec.template store<address_t::private_space>(
0, cl::sycl::multi_ptr<element_t, address_t::private_space>(reg));

// Move pointers and update index for next load
ptr += ld;
Expand Down Expand Up @@ -514,7 +515,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
0, cl::sycl::multi_ptr<const element_t, address_t::private_space>(reg));
out_vec *= alpha_;

out_vec.template store<address_t::global_space>(0, out_ptr);
out_vec.template store<address_t::global_space>(
0, cl::sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
}

/*!
Expand Down Expand Up @@ -558,7 +560,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
out_vec *= alpha_;

out_vec.template store<address_t::global_space>(
0, C + j * wg_rows * a_packet_size);
0, cl::sycl::multi_ptr<element_t, address_t::global_space>(
C + j * wg_rows * a_packet_size));
}
}
C += ((i + 1) % b_packet_size == 0
Expand Down

0 comments on commit dfa0579

Please sign in to comment.