The project is no longer maintained. The development work is moved to rust-typed-tensor.
Inspired by Tensor Considered Harmful, this project builds tensor type with named dimensions, featuring automatic dimension inference and compile-time bound checking. The development is based on top of tch-rs, a Rust binding for PyTorch, and typenum for compile-time numbers.
The project is still in alpha stage, and is not intended for production. Contributions are welcome!
There's no schedule to publish on crates.io. Put git link to your Cargo.toml
instead.
tch-typed-tensor = { git = "https://github.com/jerry73204/tch-typed-tensor.git", branch = "master" }
The project depends on tch-rs.
It requires extra environment setup to make cargo build
work.
Please study README in tch-rs for details.
The project makes heavy use of trait constructions. It's suggested to take a look at Rust HLists (Heterogenous List) and frunk project before getting started. Also, it's better to be familar with PyTorch API, and sometimes you would visit tch-rs reference.
The tensor type design moves most properties into types, including dimensions, data type and device. It ensures tensor operations are type checked, and empowers by automatic type inference.
Dimensions are named types defined by make_dims!
macro,
but not integer ordinals.
use tch_typed_tensor::{
DimListType,
tensor::NameTensor,
kind::Double,
device::Cpu,
};
use typenum::consts::*;
// make_dims! macro defines a list of dimension names
make_dims! {Batch, Channel, Height, Width}
fn main() {
// Creates a double typed tensor with shape [32, 3, 480, 640]
let tensor = NamedTensor<
DimListType! {(Batch, U32), (Channel, U3), (Height, U480), (Width, U640)}, // dimensions
Float, // data type
Cpu // device
>::zeros();
let double_tensor: NamedTensor<_, Double, _> = cpu_tensor.to_kind::<Double>();
let cuda_tensor: NamedTensor<_, _, Cuda<U0>> = cpu_tensor.to_device::<Cuda<U0>>();
}
The type design keeps bound checking in mind. For example, it verifies whether
select()
index is bound by dimension in compile-time. Otherwise it triggers
compile error.
let tensor = NamedTensor<
DimListType! {(Batch, U32), (Channel, U3), (Height, U480), (Width, U640)},
Double,
Cpu
>::zeros();
// The return type is automatically inferenced
let sub1: NamedTensor<
DimListType! {(Batch, _), (Height, _), (Width, _)},
_,
_
> = tensor.select::<U1, Channel, _>();
// This is more compact syntax
let sub2 = tensor.select::<U1, Channel, _>();
// It triggers compile error because U3 exceeds Channel dimension
// let sub3 = tensor.select::<U3, Channel, _>(); // compile error!
Dimensions are automatically inferred in any tensor operation.
There's no need to explicitly specify returned dimensions.
It can be omitted, or partially specified like
DimListType! {(Batch, _), (Height, _), (Width, _), (Channel, _)}
to work as static assertion.
let bchw_tensor = NamedTensor<
DimListType! {(Batch, U32), (Channel, U3), (Height, U480), (Width, U640)},
Double,
Cpu
>::zeros();
// Change order of dimensions
let bhwc_tensor1: NamedTensor<
DimListType! {(Batch, _), (Height, _), (Width, _), (Channel, _)},
_,
_
> = bchw_tensor.transpose::<TListType! {Batch, Height, Width, Channel}, _>();
// Or use more compact syntax instead
let bhwc_tensor2 = bchw_tensor.transpose::<TListType! {Batch, Height, Width, Channel}, _>();
// Compile error if you miss a dimension here.
// let _ = bchw_tensor.transpose::<TListType! {Batch, Height, Width}, _>(); // compile error!
// Dimension inference also works for reduction operations
let sum_tensor: NamedTensor<
DimListType! {(Batch, U32), (Channel, U3)},
_,
_
> = bhwc_tensor1.sum::<NoKeepDim, Double, TListType! {Width, Height}, _>();
Apache 2.0