Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions diskann-benchmark-runner/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,16 @@ impl App {
// List the available benchmarks.
Commands::Benchmarks {} => {
writeln!(output, "Registered Benchmarks:")?;
for (name, method) in benchmarks.methods() {
writeln!(output, " {}: {}", name, method.signatures()[0])?;
for (name, description) in benchmarks.names() {
let mut lines = description.lines();
if let Some(first) = lines.next() {
writeln!(output, " {}: {}", name, first)?;
for line in lines {
writeln!(output, " {}", line)?;
}
} else {
writeln!(output, " {}: <no description>", name)?;
}
}
}
Commands::Skeleton => {
Expand All @@ -130,23 +138,10 @@ impl App {
let run = Jobs::load(input_file, inputs)?;
// Check if we have a match for each benchmark.
for job in run.jobs().iter() {
if !benchmarks.has_match(job) {
const MAX_METHODS: usize = 3;
if let Err(mismatches) = benchmarks.debug(job, MAX_METHODS) {
let repr = serde_json::to_string_pretty(&job.serialize()?)?;

const MAX_METHODS: usize = 3;
let mismatches = match benchmarks.debug(job, MAX_METHODS) {
// Debug should return `Err` if there is not a match.
// Returning `Ok(())` here indicates an internal error with the
// dispatcher.
Ok(()) => {
return Err(anyhow::Error::msg(format!(
"experienced internal error while debugging:\n{}",
repr
)))
}
Err(m) => m,
};

writeln!(
output,
"Could not find a match for the following input:\n\n{}\n",
Expand All @@ -165,7 +160,7 @@ impl App {
writeln!(output)?;

return Err(anyhow::Error::msg(
"could not find find a benchmark for all inputs",
"could not find a benchmark for all inputs",
));
}
}
Expand Down
140 changes: 140 additions & 0 deletions diskann-benchmark-runner/src/benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

use serde::Serialize;

use crate::{
dispatcher::{FailureScore, MatchScore},
Any, Checkpoint, Input, Output,
};

/// A registered benchmark.
///
/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
/// successful matches will be passed to [`run`](Self::run).
pub trait Benchmark {
/// The [`Input`] type this benchmark matches against.
type Input: Input + 'static;

/// The concrete type of the results generated by this benchmark.
type Output: Serialize;

/// Return whether or not this benchmark is compatible with `input`.
///
/// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be
/// collected and the benchmark with the lowest final score will be selected.
///
/// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
///
/// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
/// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
/// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
fn try_match(input: &Self::Input) -> Result<MatchScore, FailureScore>;

/// Return descriptive information about the benchmark.
///
/// If `input` is `None`, then high level information about the benchmark should be relayed.
/// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
/// was expected should be generated to help users.
fn description(
f: &mut std::fmt::Formatter<'_>,
input: Option<&Self::Input>,
) -> std::fmt::Result;

/// Run the benchmark with `input`.
///
/// All prints should be directed to `output`. The `checkpoint` is provided so
/// long-running benchmarks can periodically save output to prevent data loss due to
/// an early error.
///
/// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
fn run(
input: &Self::Input,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<Self::Output>;
}

//////////////
// Internal //
//////////////

/// Object-safe trait for type-erased benchmarks stored in the registry.
pub(crate) trait DynBenchmark {
fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;

fn description(&self, f: &mut std::fmt::Formatter<'_>, input: Option<&Any>)
-> std::fmt::Result;

fn run(
&self,
input: &Any,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<serde_json::Value>;
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct Wrapper<T>(std::marker::PhantomData<T>);

impl<T> Wrapper<T> {
pub(crate) fn new() -> Self {
Self(std::marker::PhantomData)
}
}

/// The score given to unsuccessful downcasts in [`DynBenchmark::try_match`].
const MATCH_FAIL: FailureScore = FailureScore(10_000);

impl<T> DynBenchmark for Wrapper<T>
where
T: Benchmark,
{
fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
if let Some(cast) = input.downcast_ref::<T::Input>() {
T::try_match(cast)
} else {
Err(MATCH_FAIL)
}
}

fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&Any>,
) -> std::fmt::Result {
match input {
Some(input) => match input.downcast_ref::<T::Input>() {
Some(cast) => T::description(f, Some(cast)),
None => write!(
f,
"expected tag \"{}\" - instead got \"{}\"",
T::Input::tag(),
input.tag(),
),
},
None => {
writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
T::description(f, None)
}
}
}

fn run(
&self,
input: &Any,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<serde_json::Value> {
match input.downcast_ref::<T::Input>() {
Some(input) => {
let result = T::run(input, checkpoint, output)?;
Ok(serde_json::to_value(result)?)
}
None => Err(anyhow::anyhow!("INTERNAL ERROR: invalid downcast!")),
}
}
}
153 changes: 2 additions & 151 deletions diskann-benchmark-runner/src/dispatcher/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::fmt::{self, Display, Formatter};
/// Successful matches from [`DispatchRule`] will return `MatchScores`.
///
/// A lower numerical value indicates a better match for purposes of overload resolution.
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MatchScore(pub u32);

impl Display for MatchScore {
Expand All @@ -21,7 +21,7 @@ impl Display for MatchScore {
///
/// A lower numerical value indicates a better match, which can help when compiling a
/// list of considered and rejected candidates.
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct FailureScore(pub u32);

impl Display for FailureScore {
Expand Down Expand Up @@ -218,155 +218,6 @@ impl<'a, T: Sized> DispatchRule<&'a mut T> for &'a T {
}
}

/// # Lifetime Mapping
///
/// The types in signatures for dispatches need to be `'static` due to Rust.
/// However, it is nice to allow objects with lifetimes to cross the dispatcher boundary.
///
/// The `Map` trait facilitates this by allowing `'static` types to have an optional
/// lifetime attached as a generic associated type.
///
/// This associated type is that is what is actually given to dispatcher methods.
///
/// ## Example
///
/// To pass a `Vec` across a dispatcher boundary, we can use the [`Type`] helper:
///
/// ```
/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Type};
///
/// let mut d = Dispatcher1::<&'static str, Type<Vec<f32>>>::new();
/// d.register::<_, Type<Vec<f32>>>("method", |_: Vec<f32>| "called");
/// assert_eq!(d.call(vec![1.0]), Some("called"));
/// ```
///
/// This is a bit tedious to write every time, so instead types can implement [`Map`] for
/// themselves:
///
/// ```
/// use diskann_benchmark_runner::{self_map, dispatcher::{Dispatcher1}};
///
/// struct MyNum(f32);
/// self_map!(MyNum);
///
/// // Now, `MyNum` can be used directly in dispatcher signatures.
/// let mut d = Dispatcher1::<f32, MyNum>::new();
/// d.register::<_, MyNum>("method", |n: MyNum| n.0);
/// assert_eq!(d.call(MyNum(0.0)), Some(0.0));
/// ```
///
/// ## See Also:
///
/// * [`Ref`]: Mapping References
/// * [`MutRef`]: Mapping Mutable References
/// * [`Type`]: Mapper for generic types
/// * [`crate::self_map!`]: Allow types to represent themselves in dispatcher signatures.
///
pub trait Map: 'static {
/// The actual type provided to the dispatcher, with an optional additional lifetime.
type Type<'a>;
}

/// Allow references to cross dispatcher boundaries as shown in the following example:
///
/// ```
/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Ref};
///
/// let mut d = Dispatcher1::<*const f32, Ref<[f32]>>::new();
/// d.register::<_, Ref<[f32]>>("method", |data: &[f32]| data.as_ptr());
///
/// let v = vec![1.0, 2.0];
/// assert_eq!(d.call(&v), Some(v.as_ptr()));
/// ```
pub struct Ref<T: ?Sized + 'static>(std::marker::PhantomData<T>);

impl<T: ?Sized> Map for Ref<T> {
type Type<'a> = &'a T;
}

/// Allow mutable references to cross dispatcher boundaries as shown below.
///
/// ```
/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, MutRef};
///
/// let mut d = Dispatcher1::<(), MutRef<Vec<f32>>>::new();
/// d.register::<_, MutRef<Vec<f32>>>("method", |v: &mut Vec<f32>| v.push(0.0));
///
/// let mut v = Vec::new();
/// d.call(&mut v).unwrap();
/// assert_eq!(&v, &[0.0]);
/// ```
pub struct MutRef<T: ?Sized + 'static>(std::marker::PhantomData<T>);
impl<T: ?Sized> Map for MutRef<T> {
type Type<'a> = &'a mut T;
}

pub struct Type<T: 'static>(std::marker::PhantomData<T>);
impl<T> Map for Type<T> {
type Type<'a> = T;
}

#[macro_export]
macro_rules! self_map {
($($type:tt)*) => {
impl $crate::dispatcher::Map for $($type)* {
type Type<'a> = $($type)*;
}
}
}

self_map!(bool);
self_map!(usize);
self_map!(u8);
self_map!(u16);
self_map!(u32);
self_map!(u64);
self_map!(u128);
self_map!(i8);
self_map!(i16);
self_map!(i32);
self_map!(i64);
self_map!(i128);
self_map!(String);
self_map!(f32);
self_map!(f64);

/// Reasons for a method call mismatch.
///
/// The name of the associated method can be queried using `self.method()` and reasons
/// are obtained in `self.mismatches()`.
pub struct ArgumentMismatch<'a, const N: usize> {
pub(crate) method: &'a str,
pub(crate) mismatches: [Option<Box<dyn std::fmt::Display + 'a>>; N],
}

impl<'a, const N: usize> ArgumentMismatch<'a, N> {
/// Return the name of the associated method.
pub fn method(&self) -> &str {
self.method
}

/// Return a slice of reasons for method match failure.
///
/// The returned slice contains one entry per argument. An entry is `None` if that
/// argument matched the input value.
///
/// If the argument did not match the input value, then the corresponding
/// [`std::fmt::Display`] object can be used to retrieve the reason.
pub fn mismatches(&self) -> &[Option<Box<dyn std::fmt::Display + 'a>>; N] {
&self.mismatches
}
}

/// Return the signature for an argument type.
pub struct Signature(pub(crate) fn(&mut Formatter<'_>) -> std::fmt::Result);

impl std::fmt::Display for Signature {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
(self.0)(f)
}
}

///////////
// Tests //
///////////
Expand Down
Loading
Loading