Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify sum_to cuda kernel to not need atomic adds in backwards #367

Merged
merged 2 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/tensor_ops/sum_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ impl super::SumKernel<f32> for Cuda {

let mut storage = self.dev.alloc_zeros_async::<f32>(dst.num_elements())?;

let inp_numel = inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let physical_numel = inp.data.len();
let virtual_numel = inp.shape.num_elements();
let elems_per_thread = (virtual_numel / physical_numel) as f32;

let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
physical_numel, // const size_t numel,
Src::NUM_DIMS, // const size_t num_dims,
elems_per_thread, // const float elems_per_thread,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
&inp_strides, // const size_t *inp_strides,
Expand Down Expand Up @@ -72,11 +76,15 @@ impl super::SumKernel<f32> for Cuda {
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&grad_out.shape, grad_out.strides);
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = grad_inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let physical_numel = grad_inp.data.len();
let virtual_numel = grad_inp.shape.num_elements();
let elems_per_thread = (virtual_numel / physical_numel) as f32;

let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
physical_numel, // const size_t numel,
Src::NUM_DIMS, // const size_t num_dims,
elems_per_thread, // const float elems_per_thread,
&dims, // const size_t *dims,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
&inp_strides, // const size_t *inp_strides,
Expand Down
40 changes: 28 additions & 12 deletions src/tensor_ops/sum_to/sum_to.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,69 @@ __device__ unsigned int get_strided_index(
return strided_i;
}

__device__ unsigned int get_unstrided_index(
const unsigned int strided_i,
const size_t num_dims,
const size_t *dims,
const size_t *strides
) {
unsigned int idx = 0;
for (unsigned int d = 0; d < num_dims; d++) {
idx *= dims[d];
idx += strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d];
}
return idx;
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
extern "C" __global__ void sum_to_forward(
const size_t numel,
const size_t num_dims,
const float elems_per_thread,
const size_t *dims,
const float *inp,
const size_t *inp_strides,
float *out,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
if (inp_i >= numel) {
return;
}

unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
auto tmp = inp[inp_strided_i];
auto tmp = inp[inp_i];

unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);
atomicAdd(out + out_strided_i, tmp);
unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides);
unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides);
atomicAdd(out + out_i, tmp * elems_per_thread);
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
extern "C" __global__ void sum_to_backward(
const size_t numel,
const size_t num_dims,
const float elems_per_thread,
const size_t *dims,
float *grad_inp,
const size_t *inp_strides,
const float *grad_out,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
if (inp_i >= numel) {
return;
}

unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);
auto tmp = grad_out[out_strided_i];
unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides);
unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides);
auto tmp = grad_out[out_i];

// NOTE: since size of output is less than input, only 1 thread will be writing to inp
// at a time. this means we don't have to worry about multiple concurrent writes
// like we do with forward.
unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
grad_inp[inp_strided_i] += tmp;
grad_inp[inp_i] += tmp * elems_per_thread;
}