Skip to content

Commit

Permalink
Implement reshape cuda kernel (resolves #336) (#356)
Browse files Browse the repository at this point in the history
* Implement reshape cuda kernel

* Use atomic add to handle broadcasted arrays

* implement striding logic
  • Loading branch information
nkoppel committed Jan 12, 2023
1 parent 673ab8c commit b9d376a
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 9 deletions.
82 changes: 73 additions & 9 deletions src/tensor_ops/reshape_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,90 @@
use crate::shapes::{Dtype, HasSameNumelAs, Shape};
use crate::tensor::Cuda;
use crate::{
shapes::{Dtype, HasSameNumelAs, Shape},
tensor::cuda::{Cuda, CudaArray},
tensor_ops::ops::{BinaryKernel, UnaryKernel},
};
use cudarc::device::{AsKernelParam, CudaSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits};
use std::sync::Arc;

impl<E: Dtype> super::ReshapeKernel<E> for Cuda {
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape.ptx"));
const MODULE_NAME: &str = "reshape";
const FWD_FN_NAME: &str = "reshape_forward";
const BWD_FN_NAME: &str = "reshape_backward";
const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME];

impl super::ReshapeKernel<f32> for Cuda {
fn forward<Src: Shape, Dst: Shape>(
&self,
dst: Dst,
inp: &Self::Storage<Src, E>,
) -> Result<Self::Storage<Dst, E>, Self::Err>
inp: &Self::Storage<Src, f32>,
) -> Result<Self::Storage<Dst, f32>, Self::Err>
where
Src: HasSameNumelAs<Dst>,
{
todo!()
if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
}

let numel = inp.data.len();
let mut storage = self.dev.alloc_zeros_async::<f32>(numel)?;

let inp_dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?;
let dst_dims: CudaSlice<usize> = self.dev.take_async(dst.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?;
let dst_strides: CudaSlice<usize> = self.dev.take_async(dst.strides().into())?;

let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap();
let cfg = LaunchConfig::for_num_elems(numel as u32);
let params = (
numel, // const size_t numel,
inp.data.as_ref(), // const float *inp,
Src::NUM_DIMS, // const size_t inp_num_dims,
&inp_dims, // const size_t *inp_dims,
&inp_strides, // const size_t *inp_strides,
&mut storage, // float *out
Dst::NUM_DIMS, // const size_t out_num_dims,
&dst_dims, // const size_t *out_dims,
&dst_strides // const size_t *out_strides,
);
unsafe { fwd_fn.launch_async(cfg, params) }?;

Ok(CudaArray {
data: Arc::new(storage),
shape: dst,
strides: dst.strides(),
})
}

fn backward<Src: Shape, Dst: Shape>(
&self,
grad_inp: &mut Self::Storage<Src, E>,
grad_out: &Self::Storage<Dst, E>,
grad_inp: &mut Self::Storage<Src, f32>,
grad_out: &Self::Storage<Dst, f32>,
) -> Result<(), Self::Err>
where
Src: HasSameNumelAs<Dst>,
{
todo!()
let bwd_fn = self.dev.get_func(MODULE_NAME, BWD_FN_NAME).unwrap();
let numel = grad_inp.data.len();

let inp_dims: CudaSlice<usize> = self.dev.take_async(grad_inp.shape.concrete().into())?;
let out_dims: CudaSlice<usize> = self.dev.take_async(grad_out.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(grad_inp.strides.into())?;
let out_strides: CudaSlice<usize> = self.dev.take_async(grad_out.strides.into())?;

let cfg = LaunchConfig::for_num_elems(numel as u32);
let params = (
numel, // const size_t numel,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
Src::NUM_DIMS, // const size_t inp_num_dims,
&inp_dims, // const size_t *inp_dims,
&inp_strides, // const size_t *inp_strides,
grad_out.data.as_ref(), // const float *grad_out,
Dst::NUM_DIMS, // const size_t out_num_dims,
&out_dims, // const size_t *out_dims,
&out_strides, // const size_t *out_strides
);
unsafe { bwd_fn.launch_async(cfg, params) }?;
Ok(())
}
}
58 changes: 58 additions & 0 deletions src/tensor_ops/reshape_to/reshape.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__device__ unsigned int get_strided_index(
unsigned int idx,
const size_t num_dims,
const size_t *dims,
const size_t *strides
) {
unsigned int strided_i = 0;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

extern "C" __global__ void reshape_forward(
const size_t numel,
const float *inp,
const size_t inp_num_dims,
const size_t *inp_dims,
const size_t *inp_strides,
float *out,
const size_t out_num_dims,
const size_t *out_dims,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}

unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);

out[out_i] = inp[inp_i];
}

extern "C" __global__ void reshape_backward(
const size_t numel,
float *grad_inp,
const size_t inp_num_dims,
const size_t *inp_dims,
const size_t *inp_strides,
const float *grad_out,
const size_t out_num_dims,
const size_t *out_dims,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}

unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides);
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides);

atomicAdd(grad_inp + inp_i, grad_out[out_i]);
}

0 comments on commit b9d376a

Please sign in to comment.