Skip to content

Commit

Permalink
Re-using tensor storage when possible (#664)
Browse files Browse the repository at this point in the history
* Tmp commit of phantom tensors

* Fixing all kernels to use ghost tensors

* Adding backward without data

* Adding forward_reuse/backward_without_data for unary ops

* binary kernel reuse

* Reusing more for unary op

* Removing refernece to output data

* Fixing nightly ops

* Refactor to use Result in ops

* Cuda check passing

* Marking kernels as const/df_uses_fx

* Updates

* Not saving inp for sum_to

* Updates recip

* Style

* Prefering to use RHS when possible in cpu kernel

* Prefering RHS for cuda as well
  • Loading branch information
coreylowman committed Apr 5, 2023
1 parent 452d1d8 commit 3ec1042
Show file tree
Hide file tree
Showing 78 changed files with 880 additions and 422 deletions.
4 changes: 2 additions & 2 deletions benches/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ fn main() {

let dev: Dev = Default::default();
let mut m = dev.build_module::<Model, Dtype>();
let mut grads = m.alloc_grads();

loop {
let img: Tensor<InputShape, Dtype, _> = dev.sample_normal();
let grads = m.alloc_grads();

let start = Instant::now();
let out = m.forward_mut(img.traced(grads));
Expand All @@ -33,7 +33,7 @@ fn main() {
let fwd_dur = start.elapsed();

let start = Instant::now();
let _ = loss.backward();
grads = loss.backward();
dev.synchronize();
let bwd_dur = start.elapsed();
println!("fwd={:?} bwd={:?}", fwd_dur, bwd_dur);
Expand Down
2 changes: 1 addition & 1 deletion examples/model-transformer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Demonstrates how to use a transformer module on nightly rust.

fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;

let dev = AutoDevice::default();
type Model = Transformer<16, 4, 3, 3, 8>;
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#[cfg(feature = "nightly")]
fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;

type Model = (
(Conv2D<3, 4, 3>, ReLU),
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#[cfg(feature = "nightly")]
fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;
use std::time::Instant;

type BasicBlock<const C: usize> = Residual<(
Expand Down
12 changes: 4 additions & 8 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ impl<'a, B: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
type Error = D::Err;

fn try_forward(&self, input: Tensor<(B, Const<M>), E, D, T>) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
let shape = *input.shape();
input.try_add(self.beta.retaped::<T>().try_broadcast_like(&shape)?)
}
}

Expand All @@ -146,10 +144,8 @@ impl<'a, B: Dim, S: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
&self,
input: Tensor<(B, S, Const<M>), E, D, T>,
) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
let shape = *input.shape();
input.try_add(self.beta.retaped::<T>().try_broadcast_like(&shape)?)
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/tensor/cpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ impl HasErr for Cpu {
impl DeviceStorage for Cpu {
type Vec<E: Unit> = Vec<E>;

fn try_alloc_grad<E: Unit>(&self, other: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_zeros(other.len())
fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_zeros(len)
}

fn random_u64(&self) -> u64 {
Expand All @@ -78,6 +78,10 @@ impl DeviceStorage for Cpu {
}
}

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize {
v.len()
}

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
let mut buf = Vec::with_capacity(tensor.shape.num_elements());
let mut iter = tensor.iter();
Expand Down
9 changes: 6 additions & 3 deletions src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,18 @@ impl HasErr for Cuda {
impl DeviceStorage for Cuda {
type Vec<E: Unit> = CudaSlice<E>;

fn try_alloc_grad<E: Unit>(&self, other: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
let grad = self.dev.alloc_zeros(other.len())?;
Ok(grad)
fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err> {
Ok(self.dev.alloc_zeros(len)?)
}

fn random_u64(&self) -> u64 {
self.cpu.random_u64()
}

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize {
v.len()
}

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
let buf: Vec<E> = tensor.data.try_clone().unwrap().try_into().unwrap();
debug_assert_eq!(buf.len(), tensor.data.len());
Expand Down
43 changes: 43 additions & 0 deletions src/tensor/ghost.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::{shapes::*, tensor::*};

/// Holds all the information a [Tensor] does, except without
/// holding a reference to the data storage.
///
/// This can held reduce memory usage by decreasing reference
/// count on tensor data, meaning data can be re-used more.
pub struct GhostTensor<S: Shape, E: Unit, D: DeviceStorage> {
pub(crate) id: UniqueId,
pub(crate) len: usize,
pub(crate) shape: S,
pub(crate) strides: S::Concrete,
pub(crate) dev: D,
marker: std::marker::PhantomData<E>,
}

impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensor<S, E, D, T> {
/// Creates a ghost tensor that doesn't hold a reference
/// to the tensor's data.
pub(crate) fn ghost(&self) -> GhostTensor<S, E, D> {
GhostTensor {
id: self.id,
len: self.device.len(&self.data),
shape: self.shape,
strides: self.strides,
dev: self.device.clone(),
marker: std::marker::PhantomData,
}
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::HasErr for GhostTensor<S, E, D> {
type Err = D::Err;
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::AllocGrad
for GhostTensor<S, E, D>
{
type Gradient = D::Vec<E>;
fn try_alloc_grad(&self) -> Result<Self::Gradient, D::Err> {
self.dev.try_alloc_len(self.len)
}
}
29 changes: 17 additions & 12 deletions src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::collections::{BTreeMap, BTreeSet};
use std::{boxed::Box, vec::Vec};

use super::ghost::GhostTensor;
use super::{
storage_traits::{AllocGrad, DeviceStorage},
unique_id, Tensor, UniqueId,
Expand Down Expand Up @@ -50,12 +51,16 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
&mut self,
t: &Tensor<S, E, D>,
) -> Result<&mut D::Vec<E>, D::Err> {
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
let ghost = t.ghost();
self.try_alloc_for(&ghost)?;
Ok(self.get_mut(&ghost))
}

/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<S: Shape>(&mut self, t: &Tensor<S, E, D>) -> Result<(), D::Err> {
pub(crate) fn try_alloc_for<S: Shape>(
&mut self,
t: &GhostTensor<S, E, D>,
) -> Result<(), D::Err> {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id) {
e.insert(t.try_alloc_grad()?);
}
Expand Down Expand Up @@ -89,14 +94,14 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<S: Shape, T>(&mut self, t: &Tensor<S, E, D, T>) -> &mut D::Vec<E> {
pub(crate) fn get_mut<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &mut D::Vec<E> {
self.gradient_by_id.get_mut(&t.id).unwrap()
}

/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_ref<S: Shape, T>(&mut self, t: &Tensor<S, E, D, T>) -> &D::Vec<E> {
pub(crate) fn get_ref<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &D::Vec<E> {
self.gradient_by_id.get(&t.id).unwrap()
}

Expand All @@ -123,8 +128,8 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// **Panics** if `l` and `r` have the same id.
pub(crate) fn mut_and_ref<L: Shape, R: Shape>(
&mut self,
l: &Tensor<L, E, D>,
r: &Tensor<R, E, D>,
l: &GhostTensor<L, E, D>,
r: &GhostTensor<R, E, D>,
) -> (&mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l.id, r.id);
let l_ptr = self.get_mut(l) as *mut _;
Expand All @@ -137,9 +142,9 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Borrows a triplet of gradients `(&mut L1, &mut L2, &R)`.
pub(crate) fn muts_and_ref<L1: Shape, L2: Shape, R: Shape>(
&mut self,
l1: &Tensor<L1, E, D>,
l2: &Tensor<L2, E, D>,
r: &Tensor<R, E, D>,
l1: &GhostTensor<L1, E, D>,
l2: &GhostTensor<L2, E, D>,
r: &GhostTensor<R, E, D>,
) -> (&mut D::Vec<E>, &mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l1.id, l2.id);
assert_ne!(l1.id, r.id);
Expand All @@ -156,8 +161,8 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<Tensor<L, E, D>>,
r: &Tensor<R, E, D>,
ls: &Vec<GhostTensor<L, E, D>>,
r: &GhostTensor<R, E, D>,
) -> (Vec<&mut D::Vec<E>>, &D::Vec<E>) {
for i in 0..ls.len() {
assert_ne!(ls[i].id, r.id);
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
pub(crate) mod cpu;
#[cfg(feature = "cuda")]
pub(crate) mod cuda;
mod ghost;
mod gradients;
mod masks;
#[cfg(feature = "numpy")]
Expand All @@ -138,6 +139,7 @@ mod unique_id;
pub(crate) mod storage_traits;
mod tensor_impls;

pub(crate) use ghost::GhostTensor;
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};

pub use cpu::{Cpu, CpuError};
Expand Down
8 changes: 7 additions & 1 deletion src/tensor/storage_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ pub trait DeviceStorage: 'static + std::fmt::Debug + Default + Clone + HasErr {
fn random_u64(&self) -> u64;

/// Allocates a gradient for the given nd array
fn try_alloc_grad<E: Unit>(&self, storage: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err>;
fn try_alloc_grad<E: Unit>(&self, storage: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_len(self.len(storage))
}

fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err>;

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E>;

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize;

/// Blocks until all work on device to complete. Useful for benchmarking.
fn synchronize(&self) {
self.try_synchronize().unwrap()
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/abs/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use num_traits::Float;

impl<F: Float> UnaryDerivative<F> for super::AbsKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, x: &F) -> F {
x.abs()
Expand Down
16 changes: 15 additions & 1 deletion src/tensor_ops/add/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@ use crate::tensor_ops::cpu_kernels::{BinaryDerivative, UnaryDerivative};
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BinaryAddKernelOp {
const HAS_CONST_DF: bool = true;
#[inline(always)]
fn f(&self, &x: &F, &y: &F) -> F {
x + y
}
#[inline(always)]
fn dfdx(&self, _: &F, _: &F) -> F {
F::one()
self.const_dfdx()
}
#[inline(always)]
fn dfdy(&self, _: &F, _: &F) -> F {
self.const_dfdy()
}
#[inline(always)]
fn const_dfdx(&self) -> F {
F::one()
}
#[inline(always)]
fn const_dfdy(&self) -> F {
F::one()
}
}

impl<F: Float> UnaryDerivative<F> for super::ScalarAddKernelOp<F> {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = true;
#[inline(always)]
fn f(&self, &x: &F) -> F {
x + self.scalar
Expand All @@ -26,4 +36,8 @@ impl<F: Float> UnaryDerivative<F> for super::ScalarAddKernelOp<F> {
fn df(&self, _: &F) -> F {
F::one()
}
#[inline(always)]
fn const_df(&self) -> F {
F::one()
}
}
8 changes: 4 additions & 4 deletions src/tensor_ops/add/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ unsafe impl cudarc::driver::DeviceRepr for Binary {}
const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_add.ptx"));
const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx"));

cuda_unary!(Scalar<f32>, f32, SCALAR_PTX, "sadd_fwd_f32", "sadd_bwd_f32");
cuda_unary!(Scalar<f64>, f64, SCALAR_PTX, "sadd_fwd_f64", "sadd_bwd_f64");
cuda_unary!(const_df() Scalar<f32>, f32, SCALAR_PTX, "sadd_fwd_f32", "sadd_bwd_f32");
cuda_unary!(const_df() Scalar<f64>, f64, SCALAR_PTX, "sadd_fwd_f64", "sadd_bwd_f64");
cuda_binary!(
Binary,
const_df() Binary,
f32,
BINARY_PTX,
"badd_fwd_f32",
"badd_bwd_lhs_f32",
"badd_bwd_rhs_f32"
);
cuda_binary!(
Binary,
const_df() Binary,
f64,
BINARY_PTX,
"badd_fwd_f64",
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/bce/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::tensor_ops::cpu_kernels::BinaryDerivative;
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BCEKernelOp {
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &logit: &F, &prob: &F) -> F {
logit.max(F::zero()) - logit * prob + (F::one() + (-logit.abs()).exp()).ln()
Expand Down
13 changes: 8 additions & 5 deletions src/tensor_ops/choose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,17 @@ impl<
let (rhs, rhs_tape) = rhs.split_tape();

let out = lhs.device.forward(&self, &lhs, &rhs)?;
let phantom_out = out.clone();

let lhs_ghost = lhs.ghost();
let rhs_ghost = rhs.ghost();
let out_ghost = out.ghost();
let mut tape = tape.merge(rhs_tape);
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&lhs)?;
grads.try_alloc_for(&rhs)?;
grads.try_alloc_for(&phantom_out)?;
let (grad_lhs, grad_rhs, grad_out) = grads.muts_and_ref(&lhs, &rhs, &phantom_out);
grads.try_alloc_for(&lhs_ghost)?;
grads.try_alloc_for(&rhs_ghost)?;
grads.try_alloc_for(&out_ghost)?;
let (grad_lhs, grad_rhs, grad_out) =
grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost);
lhs.device
.backward(&self, &lhs, grad_lhs, &rhs, grad_rhs, grad_out)
});
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/clamp/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use num_traits::{clamp, Float};

impl<F: Float + PartialOrd> UnaryDerivative<F> for super::ClampKernelOp<F> {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &x: &F) -> F {
clamp(x, self.min, self.max)
Expand Down
Loading

0 comments on commit 3ec1042

Please sign in to comment.