diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs index e6ab2eb2..25efc27e 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs @@ -26,11 +26,11 @@ impl super::ConcatAlongKernel for Cpu { let buf = std::sync::Arc::get_mut(&mut c.data).unwrap(); while i < n { for _ in 0..a_n { - buf[i] = a.data[a_idx.next().unwrap()]; + (*buf)[i] = a.data[a_idx.next().unwrap()]; i += 1; } for _ in 0..b_n { - buf[i] = b.data[b_idx.next().unwrap()]; + (*buf)[i] = b.data[b_idx.next().unwrap()]; i += 1; } } @@ -59,11 +59,11 @@ impl super::ConcatAlongKernel for Cpu { let n = grad_out.len(); while i < n { for _ in 0..a_n { - grad_a[a_idx.next().unwrap()] += grad_out[i]; + (*grad_a)[a_idx.next().unwrap()] += grad_out[i]; i += 1; } for _ in 0..b_n { - grad_b[b_idx.next().unwrap()] += grad_out[i]; + (*grad_b)[b_idx.next().unwrap()] += grad_out[i]; i += 1; } } diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs index 7462fd2b..9165efba 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -19,7 +19,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>); /// ``` /// /// Along Axis 1: @@ -28,7 +28,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor, f32, _> = dev.zeros(); /// let b: Tensor, f32, _> = dev.zeros(); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>); /// ``` /// /// # [usize] dims @@ -38,7 +38,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const)); /// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<0>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize(); /// ``` /// /// Along Axis 1: @@ -47,7 +47,7 @@ mod webgpu_kernel; /// # let dev: Cpu = Default::default(); /// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2)); /// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4)); -/// let _: Tensor, f32, _> = (a, b).concat_along(Axis::<1>).realize(); +/// let _: Tensor, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize(); /// ``` pub trait TryConcatTensorAlong: Sized { type Output; diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..38a03d14 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -200,6 +200,8 @@ mod sigmoid; mod sin; mod slice; mod softmax; +mod split_shape_along; +mod split_tensor_along; mod sqrt; mod square; mod stack; @@ -267,6 +269,8 @@ pub use sigmoid::sigmoid; pub use sin::sin; pub use slice::slice; pub use softmax::softmax; +pub use split_shape_along::TrySplitShapeAlong; +pub use split_tensor_along::TrySplitTensorAlong; pub use sqrt::sqrt; pub use square::square; pub use stack::{AddDim, TryStack}; diff --git a/dfdx-core/src/tensor_ops/split_shape_along/mod.rs b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs new file mode 100644 index 00000000..1421e12f --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_shape_along/mod.rs @@ -0,0 +1,158 @@ +use crate::{shapes::*, tensor::*}; + +/// Split a shape in two along a given axis. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4); +/// assert_eq!(a, (3, Const::<3>)); +/// assert_eq!(b, (4, Const::<3>)); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1); +/// assert_eq!(a, (Const::<7>, 2)); +/// assert_eq!(b, (Const::<7>, 1)); +/// ``` +pub trait TrySplitShapeAlong: Shape { + type Output; + + /// Splits self along the given axis. + fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_shape_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result; +} + +macro_rules! impl_split { + ($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => { + impl TrySplitShapeAlong, A, B> + for + ( + $($Head, )* + AB, + $($Tail, )* + ) + where + ($($Head, )* A, $($Tail, )*): Shape::Concrete>, + ($($Head, )* B, $($Tail, )*): Shape::Concrete>, + { + type Output = + ( + ($($Head, )* A, $($Tail, )*), + ($($Head, )* B, $($Tail, )*), + ); + + fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result { + let dims = self.concrete(); + let mut lhs_dims = dims; + let mut rhs_dims = dims; + lhs_dims[$Ax] = a.size(); + rhs_dims[$Ax] = b.size(); + assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]); + + Ok(( + <($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(), + <($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(), + )) + } + } + }; +} + +impl_split!(0, 1, [], []); +impl_split!(0, 2, [], [D1]); +impl_split!(0, 3, [], [D1, D2]); +impl_split!(0, 4, [], [D1, D2, D3]); +impl_split!(0, 5, [], [D1, D2, D3, D4]); +impl_split!(0, 6, [], [D1, D2, D3, D4, D5]); + +impl_split!(1, 2, [D0], []); +impl_split!(1, 3, [D0], [D2]); +impl_split!(1, 4, [D0], [D2, D3]); +impl_split!(1, 5, [D0], [D2, D3, D4]); +impl_split!(1, 6, [D0], [D2, D3, D4, D5]); + +impl_split!(2, 3, [D0, D1], []); +impl_split!(2, 4, [D0, D1], [D3]); +impl_split!(2, 5, [D0, D1], [D3, D4]); +impl_split!(2, 6, [D0, D1], [D3, D4, D5]); + +impl_split!(3, 4, [D0, D1, D2], []); +impl_split!(3, 5, [D0, D1, D2], [D4]); +impl_split!(3, 6, [D0, D1, D2], [D4, D5]); + +impl_split!(4, 5, [D0, D1, D2, D3], []); +impl_split!(4, 6, [D0, D1, D2, D3], [D5]); + +impl_split!(5, 6, [D0, D1, D2, D3, D4], []); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_shape() { + let a: (usize, Const<5>) = (5, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (usize, Const<5>) = (3, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + let a: (usize, Const<5>) = (5, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + + #[cfg(feature = "nightly")] + { + let a: (Const<5>, Const<5>) = (Const, Const); + let b: (Const<3>, Const<5>) = (Const, Const); + assert_eq!( + (Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0), + (a, b) + ); + } + } + + #[test] + #[should_panic = "left: 8\n right: 7"] + fn test_split_shape_fails() { + let a: (usize, Const<5>) = (4, Const); + let b: (usize, Const<5>) = (3, Const); + (8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs new file mode 100644 index 00000000..3e2fa5e1 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs @@ -0,0 +1,99 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{cpu::NdIndex, *}, +}; + +impl super::SplitAlongKernel for Cpu { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error> { + let mut a_n = 1; + let mut b_n = 1; + { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + } + + let n_ab = ab.data.len(); + + let buf_a = std::sync::Arc::get_mut(&mut a.data).unwrap(); + let buf_b = std::sync::Arc::get_mut(&mut b.data).unwrap(); + + let mut i = 0; + let mut k = 0; + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i < n_ab { + for j in 0..a_n { + (*buf_a)[j + k * a_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + for j in 0..b_n { + (*buf_b)[j + k * b_n] = ab.data[ab_idx.next().unwrap()]; + i += 1; + } + k += 1; + } + Ok(()) + } + + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + let a_idx = NdIndex::new(a.shape, a.strides); + let b_idx = NdIndex::new(b.shape, b.strides); + + let mut a_n = 1; + let mut b_n = 1; + for i in ax..A::NUM_DIMS { + a_n *= a_idx.shape[i]; + b_n *= b_idx.shape[i]; + } + + let mut i = 0; + let mut j = 0; + let n = grad_ab.len(); + let mut ab_idx = NdIndex::new(ab.shape, ab.strides); + while i + j < n { + match a_or_b { + AorB::A => { + for _ in 0..a_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + for _ in 0..b_n { + ab_idx.next().unwrap(); + j += 1; + } + } + AorB::B => { + for _ in 0..a_n { + ab_idx.next().unwrap(); + j += 1; + } + for _ in 0..b_n { + (*grad_ab)[ab_idx.next().unwrap()] = grad_out[i]; + i += 1; + } + } + }; + } + + Ok(()) + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs new file mode 100644 index 00000000..515f0365 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs @@ -0,0 +1,31 @@ +use super::AorB; +use crate::{ + shapes::*, + tensor::{Cuda, Error, GhostTensor, Tensor}, +}; +use cudarc::types::CudaTypeName; + +impl super::SplitAlongKernel for Cuda { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs new file mode 100644 index 00000000..ac619301 --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/mod.rs @@ -0,0 +1,275 @@ +use super::split_shape_along::TrySplitShapeAlong; +use crate::{shapes::*, tensor::*}; + +pub(crate) mod cpu_kernel; +#[cfg(feature = "cuda")] +pub(crate) mod cuda_kernel; +#[cfg(feature = "webgpu")] +mod webgpu_kernel; + +/// Split a tensor in two along a given axis. +/// +/// This is the reverse of [TryConcatTensorAlong::concat_tensor_along]. +/// +/// # [Const] dims **requires nightly** +/// +/// Along Axis 0: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, Const::<2>, Const::<3>); +/// ``` +/// +/// Along Axis 1: +/// ```ignore +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor, f32, _> = dev.zeros(); +/// let (a, b, _tape): ( +/// Tensor, f32, _>, +/// Tensor, f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, Const::<2>, Const::<3>); +/// ``` +/// +/// # [usize] dims +/// Along Axis 0: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(usize, Const::<4>), f32, _> = dev.zeros_like(&(5, Const)); +/// let (a, b, _tape): ( +/// Tensor<(usize, Const::<4>), f32, _>, +/// Tensor<(usize, Const::<4>), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<0>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +/// +/// Along Axis 1: +/// ```rust +/// # use dfdx_core::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let ab: Tensor<(Const::<4>, usize), f32, _> = dev.zeros_like(&(Const, 5)); +/// let (a, b, _tape): ( +/// Tensor<(Const::<4>, usize), f32, _>, +/// Tensor<(Const::<4>, usize), f32, _>, +/// _ +/// ) = ab.split_tensor_along(Axis::<1>, 2, 3); +/// let a: Tensor, f32, _> = a.realize(); +/// let b: Tensor, f32, _> = b.realize(); +/// ``` +pub trait TrySplitTensorAlong: Sized { + type Output; + + /// Splits self along the given axis. + fn split_tensor_along(self, ax: Ax, a: A, b: B) -> Self::Output { + self.try_split_tensor_along(ax, a, b).unwrap() + } + /// Fallibly splits self along the given axis. + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result; +} + +#[derive(Debug, Clone)] +pub enum AorB { + A, + B, +} + +pub trait SplitAlongKernel: Storage { + fn forward( + &self, + ax: usize, + ab: &Tensor, + a: &mut Tensor, + b: &mut Tensor, + ) -> Result<(), Error>; + + #[allow(clippy::too_many_arguments)] + fn backward( + &self, + ax: usize, + ab: &GhostTensor, + grad_ab: &mut Self::Vec, + a: &GhostTensor, + b: &GhostTensor, + a_or_b: AorB, + grad_out: &Self::Vec, + ) -> Result<(), Error>; +} + +impl> TrySplitTensorAlong + for Tensor +where + Ax: Axes, + A: Dim, + B: Dim, + AS: Shape, + BS: Shape, + AB: Shape + TrySplitShapeAlong, + D: SplitAlongKernel + ZerosTensor, +{ + type Output = (Tensor, Tensor, T); + + fn try_split_tensor_along(self, ax: Ax, a: A, b: B) -> Result { + let device = self.device.clone(); + let (a_shape, b_shape) = (*self.shape()).try_split_shape_along(ax, a, b)?; + let ax = Ax::as_array()[0] as usize; + + let (ab, tape) = self.split_tape(); + + let mut at: Tensor = device.try_zeros_like(&a_shape)?; + let mut bt: Tensor = device.try_zeros_like(&b_shape)?; + + ab.device.forward(ax, &ab, &mut at, &mut bt)?; + + let mut ta = T::default(); + let mut tb = T::default(); + + let device_b = device.clone(); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + ta.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&a_ghost)?; + let (ab_grad, a_grad) = grads.mut_and_ref(&ab_ghost, &a_ghost); + device.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::A, a_grad) + }); + + let ab_ghost = ab.ghost(); + let a_ghost = at.ghost(); + let b_ghost = bt.ghost(); + tb.add_backward_op(move |grads| { + grads.try_alloc_for(&ab_ghost)?; + grads.try_alloc_for(&b_ghost)?; + let (ab_grad, b_grad) = grads.mut_and_ref(&ab_ghost, &b_ghost); + device_b.backward(ax, &ab_ghost, ab_grad, &a_ghost, &b_ghost, AorB::B, b_grad) + }); + + let at = at.put_tape(ta); + let bt = bt.put_tape(tb); + Ok((at, bt, tape)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{tensor_ops::*, tests::*}; + + #[test] + fn test_split_ax_0() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(usize, Const<3>, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<0>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<3>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + assert_eq!(ab_arr[0], a_arr[0]); + assert_eq!(ab_arr[1], a_arr[1]); + assert_eq!(ab_arr[2], b_arr[0]); + assert_eq!(ab_arr[3], b_arr[1]); + assert_eq!(ab_arr[4], b_arr[2]); + + let ab_concat = (a, b).concat_tensor_along(Axis::<0>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_1() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, usize, Const<4>)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<1>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<2>, Const<4>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<4>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + assert_eq!(ab_arr[i][0], a_arr[i][0]); + assert_eq!(ab_arr[i][1], a_arr[i][1]); + assert_eq!(ab_arr[i][2], b_arr[i][0]); + assert_eq!(ab_arr[i][3], b_arr[i][1]); + assert_eq!(ab_arr[i][4], b_arr[i][2]); + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<1>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } + + #[test] + fn test_split_ax_2() { + let dev: TestDevice = Default::default(); + let ab: Tensor, TestDtype, _> = dev.sample_normal(); + let ab_dyn = ab + .leaky_trace() + .try_realize::<(Const<2>, Const<3>, usize)>() + .unwrap(); + let (a, b, _tape) = ab_dyn.split_tensor_along(Axis::<2>, 2, 3); + let a = a.try_realize::<(Const<2>, Const<3>, Const<2>)>().unwrap(); + let b = b.try_realize::<(Const<2>, Const<3>, Const<3>)>().unwrap(); + let ab_arr = ab.array(); + let a_arr = a.array(); + let b_arr = b.array(); + println!("{a_arr:?}"); + println!("{b_arr:?}"); + println!("{ab_arr:?}"); + + for i in 0..2 { + for j in 0..3 { + assert_eq!(ab_arr[i][j][0], a_arr[i][j][0]); + assert_eq!(ab_arr[i][j][1], a_arr[i][j][1]); + assert_eq!(ab_arr[i][j][2], b_arr[i][j][0]); + assert_eq!(ab_arr[i][j][3], b_arr[i][j][1]); + assert_eq!(ab_arr[i][j][4], b_arr[i][j][2]); + } + } + + let ab_concat = (a, b).concat_tensor_along(Axis::<2>); + assert_eq!(ab.array(), ab_concat.array()); + let concat_grads = ab_concat.exp().sum().backward(); + let ab_grads = ab.leaky_trace().exp().sum().backward(); + + println!("{:?}", concat_grads.get(&ab).array()); + println!("{:?}", ab_grads.get(&ab).array()); + + assert_close_to_tensor!(concat_grads.get(&ab), ab_grads.get(&ab)); + } +} diff --git a/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs new file mode 100644 index 00000000..be1923dd --- /dev/null +++ b/dfdx-core/src/tensor_ops/split_tensor_along/webgpu_kernel.rs @@ -0,0 +1,26 @@ +use crate::{shapes::*, tensor::*}; + +impl super::ConcatAlongKernel for Webgpu { + fn forward( + &self, + _ax: usize, + _ab: &Tensor, + _a: &mut Tensor, + _b: &mut Tensor, + ) -> Result<(), Error> { + todo!() + } + + fn backward( + &self, + _ax: usize, + _ab: &GhostTensor, + _grad_ab: &mut Self::Vec, + _a: &GhostTensor, + _b: &GhostTensor, + _a_or_b: AorB, + _grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!() + } +} diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..0869b6c1 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -21,6 +21,9 @@ pub trait Device: + super::super::concat_tensor_along::ConcatAlongKernel + super::super::concat_tensor_along::ConcatAlongKernel + // splits + + super::super::split_tensor_along::SplitAlongKernel + // optimizers + super::super::adam::AdamKernel + super::super::sgd::SgdKernel