Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworking crate level documentation #644

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 149 additions & 78 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,104 +1,175 @@
//! Ergonomics & safety focused deep learning in Rust. Main features include:
//! 1. Tensor library with shapes up to 6d!
//! 2. Shapes with both compile and runtime sized dimensions. (e.g. `Tensor<(usize, Const<10>)>` and `Tensor<Rank2<5, 10>>`)
//! 3. A large library of tensor operations (including `matmul`, `conv2d`, and much more).
//! a. All tensor operations shape and type checked at compile time!!
//! 4. Ergonomic neural network building blocks (like `Linear`, `Conv2D`, and `Transformer`).
//! 5. Standard deep learning optimizers such as `Sgd`, `Adam`, `AdamW`, `RMSprop`, and more.
//! 6. Reverse mode auto differentiation implementation.
//! 7. Serialization to/from `.npy` and `.npz` for transferring models to/from python.
//!
//! # A quick tutorial
//!
//! 1. [crate::tensor::Tensor]s can be created with normal rust arrays. See [crate::tensor].
//! # dfdx
//!
//! dfdx is a cuda accelerated tensor and neural network library, writtten
//! entirely in rust!
//!
//! Additionally, it can track compile time shapes across tensor operations,
//! ensuring that all your neural networks are checked **at compile time**.
//!
//! The following sections provide some high level core concepts & exmaples, and
//! there is more detailed documentation in each of dfdx's submodules.
//!
//! See [feature_flags] for details on feature flags.
//!
//! # Shapes & Tensors
//!
//! *See [shapes] and [tensor] for more information.*
//!
//! At its core a [`tensor::Tensor`] is just a nd-array. Just like
//! rust arrays there are two parts:
//! 1. Shape
//! 2. Dtype
//!
//! dfdx represents shapes as **tuples** of dimensions ([`shapes::Dim`]),
//! where a dimension can either be known at:
//! 1. Compile time [`shapes::Const<M>`]
//! 2. Run time [`usize`]
//!
//! You can freely mix and match these dimensions together. Here are some
//! example shapes:
//! - `()` - unit shape
//! - `(usize,)` - 1d shape with a runtime known dimension
//! - `(usize, Const<5>)` - 2d shape with both types of dimensions
//! - `(Const<3>, usize, Const<5>)` - 3d shape!
//!
//! Here are some comparisons between representing nd arrays in rust vs dfdx:
//!
//! | rust array | dfdx `Tensor` |
//! | --- | --- |
//! | f32 | Tensor<(), f32, ...> |
//! | [u32; 5] | Tensor<Rank1<5>, u32, ...> |
//! | [[u8; 3]; 2] | Tensor<Rank2<2, 3>, u8, ...> |
//! | Vec<[bool; 5]> | Tensor<(usize, Const<5>), bool, ...> |
//!
//! The `Rank1`, `Rank2` shapes used above are actually type aliases for
//! when **all dimensions are compile time**:
//! - [`shapes::Rank0`] is just `()`.
//! - [`shapes::Rank1<M>`] is `(Const<M>, )`
//! - [`shapes::Rank2<M, N>`] is `(Const<M>, Const<N>)`
//!
//! # Allocating tensors with Devices
//!
//! *See [tensor] for more information.*
//!
//! Devices are used to allocate tensors (and neural networks!). They are akin
//! to [std::alloc::GlobalAlloc] in rust - they just allocate memory.
//! They are also used to execute tensor ops, which we will get to later on.
//!
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
//!
//! Here's how you might use a device:
//!
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! let x = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
//! let y: Tensor<Rank2<2, 3>, f32, Cpu> = dev.ones();
//! // Runtime shape
//! let z: Tensor<(usize, Const<3>), f32, _> = dev.ones_like(&(10, Const));
//! let t: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
//! ```
//!
//! 2. Neural networks are built with types. Tuples are sequential models. See [crate::nn].
//! ```rust
//! # use dfdx::prelude::*;
//! type Mlp = (
//! Linear<5, 3>,
//! ReLU,
//! Linear<3, 2>,
//! );
//! ```
//! # Tensor Operations (tip of the iceberg)
//!
//! *See [tensor_ops] for more information*
//!
//! Once you've instantiated tensors with a device, you can start doing operations on them!
//! There are **many many** operations, here are a few core ones and how they related
//! to things like numpy/pytorch:
//!
//! | Operation | dfdx | numpy | pytorch |
//! | --- | --- | --- | --- |
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
//! | Broadcasting | [tensor_ops::BroadcastTo] | implicit/`np.broadcast` | implicit/`torch.broadcast_to` |
//! | Permute | [tensor_ops::PermuteTo] | `np.transpose(...)` | `torch.permute` |
//! | Where | [tensor_ops::ChooseFrom] | `np.where` | `torch.where` |
//! | Reshape | [tensor_ops::ReshapeTo] | `np.reshape(shape)` | `a.reshape(shape)` |
//! | View | [tensor_ops::ReshapeTo] | `np.view(...)` | `a.view(...)` |
//! | Roll | [tensor_ops::Roll] | `np.rollaxis(...)` | `a.roll(...)` |
//! | Stack | [tensor_ops::TryStack] | `np.stack` | `torch.stack` |
//! | Concat | [tensor_ops::TryConcat] | `np.concatenate` | `torch.concat` |
//!
//! and **much much more!**
//!
//! # Neural networks
//!
//! *See [nn] for more information.*
//!
//! Neural networks are composed of building blocks that you can chain together. In
//! dfdx, sequential neural networks are represents by **tuples**! For example,
//! the following two networks are identical:
//!
//! | dfdx | pytorch |
//! | --- | --- |
//! | `(Linear<3, 5>, ReLU, Linear<5, 10>)` | `nn.Sequential(nn.Linear(3, 5), nn.ReLU(), nn.Linear(5, 10))` |
//! | `((Conv2D<3, 2, 1>, Tanh), Conv2D<3, 2, 1>)` | `nn.Sequential(nn.Sequential(nn.Conv2d(3, 2, 1), nn.Tanh()), nn.Conv2d(3, 2, 1))`
//!
//! To build a neural network, you of course need a device:
//!
//! 3. Instantiate models with [crate::nn::DeviceBuildExt]
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! type Model = (Linear<5, 2>, ReLU);
//! let mlp = dev.build_module::<Model, f32>();
//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! let model = dev.build_module::<Model, f32>();
//! ```
//!
//! 4. Pass data through networks with [crate::nn::Module]
//! Note two things:
//! 1. We are using [nn::DeviceBuildExt] to instantiate the model
//! 2. We **need** to pass a dtype (in this case f32) to create the model.
//!
//! You can then pass tensors into the model with [nn::Module::forward()]:
//!
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<5, 2>, f32>();
//! let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! # type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! # let model = dev.build_module::<Model, f32>();
//! // tensor with runtime batch dimension of 10
//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_normal_like(&(10, Const));
//! let y = model.forward(x);
//! ```
//!
//! 5. Trace gradients using [crate::tensor::Trace::trace()]
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true: Tensor<Rank1<5>, f32, _> = dev.sample_normal().softmax();
//! // allocate gradients [ZeroGrads::alloc_grads]
//! let grads = mlp.alloc_grads();
//! # Optimizers and Gradients
//!
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
//! *See [optim] for more information*
//!
//! // `.trace()` clones `x` and inserts a gradient tape.
//! let x_traced: Tensor<Rank1<10>, f32, Cpu, OwnedTape<f32, Cpu>> = x.trace(grads);
//! dfdx supports a number of the standard optimizers:
//!
//! // The tape from the input is moved through the network during .forward().
//! let y: Tensor<Rank1<5>, f32, Cpu, NoneTape> = mlp.forward(x);
//! let y_traced: Tensor<Rank1<5>, f32, Cpu, OwnedTape<f32, Cpu>> = mlp.forward(x_traced);
//! ```
//! | Optimizer | dfdx | pytorch |
//! | --- | --- | --- |
//! | SGD | [optim::Sgd] | `torch.optim.SGD` |
//! | Adam | [optim::Adam] | torch.optim.Adam` |
//! | AdamW | [optim::Adam] with [optim::WeightDecay::Decoupled] | `torch.optim.AdamW` |
//! | RMSprop | [optim::RMSprop] | `torch.optim.RMSprop` |
//!
//! 6. Compute gradients with [crate::tensor_ops::Backward]. See [crate::tensor_ops].
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = mlp.forward(dev.zeros::<Rank1<10>>().trace(Gradients::leaky()));
//! // compute cross entropy loss
//! let loss = cross_entropy_with_logits_loss(y, y_true);
//!
//! // call `backward()` to compute gradients. The tensor *must* have `OwnedTape`!
//! let gradients: Gradients<f32, Cpu> = loss.backward();
//! ```
//! 7. Use an optimizer from [crate::optim] to optimize your network!
//! You can use optimizers to optimize neural networks (or even tensors!). Here's
//! a simple example of how to do this with [nn::ZeroGrads]:
//! ```rust
//! # use dfdx::{prelude::*, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = mlp.forward(dev.zeros::<Rank1<10>>().trace(Gradients::leaky()));
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
//! # let mut gradients: Gradients<f32, Cpu> = loss.backward();
//! // Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum.
//! let mut opt = Sgd::new(&mlp, SgdConfig {
//! lr: 1e-2,
//! momentum: Some(Momentum::Classic(0.9)),
//! weight_decay: None,
//! });
//!
//! // pass the gradients & the mlp into the optimizer's update method
//! opt.update(&mut mlp, &gradients);
//! mlp.zero_grads(&mut gradients);
//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! let mut model = dev.build_module::<Model, f32>();
//! // 1. allocate gradients for the model
//! let mut grads = model.alloc_grads();
//! // 2. create our optimizer
//! let mut opt = Sgd::new(&model, Default::default());
//! // 3. trace gradients through forward pass
//! let x: Tensor<Rank2<10, 3>, f32, _> = dev.sample_normal();
//! let y = model.forward_mut(x.traced(grads));
//! // 4. compute loss & run backpropagation
//! let loss = y.square().mean();
//! grads = loss.backward();
//! // 5. apply gradients
//! opt.update(&mut model, &grads);
//! ```

#![cfg_attr(all(feature = "no-std", not(feature = "std")), no_std)]
Expand Down
35 changes: 23 additions & 12 deletions src/tensor_ops/broadcast_to.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
use crate::{shapes::*, tensor::*};

/// Broadcast self into a new shape.
///
/// **pytorch equivalent** `torch.broadcast_to`.
///
/// Use shape generic or output type to dictate what shape you want:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 7>, f32, _> = dev.zeros();
/// // broadcast axis 1
/// let _: Tensor<Rank3<3, 5, 7>, _, _> = a.clone().broadcast();
/// // broadcast axis 0 and axis 2
/// let _ = a.clone().broadcast::<Rank4<1, 3, 5, 7>, _>();
/// ```
///
/// Use axes generic to dis-ambiguate:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank1<1>, f32, _> = dev.zeros();
/// // It's ambiguous what axes to broadcast here - explicitly say axes 0 and 2
/// let _: Tensor<Rank3<1, 1, 1>, _, _> = a.clone().broadcast::<_, Axes2<0, 2>>();
/// ```
pub trait BroadcastTo: HasErr + HasShape {
/// Broadcast into shape `Dst` along axes `Ax`:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 7>, f32, _> = dev.zeros();
///
/// // broadcast axis 1
/// let _ = a.clone().broadcast::<Rank3<3, 5, 7>, _>();
///
/// // broadcast axis 0 and axis 2
/// let _ = a.clone().broadcast::<Rank4<1, 3, 5, 7>, _>();
/// ```
/// Broadcast into shape `Dst` along axes `Ax`.
fn broadcast<Dst: ConstShape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
Expand Down
10 changes: 10 additions & 0 deletions src/tensor_ops/choose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ pub trait ChooseKernel<E: Dtype>: DeviceStorage {
}

/// Choose values from two tensors using a boolean mask. Equivalent to `torch.where` from pytorch.
///
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let cond: Tensor<Rank1<3>, bool, _> = dev.tensor([true, false, true]);
/// let a: Tensor<Rank1<3>, f32, _> = dev.tensor([1.0, 2.0, 3.0]);
/// let b: Tensor<Rank1<3>, f32, _> = dev.tensor([-1.0, -2.0, -3.0]);
/// let c = cond.choose(a, b);
/// assert_eq!(c.array(), [1.0, -2.0, 3.0]);
/// ```
pub trait ChooseFrom<Lhs, Rhs>: HasErr {
type Output;

Expand Down
31 changes: 22 additions & 9 deletions src/tensor_ops/permute_to.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
use crate::{shapes::*, tensor::*};

/// Changes order of dimensions/axes
/// Changes order of dimensions/axes in a tensor.
///
/// **pytorch equivalent**: `torch.permute`.
///
/// Option 1: Specifying shape generic:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
/// let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<Rank2<3, 2>, _>();
/// assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
/// ```
///
/// Option 2: Specifying axes generic:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
/// let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<_, Axes2<1, 0>>();
/// assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
/// ```
pub trait PermuteTo: HasErr + HasShape {
/// Permutes the tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank3<1, 2, 3>, f32, _> = dev.zeros();
/// let _ = a.clone().permute::<Rank3<3, 2, 1>, _>();
/// let _ = a.clone().permute::<_, Axes3<2, 1, 0>>();
/// ```
/// Permutes the tensor.
fn permute<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: PermuteShapeTo<Dst, Ax>,
Expand Down
Loading