Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ToDevice and OnDevice to simplify nn api (#388) #394

Merged
merged 18 commits into from
Jan 26, 2023

Conversation

nkoppel
Copy link
Contributor

@nkoppel nkoppel commented Jan 24, 2023

Partial implementation of a solution for #388, as discussed here.

@nkoppel nkoppel marked this pull request as draft January 24, 2023 15:05
@nkoppel nkoppel marked this pull request as ready for review January 24, 2023 17:56
@nkoppel
Copy link
Contributor Author

nkoppel commented Jan 24, 2023

Ok, this should now be a full implementation.

@nkoppel nkoppel changed the title Add OnDevice type alias to simplify nn api (#388) Add ToDevice and OnDevice to simplify nn api (#388) Jan 25, 2023
@nkoppel
Copy link
Contributor Author

nkoppel commented Jan 25, 2023

I've now renamed OnDeviceTrait to ToDevice and added a to_device method which allows copying a module from one device to another, but I haven't updated the documentation yet.

Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good! Going to play around with a bunch of user errors to see what the error messages look like, I imagine using associated types should help make them clear

@@ -192,6 +194,20 @@ impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> {
}
}

impl<S: Shape, E: Dtype + Unit, D1: Device<E>, T, D2: Device<E>> ToDevice<D2>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl<S: Shape, E: Dtype + Unit, D1: Device<E>, T, D2: Device<E>> ToDevice<D2>
impl<S: Shape, E: Dtype + Unit, D1: DeviceStorage, T, D2: DeviceStorage> ToDevice<D2>

You may need to add some restrictions for vec copying, I can't remember what trait those need on top of the base DeviceStorage

src/nn/module.rs Outdated
Comment on lines 120 to 150
pub trait ToDevice<D> {
type Output;
fn to_device(&self, device: &D) -> Self::Output;
}

/// A type alias that yields the type of a module `M` as it would exist on device `D`. This can be
/// very useful when creating sequential networks that need to be parameterized by a device.
///
/// Examples:
/// ```rust
/// # use dfdx::nn::*;
/// type MLP<D> = OnDevice<(Linear<5, 10>, ReLU, Linear<10, 1>), D>;
/// ```
///
/// ```rust
/// # use dfdx::prelude::*;
/// #
/// // All modules exist on the cpu by default
/// type CpuMLP = (Linear<5, 10>, ReLU, Linear<10, 1>);
/// type MLP<D> = OnDevice<CpuMLP, D>;
/// # #[cfg(feature = "cuda")]
/// type CudaMLP = OnDevice<CpuMLP, Cuda>;
/// ```
pub type OnDevice<M, D> = <M as ToDevice<D>>::Output;

/// Equivalent to `OnDevice<M, Cuda>`
#[cfg(feature = "cuda")]
pub type OnCuda<M> = OnDevice<M, crate::prelude::Cuda>;

/// Equivalent to `OnDevice<M, Cpu>`
pub type OnCpu<M> = OnDevice<M, Cpu>;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move all this into tensor/tensor_impls.rs since that is lower level than nn

@coreylowman
Copy link
Owner

Okay here's some observations from playing around

Step 1: existing build_module works as expected

use dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
    let dev: Cpu = Default::default();
    let m: MLP = dev.build_module();
}

Step 2: change device type fails as we might expect

However this is an inconsistency with the API - if this fails, then so should using Cpu.

error[E0277]: the trait bound `(dfdx::nn::Linear<5, 10>, dfdx::nn::ReLU, dfdx::nn::Linear<10, 5>): ResetParams<dfdx::tensor::Cuda, f32>` is not satisfied
  --> examples/tmp.rs:5:18
   |
5  |     let m: MLP = dev.build_module();
   |                  ^^^ ------------ required by a bound introduced by this call
   |                  |
   |                  the trait `ResetParams<dfdx::tensor::Cuda, f32>` is not implemented for `(dfdx::nn::Linear<5, 10>, dfdx::nn::ReLU, dfdx::nn::Linear<10, 5>)`

Using OnCuda & OnCpu works

use dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
    let dev: Cuda = Default::default();
    let m: OnDevice<MLP, Cuda> = dev.build_module();
}

But if you specify OnDevice & a function generic, it fails?

use dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
    let dev: Cuda = Default::default();
    let m: OnDevice<MLP, Cuda> = dev.build_module::<MLP>();

}

fails with

error[E0308]: mismatched types
 --> examples/tmp.rs:5:34
  |
5 |     let m: OnDevice<MLP, Cuda> = dev.build_module::<MLP>();
  |            -------------------   ^^^^^^^^^^^^^^^^^^^^^^^^^ expected struct `dfdx::tensor::Cuda`, found struct `dfdx::tensor::Cpu`
  |            |
  |            expected due to this
  |
  = note: expected tuple `(dfdx::nn::Linear<dfdx::tensor::Cuda, _, _>, dfdx::nn::ReLU, dfdx::nn::Linear<dfdx::tensor::Cuda, _, _>)`
             found tuple `(dfdx::nn::Linear<dfdx::tensor::Cpu, _, _>, dfdx::nn::ReLU, dfdx::nn::Linear<dfdx::tensor::Cpu, _, _>)`

@nkoppel
Copy link
Contributor Author

nkoppel commented Jan 25, 2023

But if you specify OnDevice & a function generic, it fails?

That is expected, because MLP == OnCpu<MLP>, and MLP != OnCuda<MLP>

@nkoppel
Copy link
Contributor Author

nkoppel commented Jan 25, 2023

I've tried modifying ModuleBuilder like so:

pub trait IsSameType<T> {}
impl<T> IsSameType<T> for T {}

/// Extension trait for [Device] that can build anything that implements [ResetParams].
pub trait ModuleBuilder<E: Dtype>: Device<E> {
    fn build_module<M1, M2>(&self) -> M2
        where M1: ToDevice<Self> + IsSameType<OnCpu<M2>>,
              M2: ResetParams<Self, E> + ToDevice<Cpu> + IsSameType<OnDevice<M1, Self>>
    {
        ResetParams::build(self)
    }
    fn try_build<M1>(&self) -> Result<M1, Self::Err>
        where M1: ResetParams<Self, E>
    {
        ResetParams::try_build(self)
    }
}

but the compiler doesn't seem to be smart enough to figure out the type of M1, and fails with

error[E0284]: type annotations needed
   --> src/nn/linear.rs:165:41
    |
165 |         let m: Linear<2000, 1, _> = dev.build_module();
    |                                         ^^^^^^^^^^^^ cannot infer type of the type parameter `M1` declared on the associated function `build_module`
    |
    = note: cannot satisfy `<_ as tensor_impls::ToDevice<cuda::device::Cuda>>::Output == linear::Linear<2000, 1, cuda::device::Cuda>`
help: consider specifying the generic arguments
    |
165 |         let m: Linear<2000, 1, _> = dev.build_module::<M1, linear::Linear<2000, 1, cuda::device::Cuda>>();
    |                                                     +++++++++++++++++++++++++++++++++++++++++++++++++++

@coreylowman
Copy link
Owner

Yeah after a lot of playing around, its because the output type is an associated type, so it would have to infer the original type. I don't know why it doesn't do that, since it seems possible, but 🤷

I'm thinking of moving away from allowing dev.build_module() and instead of moving to calling a method on the type instead:

fn main() {
    type Dev = Cpu;
    type Model = (Linear<3, 5>, ReLU, Linear<5, 3>);
    type DeviceModel<D> = (Linear<3, 5, D>, ReLU, Linear<5, 3, D>);

    let dev: Dev = Default::default();

    let q = Model::build_on(&dev);
    let q: DeviceModel<Dev> = Model::build_on(&dev);
    let q: DeviceModel<_> = Model::build_on(&dev);
    let q: (Linear<3, 5, _>, ReLU, Linear<5, 3, _>) = Model::build_on(&dev);
}

because rust won't infer for us here.

The nice thing about this is that if someone does let q: Model = Model::build_on(&dev) with the wrong device (e.g. other than Cpu), the error is actually helpful:

error[E0308]: mismatched types
   --> examples\building.rs:120:65
    |
120 |     let q: (Linear<3, 5>, ReLU, Linear<5, 3>) = Model::build_on(&dev);      
    |                                                 --------------- ^^^^ expected struct `Cpu`, found struct `Cuda`
    |                                                 |
    |                                                 arguments to this function are incorrect
    |
    = note: expected reference `&Cpu`
               found reference `&Cuda`
note: associated function defined here
   --> examples\building.rs:62:8
    |
62  |     fn build_on(dev: &D) -> Self::Built;
    |        ^^^^^^^^

@nkoppel
Copy link
Contributor Author

nkoppel commented Jan 26, 2023

We pretty much already have that in the ResetParams trait, so all we would need to do is to phase out ModuleBuilder and update the examples/documentation.

@@ -192,6 +192,83 @@ impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> {
}
}

/// A trait which allows a [Module] to be copied to another [Device] and to be used with the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// A trait which allows a [Module] to be copied to another [Device] and to be used with the
/// Something that can be copied to another `Device`. To be used with the

to get rid of doc warnings

Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants