Skip to content

Commit

Permalink
Add split_tensor_along method
Browse files Browse the repository at this point in the history
- Add `TrySplitShapeAlong` and `TrySplitTensorAlong`.
- Minor linting and docs fix.

TODO
- Check if the tape should be returned. If not, it can be removed from the interface.
- Add cuda kernel.
- Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec.
  In this way it could get closer to the pytorch interface (chunks).
  • Loading branch information
swfsql committed Feb 1, 2024
1 parent 4722a99 commit 693b699
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 8 deletions.
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
let buf = std::sync::Arc::get_mut(&mut c.data).unwrap();
while i < n {
for _ in 0..a_n {
buf[i] = a.data[a_idx.next().unwrap()];
(*buf)[i] = a.data[a_idx.next().unwrap()];
i += 1;
}
for _ in 0..b_n {
buf[i] = b.data[b_idx.next().unwrap()];
(*buf)[i] = b.data[b_idx.next().unwrap()];
i += 1;
}
}
Expand Down Expand Up @@ -59,11 +59,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
let n = grad_out.len();
while i < n {
for _ in 0..a_n {
grad_a[a_idx.next().unwrap()] += grad_out[i];
(*grad_a)[a_idx.next().unwrap()] += grad_out[i];
i += 1;
}
for _ in 0..b_n {
grad_b[b_idx.next().unwrap()] += grad_out[i];
(*grad_b)[b_idx.next().unwrap()] += grad_out[i];
i += 1;
}
}
Expand Down
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_along(Axis::<0>);
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_tensor_along(Axis::<0>);
/// ```
///
/// Along Axis 1:
Expand All @@ -28,7 +28,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_along(Axis::<1>);
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_tensor_along(Axis::<1>);
/// ```
///
/// # [usize] dims
Expand All @@ -38,7 +38,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const));
/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const));
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize();
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize();
/// ```
///
/// Along Axis 1:
Expand All @@ -47,7 +47,7 @@ mod webgpu_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize();
/// ```
pub trait TryConcatTensorAlong<Ax>: Sized {
type Output;
Expand Down
4 changes: 4 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ mod sigmoid;
mod sin;
mod slice;
mod softmax;
mod split_shape_along;
mod split_tensor_along;
mod sqrt;
mod square;
mod stack;
Expand Down Expand Up @@ -267,6 +269,8 @@ pub use sigmoid::sigmoid;
pub use sin::sin;
pub use slice::slice;
pub use softmax::softmax;
pub use split_shape_along::TrySplitShapeAlong;
pub use split_tensor_along::TrySplitTensorAlong;
pub use sqrt::sqrt;
pub use square::square;
pub use stack::{AddDim, TryStack};
Expand Down
158 changes: 158 additions & 0 deletions dfdx-core/src/tensor_ops/split_shape_along/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use crate::{shapes::*, tensor::*};

/// Split a shape in two along a given axis.
///
/// # [Const] dims **requires nightly**
///
/// Along Axis 0:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>);
/// ```
///
/// Along Axis 1:
/// ```ignore
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>);
/// ```
///
/// # [usize] dims
/// Along Axis 0:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4);
/// assert_eq!(a, (3, Const::<3>));
/// assert_eq!(b, (4, Const::<3>));
/// ```
///
/// Along Axis 1:
/// ```rust
/// # use dfdx_core::prelude::*;
/// # let dev: Cpu = Default::default();
/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1);
/// assert_eq!(a, (Const::<7>, 2));
/// assert_eq!(b, (Const::<7>, 1));
/// ```
pub trait TrySplitShapeAlong<Ax, A: Dim, B: Dim>: Shape {
type Output;

/// Splits self along the given axis.
fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output {
self.try_split_shape_along(ax, a, b).unwrap()
}
/// Fallibly splits self along the given axis.
fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result<Self::Output, Error>;
}

macro_rules! impl_split {
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
impl<A: Dim, B: Dim, AB:Dim, $($Head: Dim, )* $($Tail: Dim, )*> TrySplitShapeAlong<Axis<$Ax>, A, B>
for
(
$($Head, )*
AB,
$($Tail, )*
)
where
($($Head, )* A, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
($($Head, )* B, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
{
type Output =
(
($($Head, )* A, $($Tail, )*),
($($Head, )* B, $($Tail, )*),
);

fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result<Self::Output, Error> {
let dims = self.concrete();
let mut lhs_dims = dims;
let mut rhs_dims = dims;
lhs_dims[$Ax] = a.size();
rhs_dims[$Ax] = b.size();
assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]);

Ok((
<($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(),
<($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(),
))
}
}
};
}

impl_split!(0, 1, [], []);
impl_split!(0, 2, [], [D1]);
impl_split!(0, 3, [], [D1, D2]);
impl_split!(0, 4, [], [D1, D2, D3]);
impl_split!(0, 5, [], [D1, D2, D3, D4]);
impl_split!(0, 6, [], [D1, D2, D3, D4, D5]);

impl_split!(1, 2, [D0], []);
impl_split!(1, 3, [D0], [D2]);
impl_split!(1, 4, [D0], [D2, D3]);
impl_split!(1, 5, [D0], [D2, D3, D4]);
impl_split!(1, 6, [D0], [D2, D3, D4, D5]);

impl_split!(2, 3, [D0, D1], []);
impl_split!(2, 4, [D0, D1], [D3]);
impl_split!(2, 5, [D0, D1], [D3, D4]);
impl_split!(2, 6, [D0, D1], [D3, D4, D5]);

impl_split!(3, 4, [D0, D1, D2], []);
impl_split!(3, 5, [D0, D1, D2], [D4]);
impl_split!(3, 6, [D0, D1, D2], [D4, D5]);

impl_split!(4, 5, [D0, D1, D2, D3], []);
impl_split!(4, 6, [D0, D1, D2, D3], [D5]);

impl_split!(5, 6, [D0, D1, D2, D3, D4], []);

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_split_shape() {
let a: (usize, Const<5>) = (5, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

let a: (Const<5>, Const<5>) = (Const, Const);
let b: (usize, Const<5>) = (3, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

let a: (usize, Const<5>) = (5, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!(
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);

#[cfg(feature = "nightly")]
{
let a: (Const<5>, Const<5>) = (Const, Const);
let b: (Const<3>, Const<5>) = (Const, Const);
assert_eq!(
(Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
(a, b)
);
}
}

#[test]
#[should_panic = "left: 8\n right: 7"]
fn test_split_shape_fails() {
let a: (usize, Const<5>) = (4, Const);
let b: (usize, Const<5>) = (3, Const);
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0);
}
}
99 changes: 99 additions & 0 deletions dfdx-core/src/tensor_ops/split_tensor_along/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use super::AorB;
use crate::{
shapes::*,
tensor::{cpu::NdIndex, *},
};

impl<E: Dtype> super::SplitAlongKernel<E> for Cpu {
fn forward<AB: Shape, A: Shape, B: Shape>(
&self,
ax: usize,
ab: &Tensor<AB, E, Self>,
a: &mut Tensor<A, E, Self>,
b: &mut Tensor<B, E, Self>,
) -> Result<(), Error> {
let mut a_n = 1;
let mut b_n = 1;
{
let a_idx = NdIndex::new(a.shape, a.strides);
let b_idx = NdIndex::new(b.shape, b.strides);
for i in ax..A::NUM_DIMS {
a_n *= a_idx.shape[i];
b_n *= b_idx.shape[i];
}
}

let n_ab = ab.data.len();

let buf_a = std::sync::Arc::get_mut(&mut a.data).unwrap();
let buf_b = std::sync::Arc::get_mut(&mut b.data).unwrap();

let mut i = 0;
let mut k = 0;
let mut ab_idx = NdIndex::new(ab.shape, ab.strides);
while i < n_ab {
for j in 0..a_n {
(*buf_a)[j + k * a_n] = ab.data[ab_idx.next().unwrap()];
i += 1;
}
for j in 0..b_n {
(*buf_b)[j + k * b_n] = ab.data[ab_idx.next().unwrap()];
i += 1;
}
k += 1;
}
Ok(())
}

fn backward<AB: Shape, A: Shape, B: Shape>(
&self,
ax: usize,
ab: &GhostTensor<AB, E, Self>,
grad_ab: &mut Self::Vec,
a: &GhostTensor<A, E, Self>,
b: &GhostTensor<B, E, Self>,
a_or_b: AorB,
grad_out: &Self::Vec,
) -> Result<(), Error> {
let a_idx = NdIndex::new(a.shape, a.strides);
let b_idx = NdIndex::new(b.shape, b.strides);

let mut a_n = 1;
let mut b_n = 1;
for i in ax..A::NUM_DIMS {
a_n *= a_idx.shape[i];
b_n *= b_idx.shape[i];
}

let mut i = 0;
let mut j = 0;
let n = grad_ab.len();
let mut ab_idx = NdIndex::new(ab.shape, ab.strides);
while i + j < n {
match a_or_b {
AorB::A => {
for _ in 0..a_n {
(*grad_ab)[ab_idx.next().unwrap()] = grad_out[i];
i += 1;
}
for _ in 0..b_n {
ab_idx.next().unwrap();
j += 1;
}
}
AorB::B => {
for _ in 0..a_n {
ab_idx.next().unwrap();
j += 1;
}
for _ in 0..b_n {
(*grad_ab)[ab_idx.next().unwrap()] = grad_out[i];
i += 1;
}
}
};
}

Ok(())
}
}
31 changes: 31 additions & 0 deletions dfdx-core/src/tensor_ops/split_tensor_along/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use super::AorB;
use crate::{
shapes::*,
tensor::{Cuda, Error, GhostTensor, Tensor},
};
use cudarc::types::CudaTypeName;

impl<E: Dtype + CudaTypeName> super::SplitAlongKernel<E> for Cuda {
fn forward<AB: Shape, A: Shape, B: Shape>(
&self,
_ax: usize,
_ab: &Tensor<AB, E, Self>,
_a: &mut Tensor<A, E, Self>,
_b: &mut Tensor<B, E, Self>,
) -> Result<(), Error> {
todo!()
}

fn backward<AB: Shape, A: Shape, B: Shape>(
&self,
_ax: usize,
_ab: &GhostTensor<AB, E, Self>,
_grad_ab: &mut Self::Vec,
_a: &GhostTensor<A, E, Self>,
_b: &GhostTensor<B, E, Self>,
_a_or_b: AorB,
_grad_out: &Self::Vec,
) -> Result<(), Error> {
todo!()
}
}

0 comments on commit 693b699

Please sign in to comment.