Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making Conv2D unbiased by default, and adding Bias2D module #494

Merged
merged 2 commits into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/nn/bias2d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};

use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice};

pub mod builder {
#[derive(Debug)]
pub struct Bias2D<const CHAN: usize>;
}

impl<const C: usize, E: Dtype, D: Device<E>> BuildOnDevice<D, E> for builder::Bias2D<C>
where
Bias2D<C, E, D>: BuildModule<D, E>,
{
type Built = Bias2D<C, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
}

/// Adds a learnable 1d bias to 3d and 4d inputs. Can be used with [crate::nn::modules::Conv2D]
/// to create a Biased conv.
///
/// Example:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// const NUM_CHANS: usize = 5;
/// type Model = Bias2D<NUM_CHANS>;
/// let model = dev.build_module::<Model, f32>();
///
/// // 3d input
/// let x: Tensor<Rank3<NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
///
/// // 4d input
/// let x: Tensor<Rank4<10, NUM_CHANS, 2, 3>, f32, _> = dev.sample_normal();
/// model.forward(x);
/// ```
#[derive(Clone, Debug)]
pub struct Bias2D<const C: usize, E: Dtype, D: DeviceStorage> {
pub bias: Tensor<Rank1<C>, E, D>,
}

impl<const C: usize, E: Dtype, D: Device<E>> BuildModule<D, E> for Bias2D<C, E, D> {
fn try_build(device: &D) -> Result<Self, <D>::Err> {
Ok(Self {
bias: device.try_zeros()?,
})
}
}

impl<const C: usize, E: Dtype, D: DeviceStorage> NonMutableModule for Bias2D<C, E, D> {}

impl<const C: usize, E: Dtype, D1: Device<E>, D2: Device<E>> ToDevice<D2> for Bias2D<C, E, D1> {
type Output = Bias2D<C, E, D2>;

fn to_device(&self, device: &D2) -> Self::Output {
Bias2D {
bias: self.bias.to_device(device),
}
}
}

impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for Bias2D<C, E, D> {
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(visitor: &mut V) -> Result<(), V::Err> {
visitor.visit_tensor(
"beta",
|s| &s.bias,
|s| &mut s.bias,
TensorOptions::reset_to_zeros(),
)
}
}

impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<(Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>
{
type Output = Tensor<(Const<C>, H, W), E, D, T>;
type Error = D::Err;

fn try_forward(
&self,
input: Tensor<(Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
let s = *input.shape();
input.try_add(self.bias.retaped::<T>().try_broadcast_like(&s)?)
}
}

impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<(B, Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>
{
type Output = Tensor<(B, Const<C>, H, W), E, D, T>;
type Error = D::Err;

fn try_forward(
&self,
input: Tensor<(B, Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
let s = *input.shape();
input.try_add(self.bias.retaped::<T>().try_broadcast_like(&s)?)
}
}
90 changes: 17 additions & 73 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use num_traits::Float;
use rand_distr::uniform::SampleUniform;

use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*};
use crate::{shapes::*, tensor::*, tensor_ops::*};

use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice};
use super::{tensor_collection::*, BuildModule, BuildOnDevice, NonMutableModule, ToDevice};

pub mod builder {
#[derive(Debug)]
Expand All @@ -29,9 +29,15 @@ where
}
}

/// **Requires Nightly** Performs 2d convolutions on 3d and 4d images.
/// **Requires Nightly** Performs *unbiased* 2d convolutions on 3d and 4d images.
///
/// **Pytorch Equivalent**: `torch.nn.Conv2d`
/// **Pytorch Equivalent**: `torch.nn.Conv2d(..., bias=False)`
///
/// To create a biased conv, combine with [crate::nn::modules::Bias2D]:
/// ```ignore
/// # use dfdx::prelude::*;
/// type BiasedConv = (Conv2D<3, 5, 4>, Bias2D<5>);
/// ```
///
/// Generics:
/// - `IN_CHAN`: The number of input channels in an image.
Expand All @@ -50,7 +56,6 @@ pub struct Conv2D<
D: DeviceStorage,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
pub bias: Tensor<Rank1<OUT_CHAN>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
Expand All @@ -68,15 +73,6 @@ where
let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt();
t.try_fill_with_distr(rand_distr::Uniform::new(-b, b))
}),
)?;
visitor.visit_tensor(
"bias",
|s| &s.bias,
|s| &mut s.bias,
TensorOptions::reset_with(|t| {
let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt();
t.try_fill_with_distr(rand_distr::Uniform::new(-b, b))
}),
)
}
}
Expand All @@ -92,7 +88,6 @@ where
let bound = E::ONE / k.sqrt();
Ok(Self {
weight: device.try_sample(rand_distr::Uniform::new(-bound, bound))?,
bias: device.try_sample(rand_distr::Uniform::new(-bound, bound))?,
})
}
}
Expand All @@ -109,87 +104,39 @@ where
fn to_device(&self, device: &D2) -> Self::Output {
Conv2D {
weight: self.weight.to_device(device),
bias: self.bias.to_device(device),
}
}
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for Conv2D<C, O, K, S, P, E, D>
super::Module<Img> for Conv2D<C, O, K, S, P, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConv2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
for<'a> Bias2D<'a, O, E, D>: Module<Img::Output, Output = Img::Output, Error = D::Err>,
{
type Output = Img::Output;
type Error = D::Err;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
Bias2D { beta: &self.bias }.try_forward(x.try_conv2d_to(self.weight.clone())?)
x.try_conv2d_to(self.weight.clone())
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
ModuleMut<Img> for Conv2D<I, O, K, S, P, E, D>
impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for Conv2D<I, O, K, S, P, E, D>
where
E: Dtype,
D: Device<E>,
Self: Module<Img, Error = D::Err>,
{
type Output = <Self as Module<Img>>::Output;
type Error = D::Err;

fn try_forward_mut(&mut self, input: Img) -> Result<Self::Output, D::Err> {
self.try_forward(input)
}
}

#[derive(Clone, Debug)]
struct Bias2D<'a, const C: usize, E: Dtype, D: DeviceStorage> {
beta: &'a Tensor<Rank1<C>, E, D>,
}

impl<'a, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<(Const<C>, H, W), E, D, T>> for Bias2D<'a, C, E, D>
{
type Output = Tensor<(Const<C>, H, W), E, D, T>;
type Error = D::Err;

fn try_forward(
&self,
input: Tensor<(Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
}
}

impl<'a, B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<D>>
Module<Tensor<(B, Const<C>, H, W), E, D, T>> for Bias2D<'a, C, E, D>
D: DeviceStorage,
{
type Output = Tensor<(B, Const<C>, H, W), E, D, T>;
type Error = D::Err;

fn try_forward(
&self,
input: Tensor<(B, Const<C>, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
self.beta
.retaped::<T>()
.try_broadcast_like(input.shape())?
.try_add(input)
}
}

#[cfg(feature = "nightly")]
#[cfg(test)]
mod tests {
use crate::{
nn::DeviceBuildExt,
nn::{DeviceBuildExt, Module, ModuleMut},
optim::*,
tensor::{AsArray, SampleTensor, ZerosTensor},
tests::*,
Expand Down Expand Up @@ -258,18 +205,15 @@ mod tests {
let mut m = dev.build_module::<Conv2D<2, 4, 3>, TestDtype>();

let weight_init = m.weight.clone();
let bias_init = m.bias.clone();

let mut opt = Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank4<8, 2, 28, 28>>().trace());
let out = m.forward(dev.sample_normal::<Rank4<8, 2, 28, 28>>().traced());
let g = out.square().mean().backward();

assert_ne!(g.get(&m.weight).array(), [[[[0.0; 3]; 3]; 2]; 4]);
assert_ne!(g.get(&m.bias).array(), [0.0; 4]);

opt.update(&mut m, g).expect("unused params");

assert_ne!(weight_init.array(), m.weight.array());
assert_ne!(bias_init.array(), m.bias.array());
}
}
3 changes: 3 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pub mod tensor_collection;
mod activations;
mod add_into;
mod batchnorm2d;
mod bias2d;
mod conv;
mod dropout;
mod embedding;
Expand Down Expand Up @@ -143,6 +144,7 @@ pub mod modules {
pub use super::activations::*;
pub use super::add_into::AddInto;
pub use super::batchnorm2d::BatchNorm2D;
pub use super::bias2d::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::Conv2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand All @@ -168,6 +170,7 @@ pub mod builders {
pub use super::activations::*;
pub use super::add_into::AddInto;
pub use super::batchnorm2d::builder::BatchNorm2D;
pub use super::bias2d::builder::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::builder::Conv2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand Down