Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions optd-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@

use clap::{Parser, Subcommand};
use colored::Colorize;
use optd::catalog::Catalog;
use optd::catalog::iceberg::memory_catalog;
use optd::dsl::analyzer::hir::{CoreData, HIR, Udf, Value};
use optd::dsl::compile::{Config, compile_hir};
use optd::dsl::engine::{Continuation, Engine, EngineResponse};
use optd::dsl::utils::errors::{CompileError, Diagnose};
use optd::dsl::utils::retriever::{MockRetriever, Retriever};
use optd::dsl::utils::retriever::MockRetriever;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::runtime::Builder;
Expand All @@ -66,22 +65,17 @@ enum Commands {
RunFunctions(Config),
}

/// A unimplemented user-defined function.
pub fn unimplemented_udf(
_args: &[Value],
_catalog: &dyn Catalog,
_retriever: &dyn Retriever,
) -> Value {
println!("This user-defined function is unimplemented!");
Value::new(CoreData::<Value>::None)
}

fn main() -> Result<(), Vec<CompileError>> {
let cli = Cli::parse();

let mut udfs = HashMap::new();
let udf = Udf {
func: unimplemented_udf,
func: Arc::new(|_, _, _| {
Box::pin(async move {
println!("This user-defined function is unimplemented!");
Value::new(CoreData::<Value>::None)
})
}),
};
udfs.insert("unimplemented_udf".to_string(), udf.clone());

Expand Down
42 changes: 36 additions & 6 deletions optd/src/demo/demo.opt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
data Physical
data PhysicalProperties
data Statistics
data LogicalProperties
// Taking folded here is not the most interesting property,
// but it ensures they are the same for all expressions in the same group.
data LogicalProperties(folded: I64)

data Logical =
| Add(left: Logical, right: Logical)
Expand All @@ -10,21 +11,42 @@ data Logical =
| Div(left: Logical, right: Logical)
\ Const(val: I64)

data Physical =
| PhysicalAdd(left: Physical, right: Physical)
| PhysicalSub(left: Physical, right: Physical)
| PhysicalMult(left: Physical, right: Physical)
| PhysicalDiv(left: Physical, right: Physical)
\ PhysicalConst(val: I64)

// This will be the input plan that will be optimized.
// Result is: ((1 - 2) * (3 / 4)) + ((5 - 6) * (7 / 8)) = 0
fn input(): Logical =
Add(
Mult(
Sub(Const(1), Const(2)),
Div(Const(3), Const(4))
),
Mult(
Sub(Const(5), Const(6)),
Sub(Const(1), Const(2)),
Div(Const(7), Const(8))
)
)

// TODO(Alexis): This should be $ really, make costing and derive consistent with each other.
// External function to allow the retrieval of properties.
fn properties(op: Logical*): LogicalProperties

// FIXME: This should be $ really (or other), make costing and derive consistent with each other.
// Also, be careful of not forking in there! And make it a required function in analyzer.
fn derive(op: Logical) = LogicalProperties
fn derive(op: Logical*) = match op
| Add(left, right) ->
LogicalProperties(left.properties()#folded + right.properties()#folded)
| Sub(left, right) ->
LogicalProperties(left.properties()#folded - right.properties()#folded)
| Mult(left, right) ->
LogicalProperties(left.properties()#folded * right.properties()#folded)
| Div(left, right) ->
LogicalProperties(left.properties()#folded / right.properties()#folded)
\ Const(val) -> LogicalProperties(val)

[transformation]
fn (op: Logical*) mult_commute(): Logical? = match op
Expand Down Expand Up @@ -55,4 +77,12 @@ fn (op: Logical*) const_fold_sub(): Logical? = match op
fn (op: Logical*) const_fold_div(): Logical? = match op
| Div(Const(a), Const(b)) ->
if b == 0 then none else Const(a / b)
\ _ -> none
\ _ -> none

[implementation]
fn (op: Logical*) to_physical(props: PhysicalProperties?) = match op
| Add(left, right) -> PhysicalAdd(left.to_physical(props), right.to_physical(props))
| Sub(left, right) -> PhysicalSub(left.to_physical(props), right.to_physical(props))
| Mult(left, right) -> PhysicalMult(left.to_physical(props), right.to_physical(props))
| Div(left, right) -> PhysicalDiv(left.to_physical(props), right.to_physical(props))
\ Const(val) -> PhysicalConst(val)
63 changes: 51 additions & 12 deletions optd/src/demo/mod.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,53 @@
use crate::{
catalog::iceberg::memory_catalog,
catalog::{Catalog, iceberg::memory_catalog},
dsl::{
analyzer::hir::Value,
analyzer::hir::{CoreData, LogicalOp, Materializable, Udf, Value},
compile::{Config, compile_hir},
engine::{Continuation, Engine, EngineResponse},
utils::retriever::MockRetriever,
utils::retriever::{MockRetriever, Retriever},
},
memo::MemoryMemo,
optimizer::{OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical},
optimizer::{ClientRequest, OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical},
};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{sync::mpsc, time::timeout};
use tokio::{
sync::mpsc,
time::{sleep, timeout},
};

pub async fn properties(
args: Vec<Value>,
_catalog: Arc<dyn Catalog>,
retriever: Arc<dyn Retriever>,
) -> Value {
let arg = args[0].clone();
let group_id = match &arg.data {
CoreData::Logical(Materializable::Materialized(LogicalOp { group_id, .. })) => {
group_id.unwrap()
}
CoreData::Logical(Materializable::UnMaterialized(group_id)) => *group_id,
_ => panic!("Expected a logical plan"),
};

retriever.get_properties(group_id).await
}

async fn run_demo() {
// Compile the HIR.
let config = Config::new("src/demo/demo.opt".into());
let udfs = HashMap::new();

// Create a properties UDF.
let properties_udf = Udf {
func: Arc::new(|args, catalog, retriever| {
Box::pin(async move { properties(args, catalog, retriever).await })
}),
};

// Create the UDFs HashMap.
let mut udfs = HashMap::new();
udfs.insert("properties".to_string(), properties_udf);

// Compile with the config and UDFs.
let hir = compile_hir(config, udfs).unwrap();

// Create necessary components.
Expand All @@ -35,15 +67,15 @@ async fn run_demo() {
let optimize_channel = Optimizer::launch(memo, catalog, hir);
let (tx, mut rx) = mpsc::channel(1);
optimize_channel
.send(OptimizeRequest {
plan: logical_plan,
physical_tx: tx,
})
.send(ClientRequest::Optimize(OptimizeRequest {
plan: logical_plan.clone(),
physical_tx: tx.clone(),
}))
.await
.unwrap();

// Timeout after 2 seconds.
let timeout_duration = Duration::from_secs(2);
// Timeout after 5 seconds.
let timeout_duration = Duration::from_secs(5);
let result = timeout(timeout_duration, async {
while let Some(response) = rx.recv().await {
println!("Received response: {:?}", response);
Expand All @@ -55,6 +87,13 @@ async fn run_demo() {
Ok(_) => println!("Finished receiving responses."),
Err(_) => println!("Timed out after 5 seconds."),
}

// Dump the memo (debug utility).
optimize_channel
.send(ClientRequest::DumpMemo)
.await
.unwrap();
sleep(Duration::from_secs(10)).await;
}

#[cfg(test)]
Expand Down
26 changes: 7 additions & 19 deletions optd/src/dsl/analyzer/from_ast/converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,12 @@ impl ASTConverter {
#[cfg(test)]
mod converter_tests {
use super::*;
use crate::catalog::Catalog;
use crate::dsl::analyzer::from_ast::from_ast;
use crate::dsl::analyzer::hir::{CoreData, FunKind};
use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind};
use crate::dsl::parser::ast::{self, Adt, Function, Item, Module, Type as AstType};
use crate::dsl::utils::retriever::Retriever;
use crate::dsl::utils::span::{Span, Spanned};
use std::sync::Arc;

// Helper functions to create test items
fn create_test_span() -> Span {
Expand Down Expand Up @@ -382,19 +381,15 @@ mod converter_tests {
let ext_func = create_simple_function("external_function", false);
let module = create_module_with_functions(vec![ext_func]);

pub fn external_function(
_args: &[Value],
_catalog: &dyn Catalog,
_retriever: &dyn Retriever,
) -> Value {
println!("Hello from UDF!");
Value::new(CoreData::<Value>::None)
}

// Link the dummy function.
let mut udfs = HashMap::new();
let udf = Udf {
func: external_function,
func: Arc::new(|_, _, _| {
Box::pin(async move {
println!("Hello from UDF!");
Value::new(CoreData::None)
})
}),
};
udfs.insert("external_function".to_string(), udf);

Expand All @@ -408,13 +403,6 @@ mod converter_tests {
// Check that the function is in the context.
let func_val = hir.context.lookup("external_function");
assert!(func_val.is_some());

// Verify it is the same function pointer.
if let CoreData::Function(FunKind::Udf(udf)) = &func_val.unwrap().data {
assert_eq!(udf.func as usize, external_function as usize);
} else {
panic!("Expected UDF function");
}
}

#[test]
Expand Down
29 changes: 19 additions & 10 deletions optd/src/dsl/analyzer/hir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use crate::dsl::utils::retriever::Retriever;
use crate::dsl::utils::span::Span;
use context::Context;
use map::Map;
use std::fmt::Debug;
use std::fmt::{self, Debug};
use std::pin::Pin;
use std::{collections::HashMap, sync::Arc};

pub(crate) mod context;
Expand Down Expand Up @@ -72,22 +73,30 @@ impl TypedSpan {
}
}

#[derive(Debug, Clone)]
/// Type aliases for user-defined functions (UDFs).
type UdfFutureOutput = Pin<Box<dyn Future<Output = Value> + Send>>;
type UdfFunction =
dyn Fn(Vec<Value>, Arc<dyn Catalog>, Arc<dyn Retriever>) -> UdfFutureOutput + Send + Sync;

#[derive(Clone)]
pub struct Udf {
/// The function pointer to the user-defined function.
///
/// Note that [`Value`]s passed to and returned from this UDF do not have associated metadata.
pub func: fn(&[Value], &dyn Catalog, &dyn Retriever) -> Value,
pub func: Arc<UdfFunction>,
}

impl fmt::Debug for Udf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[udf]")
}
}

impl Udf {
pub fn call(
pub async fn call(
&self,
values: &[Value],
catalog: &dyn Catalog,
retriever: &dyn Retriever,
catalog: Arc<dyn Catalog>,
retriever: Arc<dyn Retriever>,
) -> Value {
(self.func)(values, catalog, retriever)
(self.func)(values.to_vec(), catalog, retriever).await
}
}

Expand Down
46 changes: 26 additions & 20 deletions optd/src/dsl/engine/eval/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ impl<O: Clone + Send + 'static> Engine<O> {
Arc::new(move |arg_values| {
Box::pin(capture!([udf, catalog, deriver, k], async move {
// Call the UDF with the argument values.
let result = udf.call(&arg_values, catalog.as_ref(), deriver.as_ref());
let result = udf.call(&arg_values, catalog, deriver).await;

// Pass the result to the continuation.
k(result).await
Expand Down Expand Up @@ -1143,18 +1143,22 @@ mod tests {

// Define a Rust UDF that calculates the sum of array elements
let sum_function = Value::new(CoreData::Function(FunKind::Udf(Udf {
func: |args, _catalog, _deriver| match &args[0].data {
CoreData::Array(elements) => {
let mut sum = 0;
for elem in elements {
if let CoreData::Literal(Literal::Int64(value)) = &elem.data {
sum += value;
func: Arc::new(|args, _catalog, _retriever| {
Box::pin(async move {
match &args[0].data {
CoreData::Array(elements) => {
let mut sum: i64 = 0;
for elem in elements {
if let CoreData::Literal(Literal::Int64(value)) = &elem.data {
sum += value;
}
}
Value::new(CoreData::Literal(Literal::Int64(sum)))
}
_ => panic!("Expected array argument"),
}
Value::new(CoreData::Literal(Literal::Int64(sum)))
}
_ => panic!("Expected array argument"),
},
})
}),
})));

ctx.bind("sum".to_string(), sum_function);
Expand Down Expand Up @@ -1271,16 +1275,18 @@ mod tests {
ctx.bind(
"get".to_string(),
Value::new(CoreData::Function(FunKind::Udf(Udf {
func: |args, _catalog, _deriver| {
if args.len() != 2 {
panic!("get function requires 2 arguments");
}
func: Arc::new(|args, _catalog, _retriever| {
Box::pin(async move {
if args.len() != 2 {
panic!("get function requires 2 arguments");
}

match &args[0].data {
CoreData::Map(map) => map.get(&args[1]),
_ => panic!("First argument must be a map"),
}
},
match &args[0].data {
CoreData::Map(map) => map.get(&args[1]),
_ => panic!("First argument must be a map"),
}
})
}),
}))),
);

Expand Down
Loading