Skip to content

Commit

Permalink
Enable more rust DB parameter tests and fix union tag in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Jan 22, 2024
1 parent 69481a0 commit 8615350
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,6 @@ std::unique_ptr<LogicalType> LogicalTypeUtils::parseMapType(const std::string& t

std::unique_ptr<LogicalType> LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) {
auto unionFields = parseStructTypeInfo(trimmedStr);
auto unionTagField = StructField(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
unionFields.insert(unionFields.begin(), std::move(unionTagField));
return LogicalType::UNION(std::move(unionFields));
}

Expand Down Expand Up @@ -918,6 +915,9 @@ std::unique_ptr<LogicalType> LogicalType::RDF_VARIANT(std::unique_ptr<StructType
}

std::unique_ptr<LogicalType> LogicalType::UNION(std::vector<StructField>&& fields) {
// TODO(Ziy): Use UINT8 to represent tag value.
fields.insert(fields.begin(), StructField(UnionType::TAG_FIELD_NAME,
std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE)));
return std::unique_ptr<LogicalType>(
new LogicalType(LogicalTypeID::UNION, std::make_unique<StructTypeInfo>(std::move(fields))));
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/value/nested.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ uint32_t NestedVal::getChildrenSize(const Value* val) {

Value* NestedVal::getChildVal(const Value* val, uint32_t idx) {
if (idx > val->childrenSize) {
throw RuntimeException("NestedVal::getChildPointer index out of bound.");
throw RuntimeException("NestedVal::getChildVal index out of bound.");

Check warning on line 15 in src/common/types/value/nested.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/value/nested.cpp#L15

Added line #L15 was not covered by tests
}
return val->children[idx].get();
}
Expand Down
3 changes: 0 additions & 3 deletions src/function/vector_union_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ std::unique_ptr<FunctionBindData> UnionValueFunction::bindFunc(
const binder::expression_vector& arguments, kuzu::function::Function* /*function*/) {
KU_ASSERT(arguments.size() == 1);
std::vector<StructField> fields;
// TODO(Ziy): Use UINT8 to represent tag value.
fields.emplace_back(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
if (arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[0], LogicalType(LogicalTypeID::STRING));
Expand Down
3 changes: 3 additions & 0 deletions tools/rust_api/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ fn build_bundled_cmake() -> Result<Vec<PathBuf>, Box<dyn std::error::Error>> {
.define("BUILD_SHELL", "OFF")
.define("BUILD_SINGLE_FILE_HEADER", "OFF")
.define("AUTO_UPDATE_GRAMMAR", "OFF");
if get_target() == "debug" {
build.define("ENABLE_RUNTIME_CHECKS", "ON");
}
if cfg!(windows) {
build.generator("Ninja");
build.cxxflag("/EHsc");
Expand Down
2 changes: 0 additions & 2 deletions tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
LogicalType::Union { types } => {
let mut builder = ffi::create_type_list();
let mut names = vec![];
names.push("tag".to_string());
builder.pin_mut().insert((&LogicalType::Int64).into());
for (name, typ) in types {
names.push(name.clone());
builder.pin_mut().insert(typ.into());
Expand Down
17 changes: 6 additions & 11 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ impl TryFrom<&ffi::Value> for Value {
} else {
unreachable!()
};
debug_assert!(ffi::value_get_children_size(value) == 1);
let value: Value = ffi::value_get_child(value, 0).try_into()?;
Ok(Value::Union {
types,
Expand Down Expand Up @@ -778,11 +779,7 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
let typ: LogicalType = LogicalType::Struct {
fields: value
.iter()
.map(|(name, value)| {
// Unwrap is safe since we already converted when inserting into the
// builder
(name.clone(), Into::<LogicalType>::into(value))
})
.map(|(name, value)| (name.clone(), Into::<LogicalType>::into(value)))
.collect(),
};

Expand Down Expand Up @@ -1102,14 +1099,14 @@ mod tests {

database_tests! {
// Passing these values as arguments is not yet implemented in kuzu:
// db_var_list_string: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[]",
// db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
// db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
// db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
// db_union: Value::Union {
// types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
// value: Box::new(Value::Int8(-127))
// }, "UNION(Num INT8, duration INTERVAL)",
db_var_list_string: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[]",
db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
db_struct:
Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
"STRUCT(item STRING, count INT32)",
Expand Down Expand Up @@ -1234,9 +1231,7 @@ mod tests {
Ok(())
}

// TODO: This should be added back after we fix create rel.
#[test]
#[ignore]
fn test_recursive_rel() -> Result<()> {
let temp_dir = tempfile::TempDir::new()?;
let db = Database::new(temp_dir.path(), SystemConfig::default())?;
Expand Down

0 comments on commit 8615350

Please sign in to comment.