Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify upscale cuda kernels #680

Merged
merged 2 commits into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions src/tensor_ops/upscale2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use crate::{
shapes::*,
tensor::{Cuda, Tensor},
tensor::{launch_cfg, Cuda, Tensor},
};

use std::sync::Arc;

use cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig};
use cudarc::driver::{DeviceRepr, LaunchAsync};

use super::{Bilinear, NearestNeighbor, UpscaleMethod};

const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/upscale2d.ptx"));

unsafe impl DeviceRepr for super::Upscale2DOp {}

fn make_4d<S: Shape>(strides: S::Concrete, pad: usize) -> [usize; 4] {
fn make_4d<S: Shape>(strides: S::Concrete) -> [usize; 4] {
match S::NUM_DIMS {
3 => [pad, strides[0], strides[1], strides[2]],
3 => [0, strides[0], strides[1], strides[2]],
4 => [strides[0], strides[1], strides[2], strides[3]],
_ => panic!("Only implemented for 3d & 4d arrays"),
}
Expand Down Expand Up @@ -56,16 +56,14 @@ where
.load_ptx(PTX_SRC.into(), Self::FWD, &[Self::FWD, Self::BWD])?;
}

let inp_strides = self.dev.htod_copy(make_4d::<I>(inp.strides, 0).into())?;
let out_strides = self.dev.htod_copy(make_4d::<O>(out.strides, 0).into())?;
let strides = self.dev.htod_copy(make_4d::<I>(inp.strides).into())?;
let fwd_fn = self.dev.get_func(Self::FWD, Self::FWD).unwrap();
let cfg = LaunchConfig::for_num_elems(out.shape().num_elements() as u32);
let cfg = launch_cfg(out.shape().num_elements() as u32);
let params = (
op, // const Pool2dOp op,
&inp_strides, // const size_t *inp_strides,
&out_strides, // const size_t *out_strides,
inp.data.as_ref(), // const float *inp,
Arc::make_mut(&mut out.data), // float *out
op,
&strides,
inp.data.as_ref(),
Arc::make_mut(&mut out.data),
);
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(())
Expand All @@ -78,17 +76,10 @@ where
out: &Tensor<O, E, Self>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let inp_strides = self.dev.htod_copy(make_4d::<I>(inp.strides, 0).into())?;
let out_strides = self.dev.htod_copy(make_4d::<O>(out.strides, 0).into())?;
let strides = self.dev.htod_copy(make_4d::<I>(inp.strides).into())?;
let bwd_fn = self.dev.get_func(Self::FWD, Self::BWD).unwrap();
let cfg = LaunchConfig::for_num_elems(out.shape().num_elements() as u32);
let params = (
op, // const Pool2dOp op,
&inp_strides, // const size_t *inp_strides,
&out_strides, // const size_t *out_strides,
grad_inp, // float *grad_inp,
grad_out, // const float *grad_out
);
let cfg = launch_cfg(out.shape().num_elements() as u32);
let params = (op, &strides, grad_inp, grad_out);
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
Expand Down
36 changes: 14 additions & 22 deletions src/tensor_ops/upscale2d/upscale2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ template<typename T>
__device__ void nearest_upscale2d_fwd(
const Upscale2dOp op,
const size_t *inp_strides,
const size_t *out_strides,
const T *inp, // 4d (Batch, Channels, Height, Width)
T *out // 4d (Batch, Channels, HeightOut, WidthOut)
) {
Expand All @@ -34,8 +33,8 @@ __device__ void nearest_upscale2d_fwd(
idx /= op.chan;
const size_t b = idx % op.batch;

size_t ih = min(static_cast<size_t>(h_scale * oh), op.h_out - 1);
size_t iw = min(static_cast<size_t>(w_scale * ow), op.w_out - 1);
size_t ih = min(static_cast<size_t>(h_scale * oh), op.h_in - 1);
size_t iw = min(static_cast<size_t>(w_scale * ow), op.w_in - 1);

size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3];

Expand All @@ -46,7 +45,6 @@ template<typename T>
__device__ void nearest_upscale2d_bwd(
const Upscale2dOp op,
const size_t *inp_strides,
const size_t *out_strides,
T *grad_inp,
const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut)
) {
Expand All @@ -67,8 +65,8 @@ __device__ void nearest_upscale2d_bwd(
idx /= op.chan;
const size_t b = idx % op.batch;

size_t ih = min(static_cast<size_t>(h_scale * oh), op.h_out - 1);
size_t iw = min(static_cast<size_t>(w_scale * ow), op.w_out - 1);
size_t ih = min(static_cast<size_t>(h_scale * oh), op.h_in - 1);
size_t iw = min(static_cast<size_t>(w_scale * ow), op.w_in - 1);

size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3];
atomicAdd(grad_inp + inp_i, grad_out[i]);
Expand All @@ -78,7 +76,6 @@ template<typename T>
__device__ void bilinear_upscale2d_fwd(
const Upscale2dOp op,
const size_t *inp_strides,
const size_t *out_strides,
const T *inp, // 4d (Batch, Channels, Height, Width)
T *out // 4d (Batch, Channels, HeightOut, WidthOut)
) {
Expand All @@ -98,12 +95,11 @@ __device__ void bilinear_upscale2d_fwd(
const size_t c = idx % op.chan;
idx /= op.chan;
const size_t b = idx % op.batch;
idx /= op.batch;

size_t y0 = min(static_cast<size_t>(h_scale * oh), op.h_out - 1);
size_t y1 = min(y0 + 1, op.h_out - 1);
size_t x0 = min(static_cast<size_t>(w_scale * ow), op.w_out - 1);
size_t x1 = min(x0 + 1, op.w_out - 1);
size_t y0 = min(static_cast<size_t>(h_scale * oh), op.h_in - 1);
size_t y1 = min(y0 + 1, op.h_in - 1);
size_t x0 = min(static_cast<size_t>(w_scale * ow), op.w_in - 1);
size_t x1 = min(x0 + 1, op.w_in - 1);
Comment on lines +99 to +102
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nkoppel this was the issue


T hs = h_scale * oh - y0;
T ws = w_scale * ow - x0;
Expand All @@ -122,7 +118,6 @@ template<typename T>
__device__ void bilinear_upscale2d_bwd(
const Upscale2dOp op,
const size_t *inp_strides,
const size_t *out_strides,
T *grad_inp, // 4d (Batch, Channels, Height, Width)
const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut)
) {
Expand All @@ -142,12 +137,11 @@ __device__ void bilinear_upscale2d_bwd(
const size_t c = idx % op.chan;
idx /= op.chan;
const size_t b = idx % op.batch;
idx /= op.batch;

size_t y0 = min(static_cast<size_t>(h_scale * oh), op.h_out - 1);
size_t y1 = min(y0 + 1, op.h_out - 1);
size_t x0 = min(static_cast<size_t>(w_scale * ow), op.w_out - 1);
size_t x1 = min(x0 + 1, op.w_out - 1);
size_t y0 = min(static_cast<size_t>(h_scale * oh), op.h_in - 1);
size_t y1 = min(y0 + 1, op.h_in - 1);
size_t x0 = min(static_cast<size_t>(w_scale * ow), op.w_in - 1);
size_t x1 = min(x0 + 1, op.w_in - 1);

T hs = h_scale * oh - y0;
T ws = w_scale * ow - x0;
Expand All @@ -166,20 +160,18 @@ __device__ void bilinear_upscale2d_bwd(
extern "C" __global__ void fwd( \
const Upscale2dOp op, \
const size_t *inp_strides, \
const size_t *out_strides, \
const TYPENAME *inp, \
TYPENAME *out \
) { \
fwd_FN(op, inp_strides, out_strides, inp, out); \
fwd_FN(op, inp_strides, inp, out); \
} \
extern "C" __global__ void bwd( \
const Upscale2dOp op, \
const size_t *inp_strides, \
const size_t *out_strides, \
TYPENAME *grad_inp, \
const TYPENAME *grad_out \
) { \
bwd_FN(op, inp_strides, out_strides, grad_inp, grad_out); \
bwd_FN(op, inp_strides, grad_inp, grad_out); \
}

UPSCALE_OP(
Expand Down