Skip to content

Commit

Permalink
fix skipahead compile error on older cuda versions
Browse files Browse the repository at this point in the history
  • Loading branch information
jfc4050 committed Dec 20, 2022
1 parent d1a2103 commit b9337bd
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,10 @@ struct AttentionKernel {
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(
(query_start + thread_i) * p.num_keys + (iter_key_start + thread_start_j),
static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys +
(iter_key_start + thread_start_j)),
&curand_state);

const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);

// apply dropout scaling to elements this thread is responsible for,
Expand Down

0 comments on commit b9337bd

Please sign in to comment.