-
-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Implement reshape cuda kernel * Use atomic add to handle broadcasted arrays * implement striding logic
- Loading branch information
Showing
2 changed files
with
131 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,90 @@ | ||
use crate::shapes::{Dtype, HasSameNumelAs, Shape}; | ||
use crate::tensor::Cuda; | ||
use crate::{ | ||
shapes::{Dtype, HasSameNumelAs, Shape}, | ||
tensor::cuda::{Cuda, CudaArray}, | ||
tensor_ops::ops::{BinaryKernel, UnaryKernel}, | ||
}; | ||
use cudarc::device::{AsKernelParam, CudaSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits}; | ||
use std::sync::Arc; | ||
|
||
impl<E: Dtype> super::ReshapeKernel<E> for Cuda { | ||
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape.ptx")); | ||
const MODULE_NAME: &str = "reshape"; | ||
const FWD_FN_NAME: &str = "reshape_forward"; | ||
const BWD_FN_NAME: &str = "reshape_backward"; | ||
const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; | ||
|
||
impl super::ReshapeKernel<f32> for Cuda { | ||
fn forward<Src: Shape, Dst: Shape>( | ||
&self, | ||
dst: Dst, | ||
inp: &Self::Storage<Src, E>, | ||
) -> Result<Self::Storage<Dst, E>, Self::Err> | ||
inp: &Self::Storage<Src, f32>, | ||
) -> Result<Self::Storage<Dst, f32>, Self::Err> | ||
where | ||
Src: HasSameNumelAs<Dst>, | ||
{ | ||
todo!() | ||
if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) { | ||
self.dev | ||
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?; | ||
} | ||
|
||
let numel = inp.data.len(); | ||
let mut storage = self.dev.alloc_zeros_async::<f32>(numel)?; | ||
|
||
let inp_dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?; | ||
let dst_dims: CudaSlice<usize> = self.dev.take_async(dst.concrete().into())?; | ||
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?; | ||
let dst_strides: CudaSlice<usize> = self.dev.take_async(dst.strides().into())?; | ||
|
||
let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
let params = ( | ||
numel, // const size_t numel, | ||
inp.data.as_ref(), // const float *inp, | ||
Src::NUM_DIMS, // const size_t inp_num_dims, | ||
&inp_dims, // const size_t *inp_dims, | ||
&inp_strides, // const size_t *inp_strides, | ||
&mut storage, // float *out | ||
Dst::NUM_DIMS, // const size_t out_num_dims, | ||
&dst_dims, // const size_t *out_dims, | ||
&dst_strides // const size_t *out_strides, | ||
); | ||
unsafe { fwd_fn.launch_async(cfg, params) }?; | ||
|
||
Ok(CudaArray { | ||
data: Arc::new(storage), | ||
shape: dst, | ||
strides: dst.strides(), | ||
}) | ||
} | ||
|
||
fn backward<Src: Shape, Dst: Shape>( | ||
&self, | ||
grad_inp: &mut Self::Storage<Src, E>, | ||
grad_out: &Self::Storage<Dst, E>, | ||
grad_inp: &mut Self::Storage<Src, f32>, | ||
grad_out: &Self::Storage<Dst, f32>, | ||
) -> Result<(), Self::Err> | ||
where | ||
Src: HasSameNumelAs<Dst>, | ||
{ | ||
todo!() | ||
let bwd_fn = self.dev.get_func(MODULE_NAME, BWD_FN_NAME).unwrap(); | ||
let numel = grad_inp.data.len(); | ||
|
||
let inp_dims: CudaSlice<usize> = self.dev.take_async(grad_inp.shape.concrete().into())?; | ||
let out_dims: CudaSlice<usize> = self.dev.take_async(grad_out.shape.concrete().into())?; | ||
let inp_strides: CudaSlice<usize> = self.dev.take_async(grad_inp.strides.into())?; | ||
let out_strides: CudaSlice<usize> = self.dev.take_async(grad_out.strides.into())?; | ||
|
||
let cfg = LaunchConfig::for_num_elems(numel as u32); | ||
let params = ( | ||
numel, // const size_t numel, | ||
Arc::make_mut(&mut grad_inp.data), // float *grad_inp, | ||
Src::NUM_DIMS, // const size_t inp_num_dims, | ||
&inp_dims, // const size_t *inp_dims, | ||
&inp_strides, // const size_t *inp_strides, | ||
grad_out.data.as_ref(), // const float *grad_out, | ||
Dst::NUM_DIMS, // const size_t out_num_dims, | ||
&out_dims, // const size_t *out_dims, | ||
&out_strides, // const size_t *out_strides | ||
); | ||
unsafe { bwd_fn.launch_async(cfg, params) }?; | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
__device__ unsigned int get_strided_index( | ||
unsigned int idx, | ||
const size_t num_dims, | ||
const size_t *dims, | ||
const size_t *strides | ||
) { | ||
unsigned int strided_i = 0; | ||
for (unsigned int d = 0; d < num_dims; d++) { | ||
unsigned int dim_idx = num_dims - 1 - d; | ||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; | ||
idx /= dims[dim_idx]; | ||
} | ||
return strided_i; | ||
} | ||
|
||
extern "C" __global__ void reshape_forward( | ||
const size_t numel, | ||
const float *inp, | ||
const size_t inp_num_dims, | ||
const size_t *inp_dims, | ||
const size_t *inp_strides, | ||
float *out, | ||
const size_t out_num_dims, | ||
const size_t *out_dims, | ||
const size_t *out_strides | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); | ||
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); | ||
|
||
out[out_i] = inp[inp_i]; | ||
} | ||
|
||
extern "C" __global__ void reshape_backward( | ||
const size_t numel, | ||
float *grad_inp, | ||
const size_t inp_num_dims, | ||
const size_t *inp_dims, | ||
const size_t *inp_strides, | ||
const float *grad_out, | ||
const size_t out_num_dims, | ||
const size_t *out_dims, | ||
const size_t *out_strides | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int inp_i = get_strided_index(i, inp_num_dims, inp_dims, inp_strides); | ||
unsigned int out_i = get_strided_index(i, out_num_dims, out_dims, out_strides); | ||
|
||
atomicAdd(grad_inp + inp_i, grad_out[out_i]); | ||
} |