-
-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ToDevice and OnDevice to simplify nn api (#388) #394
Conversation
Ok, this should now be a full implementation. |
I've now renamed OnDeviceTrait to ToDevice and added a to_device method which allows copying a module from one device to another, but I haven't updated the documentation yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good! Going to play around with a bunch of user errors to see what the error messages look like, I imagine using associated types should help make them clear
src/tensor/tensor_impls.rs
Outdated
@@ -192,6 +194,20 @@ impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> { | |||
} | |||
} | |||
|
|||
impl<S: Shape, E: Dtype + Unit, D1: Device<E>, T, D2: Device<E>> ToDevice<D2> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
impl<S: Shape, E: Dtype + Unit, D1: Device<E>, T, D2: Device<E>> ToDevice<D2> | |
impl<S: Shape, E: Dtype + Unit, D1: DeviceStorage, T, D2: DeviceStorage> ToDevice<D2> |
You may need to add some restrictions for vec copying, I can't remember what trait those need on top of the base DeviceStorage
src/nn/module.rs
Outdated
pub trait ToDevice<D> { | ||
type Output; | ||
fn to_device(&self, device: &D) -> Self::Output; | ||
} | ||
|
||
/// A type alias that yields the type of a module `M` as it would exist on device `D`. This can be | ||
/// very useful when creating sequential networks that need to be parameterized by a device. | ||
/// | ||
/// Examples: | ||
/// ```rust | ||
/// # use dfdx::nn::*; | ||
/// type MLP<D> = OnDevice<(Linear<5, 10>, ReLU, Linear<10, 1>), D>; | ||
/// ``` | ||
/// | ||
/// ```rust | ||
/// # use dfdx::prelude::*; | ||
/// # | ||
/// // All modules exist on the cpu by default | ||
/// type CpuMLP = (Linear<5, 10>, ReLU, Linear<10, 1>); | ||
/// type MLP<D> = OnDevice<CpuMLP, D>; | ||
/// # #[cfg(feature = "cuda")] | ||
/// type CudaMLP = OnDevice<CpuMLP, Cuda>; | ||
/// ``` | ||
pub type OnDevice<M, D> = <M as ToDevice<D>>::Output; | ||
|
||
/// Equivalent to `OnDevice<M, Cuda>` | ||
#[cfg(feature = "cuda")] | ||
pub type OnCuda<M> = OnDevice<M, crate::prelude::Cuda>; | ||
|
||
/// Equivalent to `OnDevice<M, Cpu>` | ||
pub type OnCpu<M> = OnDevice<M, Cpu>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move all this into tensor/tensor_impls.rs
since that is lower level than nn
Okay here's some observations from playing around Step 1: existing build_module works as expecteduse dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
let dev: Cpu = Default::default();
let m: MLP = dev.build_module();
} Step 2: change device type fails as we might expectHowever this is an inconsistency with the API - if this fails, then so should using Cpu. error[E0277]: the trait bound `(dfdx::nn::Linear<5, 10>, dfdx::nn::ReLU, dfdx::nn::Linear<10, 5>): ResetParams<dfdx::tensor::Cuda, f32>` is not satisfied
--> examples/tmp.rs:5:18
|
5 | let m: MLP = dev.build_module();
| ^^^ ------------ required by a bound introduced by this call
| |
| the trait `ResetParams<dfdx::tensor::Cuda, f32>` is not implemented for `(dfdx::nn::Linear<5, 10>, dfdx::nn::ReLU, dfdx::nn::Linear<10, 5>)` Using OnCuda & OnCpu worksuse dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
let dev: Cuda = Default::default();
let m: OnDevice<MLP, Cuda> = dev.build_module();
} But if you specify OnDevice & a function generic, it fails?use dfdx::prelude::*;
type MLP = (Linear<5, 10>, ReLU, Linear<10, 5>);
fn main() {
let dev: Cuda = Default::default();
let m: OnDevice<MLP, Cuda> = dev.build_module::<MLP>();
} fails with error[E0308]: mismatched types
--> examples/tmp.rs:5:34
|
5 | let m: OnDevice<MLP, Cuda> = dev.build_module::<MLP>();
| ------------------- ^^^^^^^^^^^^^^^^^^^^^^^^^ expected struct `dfdx::tensor::Cuda`, found struct `dfdx::tensor::Cpu`
| |
| expected due to this
|
= note: expected tuple `(dfdx::nn::Linear<dfdx::tensor::Cuda, _, _>, dfdx::nn::ReLU, dfdx::nn::Linear<dfdx::tensor::Cuda, _, _>)`
found tuple `(dfdx::nn::Linear<dfdx::tensor::Cpu, _, _>, dfdx::nn::ReLU, dfdx::nn::Linear<dfdx::tensor::Cpu, _, _>)` |
That is expected, because |
I've tried modifying ModuleBuilder like so: pub trait IsSameType<T> {}
impl<T> IsSameType<T> for T {}
/// Extension trait for [Device] that can build anything that implements [ResetParams].
pub trait ModuleBuilder<E: Dtype>: Device<E> {
fn build_module<M1, M2>(&self) -> M2
where M1: ToDevice<Self> + IsSameType<OnCpu<M2>>,
M2: ResetParams<Self, E> + ToDevice<Cpu> + IsSameType<OnDevice<M1, Self>>
{
ResetParams::build(self)
}
fn try_build<M1>(&self) -> Result<M1, Self::Err>
where M1: ResetParams<Self, E>
{
ResetParams::try_build(self)
}
} but the compiler doesn't seem to be smart enough to figure out the type of M1, and fails with error[E0284]: type annotations needed
--> src/nn/linear.rs:165:41
|
165 | let m: Linear<2000, 1, _> = dev.build_module();
| ^^^^^^^^^^^^ cannot infer type of the type parameter `M1` declared on the associated function `build_module`
|
= note: cannot satisfy `<_ as tensor_impls::ToDevice<cuda::device::Cuda>>::Output == linear::Linear<2000, 1, cuda::device::Cuda>`
help: consider specifying the generic arguments
|
165 | let m: Linear<2000, 1, _> = dev.build_module::<M1, linear::Linear<2000, 1, cuda::device::Cuda>>();
| +++++++++++++++++++++++++++++++++++++++++++++++++++ |
Yeah after a lot of playing around, its because the output type is an associated type, so it would have to infer the original type. I don't know why it doesn't do that, since it seems possible, but 🤷 I'm thinking of moving away from allowing fn main() {
type Dev = Cpu;
type Model = (Linear<3, 5>, ReLU, Linear<5, 3>);
type DeviceModel<D> = (Linear<3, 5, D>, ReLU, Linear<5, 3, D>);
let dev: Dev = Default::default();
let q = Model::build_on(&dev);
let q: DeviceModel<Dev> = Model::build_on(&dev);
let q: DeviceModel<_> = Model::build_on(&dev);
let q: (Linear<3, 5, _>, ReLU, Linear<5, 3, _>) = Model::build_on(&dev);
} because rust won't infer for us here. The nice thing about this is that if someone does error[E0308]: mismatched types
--> examples\building.rs:120:65
|
120 | let q: (Linear<3, 5>, ReLU, Linear<5, 3>) = Model::build_on(&dev);
| --------------- ^^^^ expected struct `Cpu`, found struct `Cuda`
| |
| arguments to this function are incorrect
|
= note: expected reference `&Cpu`
found reference `&Cuda`
note: associated function defined here
--> examples\building.rs:62:8
|
62 | fn build_on(dev: &D) -> Self::Built;
| ^^^^^^^^ |
We pretty much already have that in the ResetParams trait, so all we would need to do is to phase out ModuleBuilder and update the examples/documentation. |
src/tensor/tensor_impls.rs
Outdated
@@ -192,6 +192,83 @@ impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> { | |||
} | |||
} | |||
|
|||
/// A trait which allows a [Module] to be copied to another [Device] and to be used with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// A trait which allows a [Module] to be copied to another [Device] and to be used with the | |
/// Something that can be copied to another `Device`. To be used with the |
to get rid of doc warnings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Partial implementation of a solution for #388, as discussed here.