Skip to content

Commit

Permalink
Not saving inp for sum_to
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Apr 5, 2023
1 parent ffd6253 commit 085ec54
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/tensor_ops/sum_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
tensor::{Cpu, Tensor, ZerosTensor},
tensor::{Cpu, Tensor, ZerosTensor, GhostTensor},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

Expand Down Expand Up @@ -39,7 +39,7 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
_dst: Dst,
inp: &Tensor<Src, E, Self>,
inp: &GhostTensor<Src, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>
Expand All @@ -49,7 +49,7 @@ impl<E: Dtype> super::SumKernel<E> for Cpu {
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(grad_out.len(), 1);
let v = grad_out[0];
let scale = E::from_usize(inp.shape.num_elements() / inp.data.len()).unwrap();
let scale = E::from_usize(inp.shape.num_elements() / inp.len).unwrap();
for i in grad_inp.iter_mut() {
*i += v * scale;
}
Expand Down
6 changes: 3 additions & 3 deletions src/tensor_ops/sum_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Tensor, GhostTensor},
tensor_ops::reduction_utils::*,
};

Expand Down Expand Up @@ -80,7 +80,7 @@ where
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Tensor<Src, E, Self>,
inp: &GhostTensor<Src, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>
Expand All @@ -91,7 +91,7 @@ where

let out_strides: Src::Concrete =
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
let physical_numel = inp.data.len();
let physical_numel = inp.len;
let elems_per_thread = E::from_usize(reduction_elems_per_thread::<_, Src>(
inp.shape.concrete(),
inp.strides,
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/sum_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait SumKernel<E: Dtype>: DeviceStorage {
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Tensor<Src, E, Self>,
inp: &GhostTensor<Src, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err>
Expand Down Expand Up @@ -71,7 +71,7 @@ impl<S: Shape, E: Dtype, D: SumKernel<E>, T: Tape<E, D>> SumTo for Tensor<S, E,
grads.try_alloc_for(&inp_ghost)?;
grads.try_alloc_for(&out_ghost)?;
let (grad_inp, grad_out) = grads.mut_and_ref(&inp_ghost, &out_ghost);
inp.device.backward(dst, &inp, grad_inp, grad_out)
inp.device.backward(dst, &inp_ghost, grad_inp, grad_out)
});
Ok(out.put_tape(tape))
}
Expand Down

0 comments on commit 085ec54

Please sign in to comment.