Skip to content

Commit

Permalink
impl BuildModule for ZeroSizedModule (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Feb 21, 2023
1 parent fc38a83 commit 1f89f7f
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 51 deletions.
14 changes: 1 addition & 13 deletions src/nn/activations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};

use super::module::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
use super::module::{Module, NonMutableModule, ZeroSizedModule};

macro_rules! activation_impls {
($struct_name:ident, $func_name:ident, #[$docstring:meta]) => {
Expand All @@ -11,12 +11,6 @@ macro_rules! activation_impls {
impl ZeroSizedModule for $struct_name {}
impl NonMutableModule for $struct_name {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for $struct_name {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<D>> Module<Tensor<S, E, D, T>>
for $struct_name
{
Expand Down Expand Up @@ -47,12 +41,6 @@ pub struct Softmax;
impl ZeroSizedModule for Softmax {}
impl NonMutableModule for Softmax {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for Softmax {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<Ax: Axes, S: Shape<LastAxis = Ax> + ReduceShape<Ax>, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<S, E, D, T>> for Softmax
{
Expand Down
14 changes: 1 addition & 13 deletions src/nn/dropout.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::*, shapes::*, tensor::Tensor, tensor_ops::*};

use super::{BuildModule, Module, ModuleMut, ZeroSizedModule};
use super::{Module, ModuleMut, ZeroSizedModule};

/// Does nothing as a [Module], and calls [dropout()] as [ModuleMut] with probability `1.0 / N`.
///
Expand Down Expand Up @@ -45,12 +45,6 @@ pub struct DropoutOneIn<const N: usize>;

impl<const N: usize> ZeroSizedModule for DropoutOneIn<N> {}

impl<const N: usize, D: Device<E>, E: Dtype> BuildModule<D, E> for DropoutOneIn<N> {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<const N: usize, S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>>
for DropoutOneIn<N>
{
Expand Down Expand Up @@ -123,12 +117,6 @@ impl Default for Dropout {

impl ZeroSizedModule for Dropout {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for Dropout {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>> for Dropout {
type Output = Tensor<S, E, D, NoneTape>;
/// Does nothing.
Expand Down
13 changes: 3 additions & 10 deletions src/nn/flatten.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#[allow(unused)]
use crate::{gradients::Tape, shapes::*, tensor::Tensor, tensor_ops::*};

#[allow(unused)]
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
use super::{NonMutableModule, ZeroSizedModule};

/// **Requires Nightly** Flattens 3d tensors to 1d, and 4d tensors to 2d.
#[derive(Default, Clone, Copy)]
Expand All @@ -11,15 +10,9 @@ pub struct Flatten2D;
impl ZeroSizedModule for Flatten2D {}
impl NonMutableModule for Flatten2D {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for Flatten2D {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

#[cfg(feature = "nightly")]
impl<const C: usize, const H: usize, const W: usize, D: Device<E>, E: Dtype, T: Tape<D>>
Module<Tensor<Rank3<C, H, W>, E, D, T>> for Flatten2D
super::Module<Tensor<Rank3<C, H, W>, E, D, T>> for Flatten2D
where
Rank3<C, H, W>: HasSameNumelAs<Rank1<{ C * H * W }>>,
{
Expand All @@ -31,7 +24,7 @@ where

#[cfg(feature = "nightly")]
impl<const B: usize, const C: usize, const H: usize, const W: usize, D, E: Dtype, T: Tape<D>>
Module<Tensor<Rank4<B, C, H, W>, E, D, T>> for Flatten2D
super::Module<Tensor<Rank4<B, C, H, W>, E, D, T>> for Flatten2D
where
D: Device<E>,
Rank4<B, C, H, W>: HasSameNumelAs<Rank2<B, { C * H * W }>>,
Expand Down
6 changes: 6 additions & 0 deletions src/nn/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ impl<T: ZeroSizedModule + BuildModule<D, E>, D: DeviceStorage, E: Dtype> BuildOn
type Built = T;
}

impl<E: Dtype, D: DeviceStorage, T: ZeroSizedModule> BuildModule<D, E> for T {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<T: ZeroSizedModule, D: DeviceStorage, E: Dtype> ResetParams<D, E> for T {
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
Ok(())
Expand Down
8 changes: 0 additions & 8 deletions src/nn/pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#[cfg(feature = "nightly")]
use crate::tensor_ops::{ConstAvgPool2D, ConstMaxPool2D, ConstMinPool2D};

use crate::{shapes::Dtype, tensor_ops::Device};

#[allow(unused)]
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};

Expand Down Expand Up @@ -41,12 +39,6 @@ macro_rules! impl_pools {
impl<const K: usize, const S: usize, const P: usize> ZeroSizedModule for $PoolTy<K, S, P> {}
impl<const K: usize, const S: usize, const P: usize> NonMutableModule for $PoolTy<K, S, P> {}

impl<const K: usize, const S: usize, const P: usize, D: Device<E>, E: Dtype> BuildModule<D, E> for $PoolTy<K, S, P> {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

#[cfg(feature = "nightly")]
impl<const K: usize, const S: usize, const P: usize, Img: $Trait<K, S, P>> Module<Img>
for $PoolTy<K, S, P>
Expand Down
8 changes: 1 addition & 7 deletions src/nn/pool_global.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{gradients::*, shapes::*, tensor::*, tensor_ops::*};

use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
use super::{Module, NonMutableModule, ZeroSizedModule};

/// Applies average pooling over an entire image, fully reducing the height and width
/// dimensions:
Expand Down Expand Up @@ -61,12 +61,6 @@ macro_rules! impl_pools {
impl ZeroSizedModule for $PoolTy {}
impl NonMutableModule for $PoolTy {}

impl<D: Device<E>, E: Dtype> BuildModule<D, E> for $PoolTy {
fn try_build(_: &D) -> Result<Self, <D>::Err> {
Ok(Default::default())
}
}

impl<C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<(C, H, W), E, D, T>> for $PoolTy
{
Expand Down

0 comments on commit 1f89f7f

Please sign in to comment.