From 09f8084cf9d49076ab99ba5a7c7b9a3eac58ec3d Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 5 Oct 2022 08:50:31 -0400 Subject: [PATCH] Renaming SelectTo, using SelectTo for batched select (#217) --- examples/10-tensor-index.rs | 2 +- src/tensor_ops/mod.rs | 2 +- src/tensor_ops/select.rs | 119 +++++++++++++++--------------------- 3 files changed, 50 insertions(+), 73 deletions(-) diff --git a/examples/10-tensor-index.rs b/examples/10-tensor-index.rs index 7293b6f40..3009a3580 100644 --- a/examples/10-tensor-index.rs +++ b/examples/10-tensor-index.rs @@ -2,7 +2,7 @@ use dfdx::arrays::HasArrayData; use dfdx::tensor::{tensor, Tensor2D, Tensor3D}; -use dfdx::tensor_ops::Select1; +use dfdx::tensor_ops::SelectTo; fn main() { let a: Tensor3D<3, 2, 3> = tensor([ diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 5d25435e2..ec0786902 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -102,7 +102,7 @@ //! //! # Selects/Indexing //! -//! Selecting or indexing into a tensor is done via [Select1::select()]. This traits enables +//! Selecting or indexing into a tensor is done via [SelectTo::select()]. This traits enables //! 2 behaviors for each axis of a given tensor: //! //! 1. Select exactly 1 element from that axis. diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index a64818c4e..00024666f 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -1,11 +1,9 @@ use super::utils::move_tape_and_add_backward_op; -use crate::devices::{ - BSelectAx1, Device, DeviceSelect, FillElements, SelectAx0, SelectAx1, SelectAx2, SelectAx3, -}; +use crate::devices::*; use crate::gradients::Tape; use crate::prelude::*; -/// Select values along a single axis `I` resulting in `T`. Equivalent +/// Select values along `Axes` resulting in `T`. Equivalent /// to `torch.select` and `torch.gather` from pytorch. /// /// There are two ways to select: @@ -14,7 +12,9 @@ use crate::prelude::*; /// 2. Select multiple values from an axis, which keeps the number /// of dimensions the same. You can select the same element multiple /// number of times. -pub trait Select1 { +/// +/// You can also select batches of data with this trait. +pub trait SelectTo { type Indices: Clone; /// Select sub elements using [Self::Indices]. @@ -44,12 +44,24 @@ pub trait Select1 { /// // is the new size of the 1st axis. /// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); /// ``` + /// + /// Selecting batch of values from a 1d tensor: + /// ```rust + /// # use dfdx::prelude::*; + /// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]); + ///``` + /// + /// Selecting batch of values from a 2d tensor: + /// ```rust + /// # use dfdx::prelude::*; + /// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]); + ///``` fn select(self, indices: &Self::Indices) -> T; } macro_rules! impl_select { - ($Axis:expr, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { -impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy { + ($Axes:ty, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { +impl<$(const $Dims: usize, )* H: Tape> SelectTo<$DstTy, $Axes> for $SrcTy { type Indices = $IndTy; fn select(self, indices: &Self::Indices) -> $DstTy { select::<_, _, _, $Mode>(self, indices) @@ -59,72 +71,37 @@ impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy { } // 1d -impl_select!(-1, SelectAx0, Tensor1D, usize, Tensor0D, {M}); -impl_select!(-1, SelectAx0, Tensor1D, [usize; Z], Tensor1D, {M, Z}); +impl_select!(Axis<0>, SelectAx0, Tensor1D, usize, Tensor0D, {M}); +impl_select!(Axis<0>, SelectAx0, Tensor1D, [usize; Z], Tensor1D, {M, Z}); // 2d -impl_select!(0, SelectAx0, Tensor2D, usize, Tensor1D, {M, N}); -impl_select!(0, SelectAx0, Tensor2D, [usize; Z], Tensor2D, {M, N, Z}); -impl_select!(-1, SelectAx1, Tensor2D, [usize; M], Tensor1D, {M, N}); -impl_select!(-1, SelectAx1, Tensor2D, [[usize; Z]; M], Tensor2D, {M, N, Z}); +impl_select!(Axis<0>, SelectAx0, Tensor2D, usize, Tensor1D, {M, N}); +impl_select!(Axis<0>, SelectAx0, Tensor2D, [usize; Z], Tensor2D, {M, N, Z}); +impl_select!(Axis<1>, SelectAx1, Tensor2D, [usize; M], Tensor1D, {M, N}); +impl_select!(Axis<1>, SelectAx1, Tensor2D, [[usize; Z]; M], Tensor2D, {M, N, Z}); // 3d -impl_select!(0, SelectAx0, Tensor3D, usize, Tensor2D, {M, N, O}); -impl_select!(0, SelectAx0, Tensor3D, [usize; Z], Tensor3D, {M, N, O, Z}); -impl_select!(1, SelectAx1, Tensor3D, [usize; M], Tensor2D, {M, N, O}); -impl_select!(1, SelectAx1, Tensor3D, [[usize; Z]; M], Tensor3D, {M, N, O, Z}); -impl_select!(-1, SelectAx2, Tensor3D, [[usize; N]; M], Tensor2D, {M, N, O}); -impl_select!(-1, SelectAx2, Tensor3D, [[[usize; Z]; N]; M], Tensor3D, {M, N, O, Z}); +impl_select!(Axis<0>, SelectAx0, Tensor3D, usize, Tensor2D, {M, N, O}); +impl_select!(Axis<0>, SelectAx0, Tensor3D, [usize; Z], Tensor3D, {M, N, O, Z}); +impl_select!(Axis<1>, SelectAx1, Tensor3D, [usize; M], Tensor2D, {M, N, O}); +impl_select!(Axis<1>, SelectAx1, Tensor3D, [[usize; Z]; M], Tensor3D, {M, N, O, Z}); +impl_select!(Axis<2>, SelectAx2, Tensor3D, [[usize; N]; M], Tensor2D, {M, N, O}); +impl_select!(Axis<2>, SelectAx2, Tensor3D, [[[usize; Z]; N]; M], Tensor3D, {M, N, O, Z}); // 4d -impl_select!(0, SelectAx0, Tensor4D, usize, Tensor3D, {M, N, O, P}); -impl_select!(0, SelectAx0, Tensor4D, [usize; Z], Tensor4D, {M, N, O, P, Z}); -impl_select!(1, SelectAx1, Tensor4D, [usize; M], Tensor3D, {M, N, O, P}); -impl_select!(1, SelectAx1, Tensor4D, [[usize; Z]; M], Tensor4D, {M, N, O, P, Z}); -impl_select!(2, SelectAx2, Tensor4D, [[usize; N]; M], Tensor3D, {M, N, O, P}); -impl_select!(2, SelectAx2, Tensor4D, [[[usize; Z]; N]; M], Tensor4D, {M, N, O, P, Z}); -impl_select!(-1, SelectAx3, Tensor4D, [[[usize; O]; N]; M], Tensor3D, {M, N, O, P}); -impl_select!(-1, SelectAx3, Tensor4D, [[[[usize; Z]; O]; N]; M], Tensor4D, {M, N, O, P, Z}); - -/// Select batched values from axis 0, resulting in `T`. Equivalent -/// to `torch.select` and `torch.gather` from pytorch. -pub trait SelectBatchAx0 { - type Indices; - - /// Select sub elements using [Self::Indices]. - /// The same element can be selected multiple times depending - /// on [Self::Indices]. - /// - /// This results in a tensor 1 dimension larger than self. - /// - /// Selecting batch of values from a 1d tensor: - /// ```rust - /// # use dfdx::prelude::*; - /// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]); - ///``` - /// - /// Selecting batch of values from a 2d tensor: - /// ```rust - /// # use dfdx::prelude::*; - /// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]); - ///``` - fn select_batch(self, indices: &Self::Indices) -> T; -} - -macro_rules! impl_select_batch { - ($SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { -impl<$(const $Dims: usize, )* H: Tape> SelectBatchAx0<$DstTy> for $SrcTy { - type Indices = $IndTy; - fn select_batch(self, indices: &Self::Indices) -> $DstTy { - select::<_, _, _, BSelectAx1>(self, indices) - } -} - }; -} - -impl_select_batch!(Tensor1D, [[usize; Z]; B], Tensor2D, {M, B, Z}); -impl_select_batch!(Tensor2D, [[usize; Z]; B], Tensor3D, {M, N, B, Z}); -impl_select_batch!(Tensor3D, [[usize; Z]; B], Tensor4D, {M, N, O, B, Z}); +impl_select!(Axis<0>, SelectAx0, Tensor4D, usize, Tensor3D, {M, N, O, P}); +impl_select!(Axis<0>, SelectAx0, Tensor4D, [usize; Z], Tensor4D, {M, N, O, P, Z}); +impl_select!(Axis<1>, SelectAx1, Tensor4D, [usize; M], Tensor3D, {M, N, O, P}); +impl_select!(Axis<1>, SelectAx1, Tensor4D, [[usize; Z]; M], Tensor4D, {M, N, O, P, Z}); +impl_select!(Axis<2>, SelectAx2, Tensor4D, [[usize; N]; M], Tensor3D, {M, N, O, P}); +impl_select!(Axis<2>, SelectAx2, Tensor4D, [[[usize; Z]; N]; M], Tensor4D, {M, N, O, P, Z}); +impl_select!(Axis<3>, SelectAx3, Tensor4D, [[[usize; O]; N]; M], Tensor3D, {M, N, O, P}); +impl_select!(Axis<3>, SelectAx3, Tensor4D, [[[[usize; Z]; O]; N]; M], Tensor4D, {M, N, O, P, Z}); + +// batched select +impl_select!(Axis<0>, BSelectAx1, Tensor1D, [[usize; Z]; B], Tensor2D, {M, B, Z}); +impl_select!(Axis<0>, BSelectAx1, Tensor2D, [[usize; Z]; B], Tensor3D, {M, N, B, Z}); +impl_select!(Axis<0>, BSelectAx1, Tensor3D, [[usize; Z]; B], Tensor4D, {M, N, O, B, Z}); pub(crate) fn select(t: T, indices: &I) -> R where @@ -162,9 +139,9 @@ mod tests { #[test] fn test_valid_select_batches() { - let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]); - let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]); - let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [0]]); + let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]); + let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]); + let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select(&[[0], [0]]); } #[test] @@ -255,7 +232,7 @@ mod tests { fn test_select_batch_backwards() { let mut rng = thread_rng(); let t: Tensor2D<4, 5> = TensorCreator::randn(&mut rng); - let r: Tensor3D<2, 3, 5, _> = t.trace().select_batch(&[[2, 0, 3], [0, 0, 3]]); + let r: Tensor3D<2, 3, 5, _> = t.trace().select(&[[2, 0, 3], [0, 0, 3]]); let r0: Tensor2D<3, 5> = t.clone().select(&[2, 0, 3]); let r1: Tensor2D<3, 5> = t.clone().select(&[0, 0, 3]); assert_close(&r.data()[0], r0.data());