Skip to content

Commit

Permalink
allow convtrans2d to use dynamic dimensions (#639)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Mar 30, 2023
1 parent 922ccd3 commit a6914f2
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions src/tensor_ops/convtrans2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

pub trait ConvTransAlgebra<const K: usize, const S: usize, const P: usize>: ConstDim {
type Convolved: ConstDim;
pub trait ConvTransAlgebra<const K: usize, const S: usize, const P: usize>: Dim {
type Convolved: Dim;

fn convolve_dim(&self) -> Self::Convolved;
}

impl<const D: usize, const K: usize, const S: usize, const P: usize> ConvTransAlgebra<K, S, P>
Expand All @@ -86,6 +88,18 @@ where
Const<{ D * S + K - S - 2 * P }>: Sized,
{
type Convolved = Const<{ D * S + K - S - 2 * P }>;

fn convolve_dim(&self) -> Self::Convolved {
Default::default()
}
}

impl<const K: usize, const S: usize, const P: usize> ConvTransAlgebra<K, S, P> for usize {
type Convolved = usize;

fn convolve_dim(&self) -> Self::Convolved {
(self * S + K).checked_sub(S + 2 * P).unwrap()
}
}

pub trait TryConvTrans2DTo<F, const S: usize, const P: usize>: HasErr {
Expand Down Expand Up @@ -118,40 +132,34 @@ impl<T, F> TryConvTrans2D<F> for T {}

impl<
const C: usize,
const H: usize,
const W: usize,
H: Dim + ConvTransAlgebra<K, S, P>,
W: Dim + ConvTransAlgebra<K, S, P>,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
E: Dtype,
D: ConvTrans2DKernel<E> + ZerosTensor<E>,
T: 'static + Tape<E, D>,
> TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> for Tensor<Rank3<C, H, W>, E, D, T>
where
Const<H>: ConvTransAlgebra<K, S, P>,
Const<W>: ConvTransAlgebra<K, S, P>,
> TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P>
for Tensor<(Const<C>, H, W), E, D, T>
{
type Output = Tensor<
(
Const<O>,
<Const<H> as ConvTransAlgebra<K, S, P>>::Convolved,
<Const<W> as ConvTransAlgebra<K, S, P>>::Convolved,
),
E,
D,
T,
>;
type Output = Tensor<(Const<O>, H::Convolved, W::Convolved), E, D, T>;

fn try_convtrans2d_to(
self,
filters: Tensor<Rank4<O, C, K, K>, E, D>,
) -> Result<Self::Output, Self::Err> {
let op = ConvTrans2DOp::new(S, P, K, [1, C, H, W], O);
let h = self.shape.1;
let w = self.shape.2;

let op = ConvTrans2DOp::new(S, P, K, [1, C, h.size(), w.size()], O);
let (lhs, ltape) = self.split_tape();
let (rhs, rtape) = filters.split_tape();
let mut tape = ltape.merge(rtape);
let mut out = lhs.device.try_zeros()?;
let mut out = lhs
.device
.try_zeros_like(&(Const, h.convolve_dim(), w.convolve_dim()))?;
lhs.device.forward(op, &lhs, &rhs, &mut out)?;
let phantom_out = out.clone();
tape.try_alloc_grad(&lhs)?;
Expand All @@ -169,8 +177,8 @@ where
impl<
B: Dim,
const C: usize,
const H: usize,
const W: usize,
H: Dim + ConvTransAlgebra<K, S, P>,
W: Dim + ConvTransAlgebra<K, S, P>,
const O: usize,
const K: usize,
const S: usize,
Expand All @@ -179,33 +187,23 @@ impl<
D: ConvTrans2DKernel<E> + ZerosTensor<E>,
T: 'static + Tape<E, D>,
> TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P>
for Tensor<(B, Const<C>, Const<H>, Const<W>), E, D, T>
where
Const<H>: ConvTransAlgebra<K, S, P>,
Const<W>: ConvTransAlgebra<K, S, P>,
for Tensor<(B, Const<C>, H, W), E, D, T>
{
type Output = Tensor<
(
B,
Const<O>,
<Const<H> as ConvTransAlgebra<K, S, P>>::Convolved,
<Const<W> as ConvTransAlgebra<K, S, P>>::Convolved,
),
E,
D,
T,
>;
type Output = Tensor<(B, Const<O>, H::Convolved, W::Convolved), E, D, T>;
fn try_convtrans2d_to(
self,
filters: Tensor<Rank4<O, C, K, K>, E, D>,
) -> Result<Self::Output, Self::Err> {
let h = self.shape.2;
let w = self.shape.3;

let batch = self.shape().0;
let op = ConvTrans2DOp::new(S, P, K, [batch.size(), C, H, W], O);
let op = ConvTrans2DOp::new(S, P, K, [batch.size(), C, h.size(), w.size()], O);
let (lhs, ltape) = self.split_tape();
let (rhs, rtape) = filters.split_tape();
let mut out =
lhs.device
.try_zeros_like(&(batch, Const, Default::default(), Default::default()))?;
.try_zeros_like(&(batch, Const, h.convolve_dim(), w.convolve_dim()))?;
let mut tape = ltape.merge(rtape);
lhs.device.forward(op, &lhs, &rhs, &mut out)?;
let phantom_out = out.clone();
Expand Down

0 comments on commit a6914f2

Please sign in to comment.