Skip to content

Commit

Permalink
Add more runtime shape checks (#454)
Browse files Browse the repository at this point in the history
* #445 Check broadcast shapes

* #448 runtime shape checks for select/gather

* Making inline always
  • Loading branch information
coreylowman authored Feb 15, 2023
1 parent 73905f1 commit 8b3531e
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 7 deletions.
25 changes: 25 additions & 0 deletions src/shapes/broadcasts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,27 @@ broadcast_to!(3, (N, O, P), 4, (M, N, O, P), Axis<0>);

/// Internal implementation for broadcasting strides
pub trait BroadcastStridesTo<S: Shape, Ax>: Shape + BroadcastShapeTo<S, Ax> {
fn check(&self, dst: &S);
fn broadcast_strides(&self, strides: Self::Concrete) -> S::Concrete;
}

impl<Src: Shape, Dst: Shape, Ax: Axes> BroadcastStridesTo<Dst, Ax> for Src
where
Self: BroadcastShapeTo<Dst, Ax>,
{
#[inline(always)]
fn check(&self, dst: &Dst) {
let src_dims = self.concrete();
let dst_dims = dst.concrete();
let mut j = 0;
for i in 0..Dst::NUM_DIMS {
if !Ax::as_array().into_iter().any(|x| x == i as isize) {
assert_eq!(dst_dims[i], src_dims[j]);
j += 1;
}
}
}

#[inline(always)]
fn broadcast_strides(&self, strides: Self::Concrete) -> Dst::Concrete {
let mut new_strides: Dst::Concrete = Default::default();
Expand Down Expand Up @@ -125,6 +139,17 @@ where
mod tests {
use super::*;

#[test]
fn test_check() {
BroadcastStridesTo::<(usize, usize), Axis<1>>::check(&(1,), &(1, 2));
}

#[test]
#[should_panic]
fn test_check_failures() {
BroadcastStridesTo::<(usize, usize), Axis<1>>::check(&(1,), &(2, 2));
}

#[test]
fn test_no_conflict_reductions() {
let src = (1, Const::<2>, 3, Const::<4>);
Expand Down
29 changes: 29 additions & 0 deletions src/shapes/replace_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ use super::{
pub trait RemoveDimTo<Dst: Shape, Idx: Shape>: Shape {
type Ax: Axes<Array = [isize; 1]>;

/// All dimensions of idx should be the same as the dimensions of Self
#[inline(always)]
fn check(&self, idx: &Idx) {
assert!(Idx::NUM_DIMS <= Self::NUM_DIMS);
let src_dims = self.concrete();
let idx_dims = idx.concrete();
for i in 0..Idx::NUM_DIMS {
assert_eq!(src_dims[i], idx_dims[i]);
}
}

#[inline]
fn remove(&self, _: Idx) -> Dst {
let ax = Self::Ax::as_array()[0] as usize;
Expand All @@ -28,6 +39,24 @@ pub trait RemoveDimTo<Dst: Shape, Idx: Shape>: Shape {
pub trait ReplaceDimTo<Dst: Shape, Idx: Shape>: Shape {
type Ax: Axes<Array = [isize; 1]>;

/// All dimensions of idx *up to last dimension* (which is new)
/// should be the same as the dimensions of Self
#[inline(always)]
fn check(&self, idx: &Idx) {
if Self::NUM_DIMS == Dst::NUM_DIMS {
// replace 1 dim case
assert!(Idx::NUM_DIMS <= Self::NUM_DIMS);
let src_dims = self.concrete();
let idx_dims = idx.concrete();
for i in 0..Idx::NUM_DIMS - 1 {
assert_eq!(src_dims[i], idx_dims[i]);
}
} else {
// batch replace case - we actually don't need to check this case
// at all
}
}

#[inline]
fn replace(&self, idx: Idx) -> Dst {
let ax = Self::Ax::as_array()[0] as usize;
Expand Down
15 changes: 11 additions & 4 deletions src/tensor_ops/broadcast_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ pub trait BroadcastTo: HasErr + HasShape {
/// // broadcast axis 0 and axis 2
/// let _ = a.clone().broadcast::<Rank4<1, 3, 5, 7>, _>();
/// ```
fn broadcast<Dst: Shape + Default, Ax: Axes>(self) -> Self::WithShape<Dst>
fn broadcast<Dst: ConstShape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
self.try_broadcast_like(&Default::default()).unwrap()
}
/// Fallible version of [BroadcastTo::broadcast]
fn try_broadcast<Dst: Shape + Default, Ax: Axes>(
self,
) -> Result<Self::WithShape<Dst>, Self::Err>
fn try_broadcast<Dst: ConstShape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
Expand Down Expand Up @@ -76,6 +74,7 @@ impl<S: Shape, E: Dtype, D: BroadcastKernel<E>, T: Tape<D>> BroadcastTo for Tens
Self::Shape: BroadcastShapeTo<Dst, Ax>,
{
let (inp, mut tape) = self.split_tape();
inp.shape().check(dst);
let out = inp.device.upgrade(inp.device.forward(*dst, &inp.storage)?);
let phantom_out = out.clone();
tape.try_alloc_grad(&inp)?;
Expand All @@ -94,6 +93,14 @@ mod tests {
use crate::tensor_ops::*;
use crate::tests::*;

#[test]
#[should_panic]
fn test_broadcast_incorrect_dims() {
let dev: TestDevice = Default::default();
let a: Tensor<(usize,), TestDtype, _> = dev.zeros_like(&(5,));
let _: Tensor<(Const<3>, usize), TestDtype, _> = a.broadcast_like(&(Const, 7));
}

#[test]
fn test_valid_1d_broadcasts() {
let dev: TestDevice = Default::default();
Expand Down
6 changes: 3 additions & 3 deletions src/tensor_ops/reshape_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ pub trait ReshapeKernel<E: Dtype>: DeviceStorage {

/// **Requires Nightly** Change the shape of a tensor moving data around.
pub trait ReshapeTo: HasErr + HasShape {
fn reshape<Dst: Shape + Default>(self) -> Self::WithShape<Dst>
fn reshape<Dst: ConstShape>(self) -> Self::WithShape<Dst>
where
Self::Shape: HasSameNumelAs<Dst>,
{
self.try_reshape().unwrap()
}
fn try_reshape<Dst: Shape + Default>(self) -> Result<Self::WithShape<Dst>, Self::Err>
fn try_reshape<Dst: ConstShape>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: HasSameNumelAs<Dst>;
}

impl<S: Shape, E: Dtype, D: ReshapeKernel<E>, T: Tape<D>> ReshapeTo for Tensor<S, E, D, T> {
fn try_reshape<Dst: Shape + Default>(self) -> Result<Self::WithShape<Dst>, Self::Err>
fn try_reshape<Dst: ConstShape>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: HasSameNumelAs<Dst>,
{
Expand Down
53 changes: 53 additions & 0 deletions src/tensor_ops/select_and_gather/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl<Src: Shape, E: Dtype, D: RemoveDimKernel<E>, T: Tape<D>> SelectTo<D> for Te
where
Self::Shape: RemoveDimTo<Dst, Idx>,
{
self.shape().check(idx.shape());
let (inp, mut tape) = self.split_tape();
let storage = inp.device.forward(&inp.storage, &idx.storage)?;
let out = inp.device.upgrade(storage);
Expand Down Expand Up @@ -164,6 +165,7 @@ impl<Src: Shape, E: Dtype, D: ReplaceDimKernel<E>, T: Tape<D>> GatherTo<D>
where
Self::Shape: ReplaceDimTo<Dst, Idx>,
{
self.shape().check(idx.shape());
let (inp, mut tape) = self.split_tape();
let storage = inp.device.forward(&inp.storage, &idx.storage)?;
let out = inp.device.upgrade(storage);
Expand All @@ -184,6 +186,57 @@ mod tests {
use crate::tensor_ops::*;
use crate::tests::*;

#[test]
#[should_panic]
fn test_remove_wrong_index_shape_2d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3), rand_distr::StandardNormal);
// here we are selecting from axis 1, so the 7 should actually be a 5
let _ = t.trace().select(dev.zeros_like(&(7,)));
}

#[test]
#[should_panic]
fn test_remove_wrong_index_shape_3d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(7, 5, 3), rand_distr::StandardNormal);
let _ = t.trace().select(dev.zeros_like(&(7, 4)));
}

#[test]
#[should_panic]
fn test_remove_index_out_of_bounds() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let _ = t.trace().select(dev.tensor(7));
}

#[test]
#[should_panic]
fn test_replace_wrong_index_shape_3d1() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3, 1), rand_distr::StandardNormal);
let r = t.trace().gather(dev.zeros_like(&(7,)));
assert_eq!(r.shape(), &(7, 3, 1));
let _ = t.trace().gather(dev.zeros_like(&(7, 4)));
}

#[test]
#[should_panic]
fn test_replace_wrong_index_shape_3d2() {
let dev: TestDevice = Default::default();
let t: Tensor<_, TestDtype, _> = dev.sample_like(&(5, 3, 1), rand_distr::StandardNormal);
let _ = t.trace().gather(dev.zeros_like(&(5, 4, 2)));
}

#[test]
#[should_panic]
fn test_replace_index_out_of_bounds() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank1<5>, TestDtype, _> = dev.sample_normal();
let _ = t.trace().gather(dev.tensor([7, 6, 1, 2]));
}

#[test]
fn test_remove_1d_backward() {
let dev: TestDevice = Default::default();
Expand Down

0 comments on commit 8b3531e

Please sign in to comment.