Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Adding Tensor::try_realize, and Tensor::realize no longer returns Result #758

Merged
merged 2 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/01-tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn main() {
// `realize` method helps us move between dynamic and known size for the dimensions,
// if the conversion is incompatible, it may result in runtime error
let a: Tensor<(usize, usize), f32, _> = dev.zeros_like(&(2, 3));
let _: Tensor<(usize, Const<3>), f32, _> = a.realize().expect("`a` should have 3 columns");
let _: Tensor<(usize, Const<3>), f32, _> = a.try_realize().expect("`a` should have 3 columns");

// each of the creation methods also supports specifying the shape on the function
// note to change the dtype we specify the dtype as the 2nd generic parameter
Expand Down
4 changes: 2 additions & 2 deletions examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ fn main() {
a = a + 0.5;
let b: Tensor<(usize, Const<5>), f32, _> = dev.sample_uniform_like(&(3, Const));
// note the use of `realize`
let _: Tensor<(Const<3>, usize), f32, _> = a + b.realize().expect("`b` should have 3 rows");
let _: Tensor<(Const<3>, usize), f32, _> = a + b.try_realize().expect("`b` should have 3 rows");

// then we have things like matrix and vector multiplication:
let a: Tensor<(usize, Const<5>), f32, _> = dev.sample_normal_like(&(3, Const));
let b: Tensor<(usize, usize), f32, _> = dev.sample_normal_like(&(5, 7));
// if type inference is not possible, we explicitly provide the shape for `realize`
let _: Tensor<(usize, usize), f32, _> = a.matmul(
b.realize::<(Const<5>, usize)>()
b.try_realize::<(Const<5>, usize)>()
.expect("`b` should have 5 rows"),
);

Expand Down
2 changes: 1 addition & 1 deletion examples/13-housing-nn-in-struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Predictor {
let batched: Tensor<Rank2<1, 2>, _, _> = input.clone().broadcast();

// convert static size tensor to variable sized tensor
let batched_realized: Tensor<(usize, Const<2>), _, _> = batched.realize().unwrap();
let batched_realized: Tensor<(usize, Const<2>), _, _> = batched.try_realize().unwrap();
assert_eq!(batched_realized.shape(), &(1 as usize, Const::<2>));

// call predict on batches
Expand Down
12 changes: 3 additions & 9 deletions src/tensor_ops/attention_reshape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,9 @@ mod tests {

let (q, k, v) = dev.attention_reshape(&qkv, &past_key, &past_value);

let q = q
.realize::<(Const<NUM_HEADS>, Const<1>, Const<HEAD_DIM>)>()
.unwrap();
let k = k
.realize::<(Const<NUM_HEADS>, Const<HEAD_DIM>, Const<4>)>()
.unwrap();
let v = v
.realize::<(Const<NUM_HEADS>, Const<4>, Const<HEAD_DIM>)>()
.unwrap();
let q = q.realize::<(Const<NUM_HEADS>, Const<1>, Const<HEAD_DIM>)>();
let k = k.realize::<(Const<NUM_HEADS>, Const<HEAD_DIM>, Const<4>)>();
let v = v.realize::<(Const<NUM_HEADS>, Const<4>, Const<HEAD_DIM>)>();

assert_close_to_literal!(q, [[[1.0; HEAD_DIM]; 1]; NUM_HEADS]);
assert_close_to_literal!(
Expand Down
31 changes: 20 additions & 11 deletions src/tensor_ops/concat_along/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ mod cuda_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const));
/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const));
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize().unwrap();
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize();
/// ```
///
/// Along Axis 1:
Expand All @@ -44,7 +44,7 @@ mod cuda_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize().unwrap();
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// ```
pub trait TryConcatAlong<Ax>: Sized {
type Output;
Expand Down Expand Up @@ -192,11 +192,14 @@ mod tests {
let b: Tensor<Rank3<3, 3, 4>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(usize, Const<3>, Const<4>)>()
.try_realize::<(usize, Const<3>, Const<4>)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(usize, Const<3>, Const<4>)>()
.unwrap();
let b_dyn = b.clone().realize::<(usize, Const<3>, Const<4>)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<0>);
let c = c.realize::<(Const<5>, Const<3>, Const<4>)>().unwrap();
let c = c.try_realize::<(Const<5>, Const<3>, Const<4>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand All @@ -222,11 +225,14 @@ mod tests {
let b: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(Const<2>, usize, Const<4>)>()
.try_realize::<(Const<2>, usize, Const<4>)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(Const<2>, usize, Const<4>)>()
.unwrap();
let b_dyn = b.clone().realize::<(Const<2>, usize, Const<4>)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<1>);
let c = c.realize::<(Const<2>, Const<5>, Const<4>)>().unwrap();
let c = c.try_realize::<(Const<2>, Const<5>, Const<4>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand All @@ -251,11 +257,14 @@ mod tests {
let b: Tensor<Rank3<2, 3, 3>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(Const<2>, Const<3>, usize)>()
.try_realize::<(Const<2>, Const<3>, usize)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(Const<2>, Const<3>, usize)>()
.unwrap();
let b_dyn = b.clone().realize::<(Const<2>, Const<3>, usize)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<2>);
let c = c.realize::<(Const<2>, Const<3>, Const<5>)>().unwrap();
let c = c.try_realize::<(Const<2>, Const<3>, Const<5>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand Down
143 changes: 78 additions & 65 deletions src/tensor_ops/realize_to.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
use crate::{shapes::*, tensor::*};

/// Changes order of dimensions/axes
/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let a = a.realize::<(usize, usize)>();
/// let mut a = a.realize::<Rank2<2, 3>>();
/// match a.try_realize::<(usize, Const<4>)>() {
/// Ok(new) => println!("Shape was properly realized, returned new tensor"),
/// Err(old) => println!("Shape could not be realized, returned the original tensor"),
/// }
/// ```
pub trait RealizeTo: HasErr + HasShape {
/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let a = a.realize::<(usize, usize)>().unwrap();
/// let mut a = a.realize::<Rank2<2, 3>>().unwrap();
/// match a.realize::<(usize, Const<4>)>() {
/// Ok(new) => println!("Shape was properly realized, returned new tensor"),
/// Err(old) => println!("Shape could not be realized, returned the original tensor"),
/// }
/// ```
fn realize<Dst: Shape<Concrete = <<Self as HasShape>::Shape as Shape>::Concrete>>(
self,
) -> Self::WithShape<Dst>
where
Self::Shape: RealizeShapeTo<Dst>,
Self: std::fmt::Debug,
{
self.try_realize::<Dst>().unwrap()
}

/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
fn try_realize<Dst: Shape<Concrete = <<Self as HasShape>::Shape as Shape>::Concrete>>(
self,
) -> Result<Self::WithShape<Dst>, Self>
where
Self::Shape: RealizeShapeTo<Dst>;
}

impl<S: Shape, E: Dtype, D: DeviceStorage, T: Tape<E, D>> RealizeTo for Tensor<S, E, D, T> {
fn realize<Dst: Shape<Concrete = S::Concrete>>(self) -> Result<Self::WithShape<Dst>, Self>
fn try_realize<Dst: Shape<Concrete = S::Concrete>>(self) -> Result<Self::WithShape<Dst>, Self>
where
Self::Shape: RealizeShapeTo<Dst>,
{
Expand Down Expand Up @@ -51,45 +64,39 @@ mod tests {
fn test_realize_2d() {
let dev: TestDevice = Default::default();
let src: Tensor<Rank2<2, 3>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<2>, usize), TestDtype, _> =
src.clone().realize::<(Const<2>, usize)>().unwrap();
let dst = src.clone().realize::<(Const<2>, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, Const<3>), TestDtype, _> =
src.clone().realize::<(usize, Const<3>)>().unwrap();
let dst = src.clone().realize::<(usize, Const<3>)>();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, usize), TestDtype, _> =
src.clone().realize::<(usize, usize)>().unwrap();
let dst: Tensor<(usize, usize), TestDtype, _> = src.clone().realize::<(usize, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
src = src.realize::<(usize, Const<4>)>().unwrap_err();
src = src.realize::<(Const<1>, usize)>().unwrap_err();
src = src.realize::<(Const<2>, Const<4>)>().unwrap_err();
src = src.realize::<(Const<3>, Const<2>)>().unwrap_err();
src = src.try_realize::<(usize, Const<4>)>().unwrap_err();
src = src.try_realize::<(Const<1>, usize)>().unwrap_err();
src = src.try_realize::<(Const<2>, Const<4>)>().unwrap_err();
src = src.try_realize::<(Const<3>, Const<2>)>().unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}

#[test]
fn test_realize_3d() {
let dev: TestDevice = Default::default();
let src: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<3>, usize, Const<7>), TestDtype, _> = src
.clone()
.realize::<(Const<3>, usize, Const<7>)>()
.unwrap();
let dst = src.clone().realize::<(Const<3>, usize, Const<7>)>();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, Const<5>, usize), TestDtype, _> =
src.clone().realize::<(usize, Const<5>, usize)>().unwrap();
let dst = src.clone().realize::<(usize, Const<5>, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, usize, usize), TestDtype, _> =
src.clone().realize::<(usize, usize, usize)>().unwrap();
let dst = src.clone().realize::<(usize, usize, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
// Ensure we get back the original tensor on error
src = src.realize::<(usize, Const<2>, usize)>().unwrap_err();
src = src.realize::<(Const<3>, Const<1>, Const<7>)>().unwrap_err();
src = src.realize::<(usize, usize, Const<3>)>().unwrap_err();
src = src.try_realize::<(usize, Const<2>, usize)>().unwrap_err();
src = src
.try_realize::<(Const<3>, Const<1>, Const<7>)>()
.unwrap_err();
src = src.try_realize::<(usize, usize, Const<3>)>().unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}

Expand All @@ -99,29 +106,29 @@ mod tests {
let src: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<3>, usize, Const<7>, usize), TestDtype, _> = src
.clone()
.realize::<(Const<3>, usize, Const<7>, usize)>()
.try_realize::<(Const<3>, usize, Const<7>, usize)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, usize, usize, usize), TestDtype, _> = src
.clone()
.realize::<(usize, usize, usize, usize)>()
.try_realize::<(usize, usize, usize, usize)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, Const<5>, Const<7>, Const<9>), TestDtype, _> = src
.clone()
.realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.try_realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
src = src
.realize::<(usize, Const<2>, usize, Const<9>)>()
.try_realize::<(usize, Const<2>, usize, Const<9>)>()
.unwrap_err();
src = src
.realize::<(Const<3>, Const<1>, Const<7>, Const<9>)>()
.try_realize::<(Const<3>, Const<1>, Const<7>, Const<9>)>()
.unwrap_err();
src = src
.realize::<(usize, usize, Const<3>, usize)>()
.try_realize::<(usize, usize, Const<3>, usize)>()
.unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}
Expand All @@ -133,7 +140,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize)>()
.try_realize::<(usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -148,7 +155,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize, usize)>()
.try_realize::<(usize, usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -163,7 +170,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize, usize, usize)>()
.try_realize::<(usize, usize, usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -176,39 +183,45 @@ mod tests {
let dev: TestDevice = Default::default();

let x: Tensor<Rank2<3, 5>, TestDtype, _> = dev.sample_normal();
let x = x.realize::<(Const<3>, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>)>().unwrap();
let _ = x.realize::<(usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>)>().unwrap();
let _ = x.try_realize::<(usize, usize)>().unwrap();

let x: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.sample_normal();
let x = x.realize::<(Const<3>, Const<5>, usize)>().unwrap();
let x = x.realize::<(Const<3>, usize, Const<7>)>().unwrap();
let x = x.realize::<(usize, Const<5>, Const<7>)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>, usize)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>)>().unwrap();
let _ = x.realize::<(usize, usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, Const<5>, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, Const<7>)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, Const<7>)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, Const<7>)>().unwrap();
let _ = x.try_realize::<(usize, usize, usize)>().unwrap();

let x: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.sample_normal();
let x = x
.realize::<(Const<3>, Const<5>, Const<7>, usize)>()
.try_realize::<(Const<3>, Const<5>, Const<7>, usize)>()
.unwrap();
let x = x
.try_realize::<(Const<3>, Const<5>, usize, Const<9>)>()
.unwrap();
let x = x
.try_realize::<(Const<3>, usize, Const<7>, Const<9>)>()
.unwrap();
let x = x
.try_realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.unwrap();
let x = x
.realize::<(Const<3>, Const<5>, usize, Const<9>)>()
.try_realize::<(Const<3>, Const<5>, usize, usize)>()
.unwrap();
let x = x
.realize::<(Const<3>, usize, Const<7>, Const<9>)>()
.try_realize::<(Const<3>, usize, usize, Const<9>)>()
.unwrap();
let x = x
.realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.try_realize::<(usize, usize, Const<7>, Const<9>)>()
.unwrap();
let x = x.realize::<(Const<3>, Const<5>, usize, usize)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize, Const<9>)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>, Const<9>)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>, usize, usize)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>, usize)>().unwrap();
let x = x.realize::<(usize, usize, usize, Const<9>)>().unwrap();
let _ = x.realize::<(usize, usize, usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, Const<7>, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, usize, Const<9>)>().unwrap();
let _ = x.try_realize::<(usize, usize, usize, usize)>().unwrap();
}
}
Loading
Loading