Skip to content
This repository was archived by the owner on Dec 25, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion device/keyswitch/dyadmult.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void broadcast_keys(sycl::queue& q,
const unsigned int KEYS_LEN = tp_MAX_RNS_MODULUS_SIZE * 2;
auto kernelLambda = [=]()
[[intel::kernel_args_restrict]] [[intel::max_global_work_dim(0)]] {
for (size_t i = 0; i < batch_size; i++) {
for (int i = 0; i < batch_size; i++) {
unsigned params_size = tt_ch_keyswitch_params::read();
for (int i = 0; i < params_size; i++) {
uint256_t keys1 = k_switch_keys1[i];
Expand Down
4 changes: 2 additions & 2 deletions device/keyswitch/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ sycl::event load(sycl::queue& q, sycl::event* inDepsEv,
temp_pipe::write(cur_moduli);
});
STEP(decomp_index, decomp_modulus_size);

for (int n = 0; n < coeff_count; n++) {
uint coeff_count_tmp = coeff_count;
for (uint n = 0; n < coeff_count_tmp; n++) {
Unroller<0, NUM_CORES>::Step([&](auto COREID) {
using temp_pipe =
typename ch_intt_elements_in::template PipeAt<
Expand Down
1 change: 1 addition & 0 deletions device/keyswitch/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@

#define STEP(n, max) n = n == (max - 1) ? 0 : n + 1
#define STEP2(n, max) n = n == ((max)-1) ? -1 : n + 1
#define STEP3(n, max) n = n == (max) ? -1 : n + 1
6 changes: 4 additions & 2 deletions device/keyswitch/twiddle_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ void dispatch_twiddle_factors(sycl::queue& q,
typename tt_ch_twiddle_factor_rep::template PipeAt<
NTT_ENGINES - 2>;
twPipe8::write(tf, success);
if (success) STEP2(ntt2_index, ntt2_decomp_size / VEC);
short max_tmp = ntt2_decomp_size / VEC - 1;
if (success) STEP3(ntt2_index, max_tmp);
}
// write intt1
TwiddleFactor_t intt1_tf;
Expand All @@ -253,7 +254,8 @@ void dispatch_twiddle_factors(sycl::queue& q,
}
if (intt1_index >= 0) {
tt_ch_intt1_twiddle_factor_rep::write(intt1_tf, success);
if (success) STEP2(intt1_index, intt1_decomp_size / VEC);
short max_tmp = intt1_decomp_size / VEC - 1;
if (success) STEP3(intt1_index, max_tmp);
}

// write intt2
Expand Down
2 changes: 1 addition & 1 deletion host/inc/fpga.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ class Device {

// KeySwitch section
sycl::queue keyswitch_queues_[KEYSWITCH_NUM_KERNELS];
sycl::event KeySwitch_events_write_[2][128];
sycl::event KeySwitch_events_write_[2][1024];
sycl::event KeySwitch_events_enqueue_[2][2];
std::unordered_map<uint64_t**, KeySwitchMemKeys<uint256_t>*> keys_map_;
static int device_id_;
Expand Down
37 changes: 30 additions & 7 deletions host/src/fpga.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@
namespace intel {
namespace hexl {
namespace fpga {

// helper function to explicitly copy host data to device.
static sycl::event copy_buffer_to_device(sycl::queue& q,
sycl::buffer<uint64_t>& buf) {
sycl::host_accessor host_acc(buf);
uint64_t* host_ptr = host_acc.get_pointer();
sycl::event e = q.submit([&](sycl::handler& h) {
auto acc_dev = buf.get_access<sycl::access::mode::discard_write>(h);
h.copy(host_ptr, acc_dev);
});
return e;
}

// utility function for copying input data batch for KeySwitch

const char* keyswitch_kernel_name[] = {"load", "store"};
Expand Down Expand Up @@ -257,10 +270,10 @@ FPGAObject_KeySwitch::FPGAObject_KeySwitch(sycl::queue& p_q,
aligned_alloc(HOST_MEM_ALIGNMENT, size_out * sizeof(uint64_t)));
mem_t_target_iter_ptr_ = new sycl::buffer<uint64_t>(
sycl::range(size_in),
{sycl::property::buffer::mem_channel{MEM_CHANNEL_K2}});
{sycl::property::buffer::mem_channel{MEM_CHANNEL_K1}});
mem_KeySwitch_results_ = new sycl::buffer<sycl::ulong2>(
sycl::range(size_out / 2),
{sycl::property::buffer::mem_channel{MEM_CHANNEL_K2}});
{sycl::property::buffer::mem_channel{MEM_CHANNEL_K1}});
mem_t_target_iter_ptr_->set_write_back(false);
mem_KeySwitch_results_->set_write_back(false);
}
Expand Down Expand Up @@ -1261,18 +1274,28 @@ void Device::enqueue_input_data_KeySwitch(FPGAObject_KeySwitch* fpga_obj) {
keyswitch_queues_[KEYSWITCH_LOAD], *(keys->k_switch_keys_1_),
*(keys->k_switch_keys_2_), *(keys->k_switch_keys_3_),
fpga_obj->in_objs_.size());
const auto& start_ocl = std::chrono::high_resolution_clock::now();

int obj_id = KeySwitch_id_ % 2;
copyKeySwitchBatch(fpga_obj, obj_id);

// copy_buffer_to_device() and wait() is a utility to force blocked write,
// and to facilitate performance measure on FPGA.
// The release is to support streaming, and blocking write will slow things
// down.
// KeySwitch_events_write_[obj_id][0] = copy_buffer_to_device(
// keyswitch_queues_[KEYSWITCH_LOAD],
// *(fpga_obj->mem_t_target_iter_ptr_));
// KeySwitch_events_write_[obj_id][0].wait();

// =============== Launch keyswitch kernel ==============================
unsigned rmem = 0;
if (RWMEM_FLAG) {
rmem = 1;
}
const auto& start_ocl = std::chrono::high_resolution_clock::now();
KeySwitch_events_enqueue_[obj_id][0] =
(*(KeySwitch_kernel_container_->load))(
keyswitch_queues_[KEYSWITCH_LOAD],
nullptr /* KeySwitch_events_write_[obj_id] */,
keyswitch_queues_[KEYSWITCH_LOAD], nullptr,
*(fpga_obj->mem_t_target_iter_ptr_), modulus_meta_, fpga_obj->n_,
fpga_obj->decomp_modulus_size_, fpga_obj->n_batch_,
(*(invn_t*)(void*)&invn_), rmem);
Expand Down Expand Up @@ -1540,9 +1563,9 @@ bool Device::process_output_KeySwitch() {
*(fpga_obj->mem_KeySwitch_results_), fpga_obj->n_batch_,
fpga_obj->n_, fpga_obj->decomp_modulus_size_, modulus_meta_, rmem,
wmem);

const auto& end_ocl = std::chrono::high_resolution_clock::now();
keyswitch_queues_[KEYSWITCH_STORE].wait();
const auto& end_ocl = std::chrono::high_resolution_clock::now();

const auto& start_io = std::chrono::high_resolution_clock::now();
if (KeySwitch_id_ > 0) {
KeySwitch_read_output();
Expand Down
4 changes: 2 additions & 2 deletions host/src/fpga_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ static uint64_t g_batch_size_intt = get_batch_size_intt();
static uint64_t get_batch_size_KeySwitch() {
char* env = getenv("BATCH_SIZE_KEYSWITCH");
uint64_t size = env ? strtoul(env, NULL, 10) : 1;
if (size > 128) {
if (size > 1024) {
std::cerr << "Error: BATCH_SIZE_KEYSWITCH is " << size << std::endl;
std::cerr << " Maxiaml supported BATCH_SIZE_KEYSWITCH is 128."
std::cerr << " Maxiaml supported BATCH_SIZE_KEYSWITCH is 1024."
<< std::endl;
exit(1);
}
Expand Down