Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cuda kernels for min_to/max_to #370

Merged
merged 2 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/tensor_ops/log_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<D>> Tensor<S, E, D, T> {

#[cfg(test)]
mod tests {
use crate::{shapes::Axis, tensor::*, tensor_ops::*, tests::TestDevice};
use crate::{shapes::Axis, tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_log_softmax_1d() {
Expand All @@ -62,15 +62,15 @@ mod tests {
[-4.4519143, -3.4519143, -2.4519143, -1.4519143, -0.4519143]
);
let g = r.mean().backward();
assert_eq!(
g.get(&a).array(),
[
assert_close(
&g.get(&a).array(),
&[
0.18834378,
0.16831508,
0.11387146,
-0.034121647,
-0.43640864
]
-0.43640864,
],
);
}

Expand All @@ -79,20 +79,20 @@ mod tests {
let dev: TestDevice = Default::default();
let a = dev.tensor([[-2.0, -1.0, 0.0], [1.0, 4.0, 7.0]]);
let r = a.trace().log_softmax::<Axis<1>>();
assert_eq!(
r.array(),
[
assert_close(
&r.array(),
&[
[-2.407606, -1.4076059, -0.40760595],
[-6.0509458, -3.0509458, -0.05094576]
]
[-6.0509458, -3.0509458, -0.05094576],
],
);
let g = r.mean().backward();
assert_eq!(
g.get(&a).array(),
[
assert_close(
&g.get(&a).array(),
&[
[0.12165138, 0.044302434, -0.1659538],
[0.16548885, 0.14300959, -0.30849844]
]
[0.16548885, 0.14300959, -0.30849844],
],
);
}
}
20 changes: 10 additions & 10 deletions src/tensor_ops/logsumexp_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<D>> LogSumExpTo for Tensor<S, E,
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::TestDevice;
use crate::tests::*;

#[test]
fn test_logsumexp_1d() {
Expand All @@ -60,9 +60,9 @@ mod tests {
let r = a.trace().logsumexp();
assert_eq!(r.array(), 2.4519143);
let g = r.backward();
assert_eq!(
g.get(&a).array(),
[0.011656231, 0.03168492, 0.08612854, 0.23412165, 0.6364086]
assert_close(
&g.get(&a).array(),
&[0.011656231, 0.03168492, 0.08612854, 0.23412165, 0.6364086],
);
}

Expand All @@ -71,14 +71,14 @@ mod tests {
let dev: TestDevice = Default::default();
let a = dev.tensor([[-2.0, -1.0, 0.0], [1.0, 4.0, 7.0]]);
let r = a.trace().logsumexp::<Rank1<2>, _>();
assert_eq!(r.array(), [0.40760595, 7.0509458]);
assert_close(&r.array(), &[0.40760595, 7.0509458]);
let g = r.mean().backward();
assert_eq!(
g.get(&a).array(),
[
assert_close(
&g.get(&a).array(),
&[
[0.045015287, 0.12236424, 0.33262047],
[0.0011778167, 0.023657078, 0.47516513]
]
[0.0011778167, 0.023657078, 0.47516513],
],
);
}
}
77 changes: 73 additions & 4 deletions src/tensor_ops/max_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use crate::{
shapes::{Axes, ReduceShapeTo, Shape},
tensor::Cuda,
shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape},
tensor::cuda::{Cuda, CudaArray},
};

use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};

use std::sync::Arc;

const MODULE_NAME: &str = "max_to";
const FWD_FN_NAME: &str = "max_to_forward";
const BWD_FN_NAME: &str = "max_to_backward";
const ALL_FN_NAMES: [&str; 3] = [FWD_FN_NAME, BWD_FN_NAME, "fill_with"];
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/max_to.ptx"));

impl super::MaxReduceKernel<f32> for Cuda {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
Expand All @@ -12,7 +22,44 @@ impl super::MaxReduceKernel<f32> for Cuda {
where
Src: ReduceShapeTo<Dst, Ax>,
{
todo!()
if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
}

let mut storage = self.dev.alloc_zeros_async::<f32>(dst.num_elements())?;
let fill_fn = self.dev.get_func(MODULE_NAME, "fill_with").unwrap();
unsafe {
fill_fn.launch_async(
LaunchConfig::for_num_elems(dst.num_elements() as u32),
(&mut storage, f32::NEG_INFINITY, dst.num_elements()),
)
}?;

let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?;
let out_strides = BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
&inp_strides, // const size_t *inp_strides,
&mut storage, // float *out,
&out_strides, // const size_t *out_strides
);
unsafe { fwd_fn.launch_async(cfg, params) }?;
Ok(CudaArray {
data: Arc::new(storage),
shape: dst,
strides: dst.strides(),
})
}

fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
Expand All @@ -25,6 +72,28 @@ impl super::MaxReduceKernel<f32> for Cuda {
where
Src: ReduceShapeTo<Dst, Ax>,
{
todo!()
let bwd_fn = self.dev.get_func(MODULE_NAME, BWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(grad_inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(grad_inp.strides.into())?;
let out_strides: Src::Concrete =
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&grad_out.shape, grad_out.strides);
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = grad_inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
&inp_strides, // const size_t *inp_strides,
out.data.as_ref(), // const float *out,
grad_out.data.as_ref(), // const float *grad_out,
&out_strides, // const size_t *out_strides
);
unsafe { bwd_fn.launch_async(cfg, params) }?;
Ok(())
}
}
81 changes: 81 additions & 0 deletions src/tensor_ops/max_to/max_to.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// atomicMax is not implemented for floats,
// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ __forceinline__ float atomicMaxf(float * addr, float value) {
if (value >= 0) {
return __int_as_float(atomicMax((int *)addr, __float_as_int(value)));
} else {
return __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value)));
}
}

__device__ unsigned int get_strided_index(
unsigned int idx,
size_t num_dims,
const size_t *dims,
const size_t *strides
) {
unsigned int strided_i = 0;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

extern "C" __global__ void fill_with(float *buf, float value, const size_t numel) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}
buf[i] = value;
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
extern "C" __global__ void max_to_forward(
const size_t numel,
const size_t num_dims,
const size_t *dims,
const float *inp,
const size_t *inp_strides,
float *out,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);

atomicMaxf(out + out_strided_i, inp[inp_strided_i]);
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
extern "C" __global__ void max_to_backward(
const size_t numel,
const size_t num_dims,
const size_t *dims,
const float *inp,
float *grad_inp,
const size_t *inp_strides,
const float *out,
const float *grad_out,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);

auto tmp = inp[inp_strided_i] == out[out_strided_i] ? grad_out[out_strided_i] : 0.0;
atomicAdd(grad_inp + inp_strided_i, tmp);
}
77 changes: 73 additions & 4 deletions src/tensor_ops/min_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
use crate::{
shapes::{Axes, ReduceShapeTo, Shape},
tensor::Cuda,
shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape},
tensor::cuda::{Cuda, CudaArray},
};

use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};

use std::sync::Arc;

const MODULE_NAME: &str = "min_to";
const FWD_FN_NAME: &str = "min_to_forward";
const BWD_FN_NAME: &str = "min_to_backward";
const ALL_FN_NAMES: [&str; 3] = [FWD_FN_NAME, BWD_FN_NAME, "fill_with"];
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/min_to.ptx"));

impl super::MinReduceKernel<f32> for Cuda {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
Expand All @@ -12,7 +22,44 @@ impl super::MinReduceKernel<f32> for Cuda {
where
Src: ReduceShapeTo<Dst, Ax>,
{
todo!()
if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
}

let mut storage = self.dev.alloc_zeros_async::<f32>(dst.num_elements())?;
let fill_fn = self.dev.get_func(MODULE_NAME, "fill_with").unwrap();
unsafe {
fill_fn.launch_async(
LaunchConfig::for_num_elems(dst.num_elements() as u32),
(&mut storage, f32::INFINITY, dst.num_elements()),
)
}?;

let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?;
let out_strides = BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
&inp_strides, // const size_t *inp_strides,
&mut storage, // float *out,
&out_strides, // const size_t *out_strides
);
unsafe { fwd_fn.launch_async(cfg, params) }?;
Ok(CudaArray {
data: Arc::new(storage),
shape: dst,
strides: dst.strides(),
})
}

fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
Expand All @@ -25,6 +72,28 @@ impl super::MinReduceKernel<f32> for Cuda {
where
Src: ReduceShapeTo<Dst, Ax>,
{
todo!()
let bwd_fn = self.dev.get_func(MODULE_NAME, BWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(grad_inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(grad_inp.strides.into())?;
let out_strides: Src::Concrete =
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&grad_out.shape, grad_out.strides);
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = grad_inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
&inp_strides, // const size_t *inp_strides,
out.data.as_ref(), // const float *out,
grad_out.data.as_ref(), // const float *grad_out,
&out_strides, // const size_t *out_strides
);
unsafe { bwd_fn.launch_async(cfg, params) }?;
Ok(())
}
}
Loading