Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-using tensor storage when possible #664

Merged
merged 22 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benches/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ fn main() {

let dev: Dev = Default::default();
let mut m = dev.build_module::<Model, Dtype>();
let mut grads = m.alloc_grads();

loop {
let img: Tensor<InputShape, Dtype, _> = dev.sample_normal();
let grads = m.alloc_grads();

let start = Instant::now();
let out = m.forward_mut(img.traced(grads));
Expand All @@ -33,7 +33,7 @@ fn main() {
let fwd_dur = start.elapsed();

let start = Instant::now();
let _ = loss.backward();
grads = loss.backward();
dev.synchronize();
let bwd_dur = start.elapsed();
println!("fwd={:?} bwd={:?}", fwd_dur, bwd_dur);
Expand Down
2 changes: 1 addition & 1 deletion examples/model-transformer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Demonstrates how to use a transformer module on nightly rust.

fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;

let dev = AutoDevice::default();
type Model = Transformer<16, 4, 3, 3, 8>;
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#[cfg(feature = "nightly")]
fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;

type Model = (
(Conv2D<3, 4, 3>, ReLU),
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#[cfg(feature = "nightly")]
fn main() {
use dfdx::{prelude::*, tensor::AutoDevice};
use dfdx::prelude::*;
use std::time::Instant;

type BasicBlock<const C: usize> = Residual<(
Expand Down
12 changes: 4 additions & 8 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ impl<'a, B: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
type Error = D::Err;

fn try_forward(&self, input: Tensor<(B, Const<M>), E, D, T>) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
let shape = *input.shape();
input.try_add(self.beta.retaped::<T>().try_broadcast_like(&shape)?)
}
}

Expand All @@ -146,10 +144,8 @@ impl<'a, B: Dim, S: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
&self,
input: Tensor<(B, S, Const<M>), E, D, T>,
) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
let shape = *input.shape();
input.try_add(self.beta.retaped::<T>().try_broadcast_like(&shape)?)
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/tensor/cpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ impl HasErr for Cpu {
impl DeviceStorage for Cpu {
type Vec<E: Unit> = Vec<E>;

fn try_alloc_grad<E: Unit>(&self, other: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_zeros(other.len())
fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_zeros(len)
}

fn random_u64(&self) -> u64 {
Expand All @@ -78,6 +78,10 @@ impl DeviceStorage for Cpu {
}
}

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize {
v.len()
}

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
let mut buf = Vec::with_capacity(tensor.shape.num_elements());
let mut iter = tensor.iter();
Expand Down
9 changes: 6 additions & 3 deletions src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,18 @@ impl HasErr for Cuda {
impl DeviceStorage for Cuda {
type Vec<E: Unit> = CudaSlice<E>;

fn try_alloc_grad<E: Unit>(&self, other: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
let grad = self.dev.alloc_zeros(other.len())?;
Ok(grad)
fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err> {
Ok(self.dev.alloc_zeros(len)?)
}

fn random_u64(&self) -> u64 {
self.cpu.random_u64()
}

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize {
v.len()
}

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E> {
let buf: Vec<E> = tensor.data.try_clone().unwrap().try_into().unwrap();
debug_assert_eq!(buf.len(), tensor.data.len());
Expand Down
43 changes: 43 additions & 0 deletions src/tensor/ghost.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::{shapes::*, tensor::*};

/// Holds all the information a [Tensor] does, except without
/// holding a reference to the data storage.
///
/// This can held reduce memory usage by decreasing reference
/// count on tensor data, meaning data can be re-used more.
pub struct GhostTensor<S: Shape, E: Unit, D: DeviceStorage> {
pub(crate) id: UniqueId,
pub(crate) len: usize,
pub(crate) shape: S,
pub(crate) strides: S::Concrete,
pub(crate) dev: D,
marker: std::marker::PhantomData<E>,
}

impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensor<S, E, D, T> {
/// Creates a ghost tensor that doesn't hold a reference
/// to the tensor's data.
pub(crate) fn ghost(&self) -> GhostTensor<S, E, D> {
GhostTensor {
id: self.id,
len: self.device.len(&self.data),
shape: self.shape,
strides: self.strides,
dev: self.device.clone(),
marker: std::marker::PhantomData,
}
}
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::HasErr for GhostTensor<S, E, D> {
type Err = D::Err;
}

impl<S: Shape, E: Unit, D: DeviceStorage> super::storage_traits::AllocGrad
for GhostTensor<S, E, D>
{
type Gradient = D::Vec<E>;
fn try_alloc_grad(&self) -> Result<Self::Gradient, D::Err> {
self.dev.try_alloc_len(self.len)
}
}
29 changes: 17 additions & 12 deletions src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::collections::{BTreeMap, BTreeSet};
use std::{boxed::Box, vec::Vec};

use super::ghost::GhostTensor;
use super::{
storage_traits::{AllocGrad, DeviceStorage},
unique_id, Tensor, UniqueId,
Expand Down Expand Up @@ -50,12 +51,16 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
&mut self,
t: &Tensor<S, E, D>,
) -> Result<&mut D::Vec<E>, D::Err> {
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
let ghost = t.ghost();
self.try_alloc_for(&ghost)?;
Ok(self.get_mut(&ghost))
}

/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<S: Shape>(&mut self, t: &Tensor<S, E, D>) -> Result<(), D::Err> {
pub(crate) fn try_alloc_for<S: Shape>(
&mut self,
t: &GhostTensor<S, E, D>,
) -> Result<(), D::Err> {
if let std::collections::btree_map::Entry::Vacant(e) = self.gradient_by_id.entry(t.id) {
e.insert(t.try_alloc_grad()?);
}
Expand Down Expand Up @@ -89,14 +94,14 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<S: Shape, T>(&mut self, t: &Tensor<S, E, D, T>) -> &mut D::Vec<E> {
pub(crate) fn get_mut<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &mut D::Vec<E> {
self.gradient_by_id.get_mut(&t.id).unwrap()
}

/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_ref<S: Shape, T>(&mut self, t: &Tensor<S, E, D, T>) -> &D::Vec<E> {
pub(crate) fn get_ref<S: Shape>(&mut self, t: &GhostTensor<S, E, D>) -> &D::Vec<E> {
self.gradient_by_id.get(&t.id).unwrap()
}

Expand All @@ -123,8 +128,8 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// **Panics** if `l` and `r` have the same id.
pub(crate) fn mut_and_ref<L: Shape, R: Shape>(
&mut self,
l: &Tensor<L, E, D>,
r: &Tensor<R, E, D>,
l: &GhostTensor<L, E, D>,
r: &GhostTensor<R, E, D>,
) -> (&mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l.id, r.id);
let l_ptr = self.get_mut(l) as *mut _;
Expand All @@ -137,9 +142,9 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
/// Borrows a triplet of gradients `(&mut L1, &mut L2, &R)`.
pub(crate) fn muts_and_ref<L1: Shape, L2: Shape, R: Shape>(
&mut self,
l1: &Tensor<L1, E, D>,
l2: &Tensor<L2, E, D>,
r: &Tensor<R, E, D>,
l1: &GhostTensor<L1, E, D>,
l2: &GhostTensor<L2, E, D>,
r: &GhostTensor<R, E, D>,
) -> (&mut D::Vec<E>, &mut D::Vec<E>, &D::Vec<E>) {
assert_ne!(l1.id, l2.id);
assert_ne!(l1.id, r.id);
Expand All @@ -156,8 +161,8 @@ impl<E: Unit, D: DeviceStorage> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<Tensor<L, E, D>>,
r: &Tensor<R, E, D>,
ls: &Vec<GhostTensor<L, E, D>>,
r: &GhostTensor<R, E, D>,
) -> (Vec<&mut D::Vec<E>>, &D::Vec<E>) {
for i in 0..ls.len() {
assert_ne!(ls[i].id, r.id);
Expand Down
2 changes: 2 additions & 0 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
pub(crate) mod cpu;
#[cfg(feature = "cuda")]
pub(crate) mod cuda;
mod ghost;
mod gradients;
mod masks;
#[cfg(feature = "numpy")]
Expand All @@ -138,6 +139,7 @@ mod unique_id;
pub(crate) mod storage_traits;
mod tensor_impls;

pub(crate) use ghost::GhostTensor;
pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage};

pub use cpu::{Cpu, CpuError};
Expand Down
8 changes: 7 additions & 1 deletion src/tensor/storage_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ pub trait DeviceStorage: 'static + std::fmt::Debug + Default + Clone + HasErr {
fn random_u64(&self) -> u64;

/// Allocates a gradient for the given nd array
fn try_alloc_grad<E: Unit>(&self, storage: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err>;
fn try_alloc_grad<E: Unit>(&self, storage: &Self::Vec<E>) -> Result<Self::Vec<E>, Self::Err> {
self.try_alloc_len(self.len(storage))
}

fn try_alloc_len<E: Unit>(&self, len: usize) -> Result<Self::Vec<E>, Self::Err>;

fn tensor_to_vec<S: Shape, E: Unit, T>(&self, tensor: &Tensor<S, E, Self, T>) -> Vec<E>;

fn len<E: Unit>(&self, v: &Self::Vec<E>) -> usize;

/// Blocks until all work on device to complete. Useful for benchmarking.
fn synchronize(&self) {
self.try_synchronize().unwrap()
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/abs/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use num_traits::Float;

impl<F: Float> UnaryDerivative<F> for super::AbsKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, x: &F) -> F {
x.abs()
Expand Down
16 changes: 15 additions & 1 deletion src/tensor_ops/add/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@ use crate::tensor_ops::cpu_kernels::{BinaryDerivative, UnaryDerivative};
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BinaryAddKernelOp {
const HAS_CONST_DF: bool = true;
#[inline(always)]
fn f(&self, &x: &F, &y: &F) -> F {
x + y
}
#[inline(always)]
fn dfdx(&self, _: &F, _: &F) -> F {
F::one()
self.const_dfdx()
}
#[inline(always)]
fn dfdy(&self, _: &F, _: &F) -> F {
self.const_dfdy()
}
#[inline(always)]
fn const_dfdx(&self) -> F {
F::one()
}
#[inline(always)]
fn const_dfdy(&self) -> F {
F::one()
}
}

impl<F: Float> UnaryDerivative<F> for super::ScalarAddKernelOp<F> {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = true;
#[inline(always)]
fn f(&self, &x: &F) -> F {
x + self.scalar
Expand All @@ -26,4 +36,8 @@ impl<F: Float> UnaryDerivative<F> for super::ScalarAddKernelOp<F> {
fn df(&self, _: &F) -> F {
F::one()
}
#[inline(always)]
fn const_df(&self) -> F {
F::one()
}
}
8 changes: 4 additions & 4 deletions src/tensor_ops/add/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ unsafe impl cudarc::driver::DeviceRepr for Binary {}
const SCALAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/scalar_add.ptx"));
const BINARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx"));

cuda_unary!(Scalar<f32>, f32, SCALAR_PTX, "sadd_fwd_f32", "sadd_bwd_f32");
cuda_unary!(Scalar<f64>, f64, SCALAR_PTX, "sadd_fwd_f64", "sadd_bwd_f64");
cuda_unary!(const_df() Scalar<f32>, f32, SCALAR_PTX, "sadd_fwd_f32", "sadd_bwd_f32");
cuda_unary!(const_df() Scalar<f64>, f64, SCALAR_PTX, "sadd_fwd_f64", "sadd_bwd_f64");
cuda_binary!(
Binary,
const_df() Binary,
f32,
BINARY_PTX,
"badd_fwd_f32",
"badd_bwd_lhs_f32",
"badd_bwd_rhs_f32"
);
cuda_binary!(
Binary,
const_df() Binary,
f64,
BINARY_PTX,
"badd_fwd_f64",
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/bce/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::tensor_ops::cpu_kernels::BinaryDerivative;
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BCEKernelOp {
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &logit: &F, &prob: &F) -> F {
logit.max(F::zero()) - logit * prob + (F::one() + (-logit.abs()).exp()).ln()
Expand Down
13 changes: 8 additions & 5 deletions src/tensor_ops/choose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,17 @@ impl<
let (rhs, rhs_tape) = rhs.split_tape();

let out = lhs.device.forward(&self, &lhs, &rhs)?;
let phantom_out = out.clone();

let lhs_ghost = lhs.ghost();
let rhs_ghost = rhs.ghost();
let out_ghost = out.ghost();
let mut tape = tape.merge(rhs_tape);
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&lhs)?;
grads.try_alloc_for(&rhs)?;
grads.try_alloc_for(&phantom_out)?;
let (grad_lhs, grad_rhs, grad_out) = grads.muts_and_ref(&lhs, &rhs, &phantom_out);
grads.try_alloc_for(&lhs_ghost)?;
grads.try_alloc_for(&rhs_ghost)?;
grads.try_alloc_for(&out_ghost)?;
let (grad_lhs, grad_rhs, grad_out) =
grads.muts_and_ref(&lhs_ghost, &rhs_ghost, &out_ghost);
lhs.device
.backward(&self, &lhs, grad_lhs, &rhs, grad_rhs, grad_out)
});
Expand Down
1 change: 1 addition & 0 deletions src/tensor_ops/clamp/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use num_traits::{clamp, Float};

impl<F: Float + PartialOrd> UnaryDerivative<F> for super::ClampKernelOp<F> {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &x: &F) -> F {
clamp(x, self.min, self.max)
Expand Down
Loading