Skip to content

Commit

Permalink
#55 reshape, and #87 additional work on nightly feature
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 20, 2022
1 parent d5a19df commit 148826c
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cargo-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: nightly
override: true
- uses: actions-rs/cargo@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cargo-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: nightly
override: true
- uses: actions-rs/cargo@v1
with:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/cargo-fmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: nightly
override: true
components: rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/cargo-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: nightly
override: true
components: rustfmt, clippy
- uses: actions-rs/cargo@v1
with:
command: test
command: +nightly test
2 changes: 1 addition & 1 deletion .github/workflows/clippy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: nightly
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ libc = { version = "0.2", optional = true }

[features]
default = []
nightly = []
cblas = ["dep:cblas-sys", "dep:libc"]
mkl-static-iomp = ["cblas"]
mkl-static-seq = ["cblas"]
Expand Down
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ fn main() -> Result<(), BuildError> {

// If on nightly, enable "nightly" feature
if version_meta().unwrap().channel == Channel::Nightly {
println!("cargo:rustc-cfg=nightly");
println!("cargo:rustc-cfg=feature=\"nightly\"");
}

Ok(())
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]

//! Ergonomics & safety focused deep learning in Rust. Main features include:
//! 1. Const generic tensor library with tensors up to 4d!
//! 2. A large library of tensor operations (matrix multiplication, arithmetic, activation functions, etc).
Expand Down
3 changes: 3 additions & 0 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,6 @@ pub use impl_sum::*;
pub use impl_sum_last::*;
pub use map::*;
pub use matmul::*;

#[cfg(feature = "nightly")]
mod reshape;
71 changes: 71 additions & 0 deletions src/tensor_ops/reshape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use super::utils::move_tape_and_add_backward_op;
use crate::prelude::*;

/// Reshapes `t` into a differently shaped tensor. [T] and [R] must have the same
/// number of elements.
///
/// # Example
/// ```rust
/// # use dfdx::prelude::*;
/// let t = Tensor1D::new([1.0, 2.0, 3.0, 4.0]);
/// let r: Tensor2D<2, 2> = reshape(t);
/// assert_eq!(r.data(), &[[1.0, 2.0], [3.0, 4.0]]);
/// ```
pub fn reshape<T, R>(t: T) -> R
where
T: Tensor<Dtype = f32>,
R: Tensor<Tape = T::Tape, Dtype = f32>,
ConstEq<{ T::Array::NUM_ELEMENTS }, { R::Array::NUM_ELEMENTS }>: ConstTrue,
ConstEq<{ R::Array::NUM_ELEMENTS }, { T::Array::NUM_ELEMENTS }>: ConstTrue,
{
let mut result: R::NoTape = R::NoTape::zeros();
copy(t.data(), result.mut_data());
move_tape_and_add_backward_op(t, result, move |mut t, result, grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &result);
copy(result_grad, t.mut_data());
T::Device::add(t_grad, t.data());
})
}

fn copy<Lhs: CountElements, Rhs: CountElements<Dtype = Lhs::Dtype>>(lhs: &Lhs, rhs: &mut Rhs)
where
ConstEq<{ Lhs::NUM_ELEMENTS }, { Rhs::NUM_ELEMENTS }>: ConstTrue,
{
let l = lhs.ref_first_elem() as *const Lhs::Dtype;
let r = rhs.mut_first_elem() as *mut Lhs::Dtype;
unsafe {
std::ptr::copy_nonoverlapping(l, r, Lhs::NUM_ELEMENTS);
}
}

pub trait ConstTrue {}

pub struct ConstEq<const A: usize, const B: usize>;
impl<const N: usize> ConstTrue for ConstEq<N, N> {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_0d_reshape() {
let a = Tensor0D::new(3.14);
let b: Tensor1D<1> = reshape(a.duplicate());
assert_eq!(b.data(), &[3.14]);

let c: Tensor2D<1, 1> = reshape(a.duplicate());
assert_eq!(c.data(), &[[3.14]]);
}

#[test]
fn test_1d_reshape() {
let a = Tensor1D::new([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
let b: Tensor2D<2, 3, OwnedTape> = reshape(a.trace());
assert_eq!(b.data(), &[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]);
let gradients = b.exp().mean().backward();
assert_eq!(
gradients.ref_gradient(&a),
&[0.18419516, 0.20356713, 0.22497648, 0.24863747, 0.2747869, 0.3036865]
)
}
}

0 comments on commit 148826c

Please sign in to comment.