Skip to content

nicksenger/glowstick

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

glowstick

This crate makes working with tensors in Rust safe, easy, and fun by tracking their shapes in the type system!

Example usage with candle:

use candle::{DType, Device};  
use glowstick::{Shape2, num::{U1, U2}, debug_tensor};
use glowstick_candle::{Tensor, matmul};

let a: Tensor<Shape2<U2, U1>> = Tensor::zeros(DType::F32, &Device::Cpu).expect("tensor A");
let b: Tensor<Shape2<U1, U2>> = Tensor::zeros(DType::F32, &Device::Cpu).expect("tensor B");

let c = matmul!(a, b).expect("matmul");
//debug_tensor!(c); // [glowstick shape]: (RANK<_2>, (DIM<_2>, DIM<_2>))

Several operations are available:

use candle::{DType, Device};  
use glowstick::{num::{U0, U1, U2, U4, U3, U64, U5, U8}, Shape2, Shape4};
use glowstick_candle::{Tensor, conv2d, squeeze, unsqueeze, narrow, reshape, transpose, flatten, broadcast_add};

#[allow(unused)]
use glowstick::debug_tensor;

let my_tensor: Tensor<Shape2<U8, U8>> = Tensor::zeros(DType::F32, &Device::Cpu).expect("tensor");
//debug_tensor!(my_tensor); // [glowstick shape]: (RANK<_2>, (DIM<_8>, DIM<_8>))

let reshaped = reshape!(my_tensor, [U64]).expect("reshape"); 
//debug_tensor!(reshaped); // [glowstick shape]: (RANK<_1>, (DIM<_64>))

let unsqueezed = unsqueeze!(reshaped, U0, U2).expect("unsqueeze");
//debug_tensor!(unsqueezed); // [glowstick shape]: (RANK<_3>, (DIM<_1>, DIM<_64>, DIM<_1>))

let squeezed = squeeze!(unsqueezed, U0, U2).expect("squeeze");
//debug_tensor!(squeezed); // [glowstick shape]: (RANK<_1>, (DIM<_64>))

let narrowed = narrow!(squeezed, U0: [U8, U5]).expect("narrow");
//debug_tensor!(narrowed); // [glowstick shape]: (RANK<_1>, (DIM<_5>))

let expanded = broadcast_add!(Tensor::<Shape4<U2, U5, U2, U1>>::zeros(DType::F32, &Device::Cpu).unwrap(), narrowed).expect("add");
//debug_tensor!(expanded); // [glowstick shape]: (RANK<_4>, (DIM<_2>, DIM<_5>, DIM<_2>, DIM<_5>))

let swapped = transpose!(expanded, U1: U2).expect("swap");
//debug_tensor!(swapped); // [glowstick shape]: (RANK<_2>, (DIM<_2>, DIM<_5>, DIM<_5>))

let kernel: Tensor<Shape4<U4, U2, U3, U3>> = Tensor::zeros(DType::F32, &Device::Cpu).expect("kernel");
let conv = conv2d!(swapped, kernel, U0, U1, U1, 1).expect("conv2d");
//debug_tensor!(conv); // [glowstick shape]: (RANK<_4>, (DIM<_2>, DIM<_4>, DIM<_3>, DIM<_3>))

let flattened = flatten!(conv, [U1, U2]).expect("flatten");
//debug_tensor!(swapped); // [glowstick shape]: (RANK<_3>, (DIM<_2>, DIM<_12>, DIM<_3>))

assert_eq!(flattened.inner().dims(), [2, 12, 3]);

For examples of more extensive usage and integration with popular Rust ML frameworks like candle and Burn, check out the examples directory.

The project is currently pre-1.0: breaking changes will be made!

Features

  • Express tensor shapes as types
  • Support for dynamic dimensions (gradual typing)
  • Human-readable error messages (sort of)
  • Manually check type-level shapes (debug_tensor!(_))
  • Support for all ONNX operations

About

Gradual typing for tensor shapes in Rust

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages