Skip to content

Commit

Permalink
Adds Storage and Gradient view/mutating methods; Adds grads clamping …
Browse files Browse the repository at this point in the history
…and cliping

- Added `dfdx::nn_traits::WithGrads` trait and `dfdx_derives::WithGrads` proc macro, basead on `ZeroGrads`.
- Added  `dfdx_core::tensor::WithStorage` trait.
- Changed some methods from `Gradients`:
  - Exposed `get_mut` as `pub`.
  - Exposed `get_ref` as `pub`, and lower the requirements from `&mut self` to `&self`.
- Added gradient clamping and cliping methods.
  • Loading branch information
swfsql committed Dec 14, 2023
1 parent 4615ac1 commit 81f644d
Show file tree
Hide file tree
Showing 31 changed files with 536 additions and 28 deletions.
129 changes: 129 additions & 0 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,135 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
}
}

/// Something that can view or mutate a [Gradients] object.
pub trait WithGrads<E: Dtype, D: Device<E>> {
/// View the gradient values for each parameter.
fn grads_element_view<F: FnMut(&E)>(&self, grads: &Gradients<E, D>, f: F) {
self.try_grads_element_view(grads, f).unwrap()
}
/// View the gradient values for each parameter.
fn try_grads_element_view<F: FnMut(&E)>(
&self,
grads: &Gradients<E, D>,
f: F,
) -> Result<(), Error>;
/// View the gradient values for each tensor (unique id).
fn grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) {
self.try_grads_view(grads, f).unwrap()
}
/// View the gradient values for each tensor (unique id).
fn try_grads_view<F: FnMut(&[E])>(&self, grads: &Gradients<E, D>, f: F) -> Result<(), Error>;
/// Mutate the gradient values for each parameter.
fn grads_element_map<F: FnMut(E) -> E>(&self, grads: &mut Gradients<E, D>, f: F) {
self.try_grads_element_map(grads, f).unwrap()
}
/// Mutate the gradient values for each parameter.
fn try_grads_element_map<F: FnMut(E) -> E>(
&self,
grads: &mut Gradients<E, D>,
f: F,
) -> Result<(), crate::tensor::Error>;
/// Mutate the gradient values for each tensor (unique id).
fn grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(&self, grads: &mut Gradients<E, D>, f: F) {
self.try_grads_map(grads, f).unwrap()
}
/// Mutate the gradient values for each tensor (unique id).
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
grads: &mut Gradients<E, D>,
f: F,
) -> Result<(), crate::tensor::Error>;
/// Changes the gradient values for each parameter to be between `min` and `max`.
///
/// Note that this may change the "direction" of your gradients.
fn grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E)
where
E: std::cmp::PartialOrd + Clone,
{
self.try_grads_clamp(grads, min, max).unwrap()
}
/// Changes the gradient values for each parameter to be between `min` and `max`.
///
/// Note that this may change the "direction" of your gradients.
fn try_grads_clamp(&self, grads: &mut Gradients<E, D>, min: E, max: E) -> Result<(), Error>
where
E: std::cmp::PartialOrd + Clone,
{
self.try_grads_element_map(grads, |e| {
if e < min {
min
} else if e > max {
max
} else {
e
}
})
}
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
///
/// Note that this may change the "direction" of your gradients.
fn grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E)
where
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
{
self.try_grads_clip_value(grads, threshold).unwrap()
}
/// Changes the gradient values for each parameter to be between `-threshold` and `+threshold`.
///
/// Note that this may change the "direction" of your gradients.
fn try_grads_clip_value(&self, grads: &mut Gradients<E, D>, threshold: E) -> Result<(), Error>
where
E: std::cmp::PartialOrd + std::ops::Neg<Output = E> + Clone,
{
self.try_grads_clamp(grads, -threshold, threshold)
}
/// Accumulates into `acc` the squared value for the gradients.
///
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
fn grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E)
where
E: num_traits::Zero + std::ops::Mul<Output = E> + num_traits::Float,
{
self.try_grads_norm_squared(grads, acc).unwrap()
}
/// Accumulates into `acc` the squared value for the gradients.
///
/// After the accumulation, taking the sqrt of `acc` results in the gradients norm.
fn try_grads_norm_squared(&self, grads: &Gradients<E, D>, acc: &mut E) -> Result<(), Error>
where
E: std::ops::Mul<Output = E> + num_traits::Float,
{
self.try_grads_element_view(grads, |e| *acc += *e * *e)
}
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
///
/// Note that this doesn't change the "direction" of your gradients.
fn grads_clip_norm(&self, grads: &mut Gradients<E, D>, norm: E, norm_threshold: E)
where
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
{
self.try_grads_clip_norm(grads, norm, norm_threshold)
.unwrap()
}
/// Given a `norm` for all of the gradient values, scales down all gradients so their norm is not higher than `norm_threshold`.
///
/// Note that this doesn't change the "direction" of your gradients.
fn try_grads_clip_norm(
&self,
grads: &mut Gradients<E, D>,
norm: E,
norm_threshold: E,
) -> Result<(), Error>
where
E: Clone + std::cmp::PartialOrd + std::ops::Mul<Output = E> + std::ops::Div<Output = E>,
{
if norm > norm_threshold {
self.try_grads_element_map(grads, |e| norm_threshold * e / norm)?
}
Ok(())
}
}

#[cfg(feature = "safetensors")]
/// Something that can be saved to a .safetensors file.
pub trait SaveSafeTensors {
Expand Down
19 changes: 19 additions & 0 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ macro_rules! tuple_impls {
}
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::WithGrads<Elem, Dev>),+> crate::nn_traits::WithGrads<Elem, Dev> for ($($name,)+) {
fn try_grads_element_view<F: FnMut(&Elem)>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_element_view(grads, &mut f)?;)+
Ok(())
}
fn try_grads_view<F: FnMut(&[Elem])>(&self, grads: &crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_view(grads, &mut f)?;)+
Ok(())
}
fn try_grads_element_map<F: FnMut(Elem) -> Elem>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_element_map(grads, &mut f)?;)+
Ok(())
}
fn try_grads_map<F: FnMut(Vec<Elem>) -> Option<Vec<Elem>>>(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>, mut f: F) -> Result<(), Error> {
$(self.$idx.try_grads_map(grads, &mut f)?;)+
Ok(())
}
}

/*This macro expands like this for a 4-tuple:
impl<
Expand Down
45 changes: 45 additions & 0 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,51 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_tra
}
}

impl<E: Dtype, D: Device<E>, T: crate::nn_traits::WithGrads<E, D>> crate::nn_traits::WithGrads<E, D>
for Vec<T>
{
fn try_grads_element_view<F: FnMut(&E)>(
&self,
grads: &crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_element_view(grads, &mut f)?;
}
Ok(())
}
fn try_grads_view<F: FnMut(&[E])>(
&self,
grads: &crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_view(grads, &mut f)?;
}
Ok(())
}
fn try_grads_element_map<F: FnMut(E) -> E>(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_element_map(grads, &mut f)?;
}
Ok(())
}
fn try_grads_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
mut f: F,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_grads_map(grads, &mut f)?;
}
Ok(())
}
}

#[cfg(feature = "safetensors")]
impl<T: crate::nn_traits::SaveSafeTensors> crate::nn_traits::SaveSafeTensors for Vec<T> {
fn write_safetensors(
Expand Down
42 changes: 42 additions & 0 deletions dfdx-core/src/tensor/cpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,48 @@ impl<E: Unit> ZeroFillStorage<E> for Cpu {
}
}

impl<E: Unit> WithStorage<E> for Cpu {
/// View the values by each element (in-place).
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
for e in storage.iter() {
f(e);
}
Ok(())
}
/// View the values by a [Vec] (in-place).
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
f(storage.data.as_slice());
Ok(())
}
/// Mutates the values by each element (in-place).
fn try_element_map<F: FnMut(E) -> E>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
for e in storage.iter_mut() {
let fe = f(*e);
*e = fe;
}
Ok(())
}
/// Mutates a clone of the values (not in-place).
///
/// If `Some` is returned, replaces the changed values back into the object.
/// Otherwise if `None` is returned, the changed values are discarded and the object stays intact.
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let storage_copy = storage.data.clone();
if let Some(fstorage) = f(storage_copy) {
storage.data.copy_from_slice(&fstorage);
}
Ok(())
}
}

impl<E: Unit> OnesTensor<E> for Cpu {
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
Expand Down
47 changes: 47 additions & 0 deletions dfdx-core/src/tensor/cuda/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ impl<E: Unit> ZeroFillStorage<E> for Cuda {
}
}

impl<E: Unit> WithStorage<E> for Cuda {
/// View a copy of the values by each element (not in-place).
fn try_element_view<F: FnMut(&E)>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
for e in v.iter() {
f(e);
}
Ok(())
}
/// View a copy of the values by a [Vec] (not in-place).
fn try_view<F: FnMut(&[E])>(&self, storage: &Self::Vec, mut f: F) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
f(v.as_slice());
Ok(())
}
/// Mutates a copy of the values by each element (not in-place).
/// Then the values in Cuda memory are replaced by the changed values.
fn try_element_map<F: FnMut(E) -> E>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let mut v = self.dev.dtoh_sync_copy(storage)?;
for e in v.iter_mut() {
let fe = (&mut f)(*e);
*e = fe;
}
self.dev.htod_copy_into(v, storage)?;
Ok(())
}
/// Mutates a copy of the values (not in-place).
///
/// If `Some` is returned, the values in Cuda memory are replaced by the changed values.
/// Otherwise if `None` is returned, the values in Cuda memory are left intact.
fn try_map<F: FnMut(Vec<E>) -> Option<Vec<E>>>(
&self,
storage: &mut Self::Vec,
mut f: F,
) -> Result<(), Error> {
let v = self.dev.dtoh_sync_copy(storage)?;
if let Some(fv) = (&mut f)(v) {
self.dev.htod_copy_into(fv, storage)?;
}
Ok(())
}
}

impl<E: Unit> OnesTensor<E> for Cuda
where
Cpu: OnesTensor<E>,
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ impl<E, D: Storage<E>> Gradients<E, D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec {
pub fn get_mut<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &mut D::Vec {
self.gradient_by_id.get_mut(&t.id()).unwrap()
}

/// Returns an immutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_ref<S: Shape>(&mut self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
pub fn get_ref<S: Shape>(&self, t: &impl Tensorlike<S, E, D>) -> &D::Vec {
self.gradient_by_id.get(&t.id()).unwrap()
}

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ mod tensor_impls;

pub use error::Error;
pub(crate) use ghost::GhostTensor;
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};
pub(crate) use storage_traits::{OneFillStorage, WithStorage, ZeroFillStorage};
pub use tensorlike::Tensorlike;

pub use cpu::Cpu;
Expand Down
Loading

0 comments on commit 81f644d

Please sign in to comment.