Skip to content

Commit

Permalink
Clean up the types a little
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo committed Nov 24, 2023
1 parent 8a556f1 commit 9f2ee8e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tardyai/src/models/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,16 @@ pub type Resnet34Body = (
);

pub type Resnet34<const NUM_CLASSES: usize> = (Resnet34Body, Head<NUM_CLASSES>);
type Resnet34Built<const NUM_CLASSES: usize, E> =
<Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built;

pub struct Resnet34Model<const NUM_CLASSES: usize, E>
where
E: Dtype + SafeDtype,
Resnet34<NUM_CLASSES>: BuildOnDevice<AutoDevice, E>,
AutoDevice: Device<E>,
{
pub model: <Resnet34<NUM_CLASSES> as BuildOnDevice<AutoDevice, E>>::Built,
pub model: Resnet34Built<NUM_CLASSES, E>,
}

impl<E, const N: usize> Resnet34Model<N, E>
Expand All @@ -157,13 +159,12 @@ where
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
let tensors = SafeTensors::deserialize(&buffer).unwrap();

let _ = <<Resnet34<N> as BuildOnDevice<AutoDevice, E>>::Built as TensorCollection<
E,
AutoDevice,
>>::iter_tensors(&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
})?;
let _ = <Resnet34Built<N, E> as TensorCollection<E, AutoDevice>>::iter_tensors(
&mut RecursiveWalker {
m: &mut self.model,
f: &mut NamedTensorVisitor::new(&RESNET34_LAYERS, &tensors),
},
)?;

Ok(())
}
Expand Down

0 comments on commit 9f2ee8e

Please sign in to comment.