-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding median proc-block adding more hints #65
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's awesome to see you chugging away with your own proc blocks!
I think we should try to add some proper error handling to the kernel()
function so we don't blow up the moment we leave the happy path.
I think I also found a bug in the implementation. Can you pull the business logic out into some sort of fn median(values: &ArrayView1<f64>) -> Option<f64>
function and add some unit tests?
median/src/lib.rs
Outdated
fn kernel(id: String) -> Result<(), KernelError> { | ||
let ctx = KernelContext::for_node(&id).unwrap(); | ||
|
||
let TensorResult { | ||
element_type, | ||
buffer, | ||
dimensions, | ||
} = ctx.get_input_tensor("samples").unwrap(); | ||
|
||
let samples: ArrayView1<f64> = match element_type { | ||
ElementType::F64 => buffer | ||
.view(&dimensions) | ||
.unwrap() | ||
.into_dimensionality() | ||
.unwrap(), | ||
_ => panic!("Handle invalid element type"), | ||
}; | ||
let mut median_slice: Vec<f64> = samples.to_slice().unwrap().to_vec(); | ||
median_slice.sort_by(|a, b| a.partial_cmp(b).unwrap()); | ||
|
||
let median = (median_slice.len() as f32 / 2.0) as i32 as usize; | ||
|
||
ctx.set_output_tensor( | ||
"median", | ||
TensorParam { | ||
element_type: ElementType::F64, | ||
dimensions: &[1], | ||
buffer: [samples[median]].as_bytes(), | ||
}, | ||
); | ||
|
||
Ok(()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a hell of a lot of unwrap()
s in here. Is there any chance you can translate them into errors so the proc-block fails gracefully instead of ending up fubar.
As we implement more proc-blocks I'll introduce some convenience methods or abstractions to make error handling cleaner, but in the meantime the modulo implementation should give you some inspiration.
median/src/lib.rs
Outdated
let samples: ArrayView1<f64> = match element_type { | ||
ElementType::F64 => buffer | ||
.view(&dimensions) | ||
.unwrap() | ||
.into_dimensionality() | ||
.unwrap(), | ||
_ => panic!("Handle invalid element type"), | ||
}; | ||
let mut median_slice: Vec<f64> = samples.to_slice().unwrap().to_vec(); | ||
median_slice.sort_by(|a, b| a.partial_cmp(b).unwrap()); | ||
|
||
let median = (median_slice.len() as f32 / 2.0) as i32 as usize; | ||
|
||
ctx.set_output_tensor( | ||
"median", | ||
TensorParam { | ||
element_type: ElementType::F64, | ||
dimensions: &[1], | ||
buffer: [samples[median]].as_bytes(), | ||
}, | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to do the sorting in place. That way we don't copy the entire input tensor (the to_vec()
) unnecessarily.
Maybe something like this?
let samples: ArrayView1<f64> = match element_type { | |
ElementType::F64 => buffer | |
.view(&dimensions) | |
.unwrap() | |
.into_dimensionality() | |
.unwrap(), | |
_ => panic!("Handle invalid element type"), | |
}; | |
let mut median_slice: Vec<f64> = samples.to_slice().unwrap().to_vec(); | |
median_slice.sort_by(|a, b| a.partial_cmp(b).unwrap()); | |
let median = (median_slice.len() as f32 / 2.0) as i32 as usize; | |
ctx.set_output_tensor( | |
"median", | |
TensorParam { | |
element_type: ElementType::F64, | |
dimensions: &[1], | |
buffer: [samples[median]].as_bytes(), | |
}, | |
); | |
let samples: ArrayViewMut1<f64> = match element_type { | |
ElementType::F64 => buffer | |
.view_mut(&dimensions) | |
.unwrap() | |
.into_dimensionality() | |
.unwrap(), | |
_ => panic!("Handle invalid element type"), | |
}; | |
samples.as_slice_mut().unwrap().sort_by(|a, b| a.partial_cmp(b).unwrap()); | |
let median = (median_slice.len() as f32 / 2.0) as i32 as usize; | |
ctx.set_output_tensor( | |
"median", | |
TensorParam { | |
element_type: ElementType::F64, | |
dimensions: &[1], | |
buffer: [samples[median]].as_bytes(), | |
}, | |
); |
Also, the original code stored the sorted values in median_slice
then proceeded to extract the median using samples[median]
... Which is probably not what you wanted 😅
No description provided.