diff --git a/Cargo.lock b/Cargo.lock index 3c784f24..0da7669e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1540,6 +1540,7 @@ dependencies = [ "mongodb", "mongodb-agent-common", "mongodb-support", + "proptest", "serde", "serde_json", "these", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index caafce44..e96b70cc 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -18,3 +18,6 @@ serde_json = { version = "1.0.113", features = ["raw_value"] } thiserror = "1.0.57" tokio = { version = "1.36.0", features = ["full"] } these = "2.0.0" + +[dev-dependencies] +proptest = "1" \ No newline at end of file diff --git a/crates/cli/proptest-regressions/introspection/type_unification.txt b/crates/cli/proptest-regressions/introspection/type_unification.txt new file mode 100644 index 00000000..77460802 --- /dev/null +++ b/crates/cli/proptest-regressions/introspection/type_unification.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 45028da671f86113f58b8ec86468ec593b8e33488eecb154950098054ee15675 # shrinks to c = TypeUnificationContext { object_type_name: "", field_name: "" }, t = ArrayOf(Scalar(Null)) +cc e7368f0503761c52e2ce47fa2e64454ecd063f2e019c511759162d0be049e665 # shrinks to c = TypeUnificationContext { object_type_name: "", field_name: "" }, t = Nullable(Nullable(Scalar(Double))) +cc bd6f440b7ea7e51d8c369e802b8cbfbc0c3f140c01cd6b54d9c61e6d84d7e77d # shrinks to c = TypeUnificationContext { object_type_name: "", field_name: "" }, t = Nullable(Scalar(Null)) +cc d16279848ea51c4be376436423d342afd077a737efcab03ba2d29d5a0dee9df2 # shrinks to left = {"": Scalar(Double)}, right = {"": Scalar(Decimal)}, shared = {} +cc fc85c97eeccb12e144f548fe65fd262d4e7b1ec9c799be69fd30535aa032e26d # shrinks to ta = Nullable(Scalar(Null)), tb = Nullable(Scalar(Undefined)) diff --git a/crates/cli/src/introspection/sampling.rs b/crates/cli/src/introspection/sampling.rs index 8e86fb77..ca2e0e32 100644 --- a/crates/cli/src/introspection/sampling.rs +++ b/crates/cli/src/introspection/sampling.rs @@ -29,7 +29,7 @@ pub async fn sample_schema_from_db( let collection_name = collection_spec.name; let collection_schema = sample_schema_from_collection(&collection_name, sample_size, config).await?; - schema = unify_schema(schema, collection_schema)?; + schema = unify_schema(schema, collection_schema); } Ok(schema) } @@ -161,3 +161,102 @@ fn make_field_type( Bson::DbPointer(_) => scalar(DbPointer), } } + +#[cfg(test)] +mod tests { + use configuration::schema::{ObjectField, ObjectType, Type}; + use mongodb::bson::doc; + use mongodb_support::BsonScalarType; + + use crate::introspection::type_unification::{TypeUnificationContext, TypeUnificationError}; + + use super::make_object_type; + + #[test] + fn simple_doc() -> Result<(), anyhow::Error> { + let object_name = "foo"; + let doc = doc! {"my_int": 1, "my_string": "two"}; + let result = make_object_type(object_name, &doc); + + let expected = Ok(vec![ObjectType { + name: object_name.to_owned(), + fields: vec![ + ObjectField { + name: "my_int".to_owned(), + r#type: Type::Scalar(BsonScalarType::Int), + description: None, + }, + ObjectField { + name: "my_string".to_owned(), + r#type: Type::Scalar(BsonScalarType::String), + description: None, + }, + ], + description: None, + }]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn array_of_objects() -> Result<(), anyhow::Error> { + let object_name = "foo"; + let doc = doc! {"my_array": [{"foo": 42, "bar": ""}, {"bar": "wut", "baz": 3.77}]}; + let result = make_object_type(object_name, &doc); + + let expected = Ok(vec![ + ObjectType { + name: "foo_my_array".to_owned(), + fields: vec![ + ObjectField { + name: "foo".to_owned(), + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Int))), + description: None, + }, + ObjectField { + name: "bar".to_owned(), + r#type: Type::Scalar(BsonScalarType::String), + description: None, + }, + ObjectField { + name: "baz".to_owned(), + r#type: Type::Nullable(Box::new(Type::Scalar(BsonScalarType::Double))), + description: None, + }, + ], + description: None, + }, + ObjectType { + name: object_name.to_owned(), + fields: vec![ObjectField { + name: "my_array".to_owned(), + r#type: Type::ArrayOf(Box::new(Type::Object("foo_my_array".to_owned()))), + description: None, + }], + description: None, + }, + ]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn non_unifiable_array_of_objects() -> Result<(), anyhow::Error> { + let object_name = "foo"; + let doc = doc! {"my_array": [{"foo": 42, "bar": ""}, {"bar": 17, "baz": 3.77}]}; + let result = make_object_type(object_name, &doc); + + let expected = Err(TypeUnificationError::ScalarType( + TypeUnificationContext::new("foo_my_array", "bar"), + BsonScalarType::String, + BsonScalarType::Int, + )); + assert_eq!(expected, result); + + Ok(()) + } +} diff --git a/crates/cli/src/introspection/type_unification.rs b/crates/cli/src/introspection/type_unification.rs index b435e54e..b3ac3179 100644 --- a/crates/cli/src/introspection/type_unification.rs +++ b/crates/cli/src/introspection/type_unification.rs @@ -17,7 +17,7 @@ use std::{ }; use thiserror::Error; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct TypeUnificationContext { object_type_name: String, field_name: String, @@ -42,7 +42,7 @@ impl Display for TypeUnificationContext { } } -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq, Eq)] pub enum TypeUnificationError { ScalarType(TypeUnificationContext, BsonScalarType, BsonScalarType), ObjectType(String, String), @@ -84,6 +84,17 @@ pub fn unify_type( (Type::Scalar(Undefined), type_b) => Ok(type_b), (type_a, Type::Scalar(Undefined)) => Ok(type_a), + // A Nullable type will unify with another type iff the underlying type is unifiable. + // The resulting type will be Nullable. + (Type::Nullable(nullable_type_a), type_b) => { + let result_type = unify_type(context, *nullable_type_a, type_b)?; + Ok(make_nullable(result_type)) + } + (type_a, Type::Nullable(nullable_type_b)) => { + let result_type = unify_type(context, type_a, *nullable_type_b)?; + Ok(make_nullable(result_type)) + } + // Union of any type with Null is the Nullable version of that type (Type::Scalar(Null), type_b) => Ok(make_nullable(type_b)), (type_a, Type::Scalar(Null)) => Ok(make_nullable(type_a)), @@ -114,25 +125,29 @@ pub fn unify_type( Ok(Type::ArrayOf(Box::new(elem_type))) } - // A Nullable type will unify with another type iff the underlying type is unifiable. - // The resulting type will be Nullable. - (Type::Nullable(nullable_type_a), type_b) => { - let result_type = unify_type(context, *nullable_type_a, type_b)?; - Ok(make_nullable(result_type)) - } - (type_a, Type::Nullable(nullable_type_b)) => { - let result_type = unify_type(context, type_a, *nullable_type_b)?; - Ok(make_nullable(result_type)) - } - // Anything else is a unification error. (type_a, type_b) => Err(TypeUnificationError::TypeKind(type_a, type_b)), } + .map(normalize_type) +} + +fn normalize_type(t: Type) -> Type { + match t { + Type::Scalar(s) => Type::Scalar(s), + Type::Object(o) => Type::Object(o), + Type::ArrayOf(a) => Type::ArrayOf(Box::new(normalize_type(*a))), + Type::Nullable(n) => match *n { + Type::Scalar(BsonScalarType::Null) => Type::Scalar(BsonScalarType::Null), + Type::Nullable(t) => normalize_type(Type::Nullable(t)), + t => Type::Nullable(Box::new(normalize_type(t))), + }, + } } fn make_nullable(t: Type) -> Type { match t { Type::Nullable(t) => Type::Nullable(t), + Type::Scalar(BsonScalarType::Null) => Type::Scalar(BsonScalarType::Null), t => Type::Nullable(Box::new(t)), } } @@ -177,7 +192,7 @@ fn unify_object_type( }) } -/// The types of two `ObjectField`s. +/// Unify the types of two `ObjectField`s. /// If the types are not unifiable then return an error. fn unify_object_field( object_type_name: &str, @@ -214,7 +229,7 @@ pub fn unify_object_types( } /// Unify two schemas. Assumes that the schemas describe mutually exclusive sets of collections. -pub fn unify_schema(schema_a: Schema, schema_b: Schema) -> TypeUnificationResult { +pub fn unify_schema(schema_a: Schema, schema_b: Schema) -> Schema { let collections = schema_a .collections .into_iter() @@ -225,8 +240,202 @@ pub fn unify_schema(schema_a: Schema, schema_b: Schema) -> TypeUnificationResult .into_iter() .chain(schema_b.object_types) .collect(); - Ok(Schema { + Schema { collections, object_types, - }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::{HashMap, HashSet}; + + use super::{ + normalize_type, unify_object_type, unify_type, ObjectField, ObjectType, + TypeUnificationContext, TypeUnificationError, + }; + use configuration::schema::Type; + use mongodb_support::BsonScalarType; + use proptest::{collection::hash_map, prelude::*}; + + #[test] + fn test_unify_scalar() -> Result<(), anyhow::Error> { + let context = TypeUnificationContext::new("foo", "bar"); + let expected = Ok(Type::Scalar(BsonScalarType::Int)); + let actual = unify_type( + context, + Type::Scalar(BsonScalarType::Int), + Type::Scalar(BsonScalarType::Int), + ); + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_unify_scalar_error() -> Result<(), anyhow::Error> { + let context = TypeUnificationContext::new("foo", "bar"); + let expected = Err(TypeUnificationError::ScalarType( + context.clone(), + BsonScalarType::Int, + BsonScalarType::String, + )); + let actual = unify_type( + context, + Type::Scalar(BsonScalarType::Int), + Type::Scalar(BsonScalarType::String), + ); + assert_eq!(expected, actual); + Ok(()) + } + + prop_compose! { + fn arb_type_unification_context()(object_type_name in any::(), field_name in any::()) -> TypeUnificationContext { + TypeUnificationContext { object_type_name, field_name } + } + } + + fn arb_bson_scalar_type() -> impl Strategy { + prop_oneof![ + Just(BsonScalarType::Double), + Just(BsonScalarType::Decimal), + Just(BsonScalarType::Int), + Just(BsonScalarType::Long), + Just(BsonScalarType::String), + Just(BsonScalarType::Date), + Just(BsonScalarType::Timestamp), + Just(BsonScalarType::BinData), + Just(BsonScalarType::ObjectId), + Just(BsonScalarType::Bool), + Just(BsonScalarType::Null), + Just(BsonScalarType::Regex), + Just(BsonScalarType::Javascript), + Just(BsonScalarType::JavascriptWithScope), + Just(BsonScalarType::MinKey), + Just(BsonScalarType::MaxKey), + Just(BsonScalarType::Undefined), + Just(BsonScalarType::DbPointer), + Just(BsonScalarType::Symbol), + ] + } + + fn arb_type() -> impl Strategy { + let leaf = prop_oneof![ + arb_bson_scalar_type().prop_map(Type::Scalar), + any::().prop_map(Type::Object) + ]; + leaf.prop_recursive(3, 10, 10, |inner| { + prop_oneof![ + inner.clone().prop_map(|t| Type::ArrayOf(Box::new(t))), + inner.prop_map(|t| Type::Nullable(Box::new(t))) + ] + }) + } + + fn swap_error(err: TypeUnificationError) -> TypeUnificationError { + match err { + TypeUnificationError::ScalarType(c, a, b) => TypeUnificationError::ScalarType(c, b, a), + TypeUnificationError::ObjectType(a, b) => TypeUnificationError::ObjectType(b, a), + TypeUnificationError::TypeKind(a, b) => TypeUnificationError::TypeKind(b, a), + } + } + + fn is_nullable(t: &Type) -> bool { + matches!(t, Type::Scalar(BsonScalarType::Null) | Type::Nullable(_)) + } + + proptest! { + #[test] + fn test_type_unifies_with_itself_and_normalizes(t in arb_type()) { + let c = TypeUnificationContext::new("", ""); + let u = unify_type(c, t.clone(), t.clone()); + prop_assert_eq!(Ok(normalize_type(t)), u) + } + } + + proptest! { + #[test] + fn test_unify_type_is_commutative(ta in arb_type(), tb in arb_type()) { + let c = TypeUnificationContext::new("", ""); + let result_a_b = unify_type(c.clone(), ta.clone(), tb.clone()); + let result_b_a = unify_type(c, tb, ta); + prop_assert_eq!(result_a_b, result_b_a.map_err(swap_error)) + } + } + + proptest! { + #[test] + fn test_unify_type_is_associative(ta in arb_type(), tb in arb_type(), tc in arb_type()) { + let c = TypeUnificationContext::new("", ""); + let result_lr = unify_type(c.clone(), ta.clone(), tb.clone()).and_then(|tab| unify_type(c.clone(), tab, tc.clone())); + let result_rl = unify_type(c.clone(), tb, tc).and_then(|tbc| unify_type(c, ta, tbc)); + if let Ok(tlr) = result_lr { + prop_assert_eq!(Ok(tlr), result_rl) + } else if result_rl.is_ok() { + panic!("Err, Ok") + } + } + } + + proptest! { + #[test] + fn test_undefined_is_left_identity(t in arb_type()) { + let c = TypeUnificationContext::new("", ""); + let u = unify_type(c, Type::Scalar(BsonScalarType::Undefined), t.clone()); + prop_assert_eq!(Ok(normalize_type(t)), u) + } + } + + proptest! { + #[test] + fn test_undefined_is_right_identity(t in arb_type()) { + let c = TypeUnificationContext::new("", ""); + let u = unify_type(c, t.clone(), Type::Scalar(BsonScalarType::Undefined)); + prop_assert_eq!(Ok(normalize_type(t)), u) + } + } + + fn type_hash_map() -> impl Strategy> { + hash_map(".*", arb_type(), 0..10) + } + + proptest! { + #[test] + fn test_object_type_unification(left in type_hash_map(), right in type_hash_map(), shared in type_hash_map()) { + let mut left_fields = left.clone(); + let mut right_fields: HashMap = right.clone().into_iter().filter(|(k, _)| !left_fields.contains_key(k)).collect(); + for (k, v) in shared.clone() { + left_fields.insert(k.clone(), v.clone()); + right_fields.insert(k, v); + } + + let name = "foo"; + let left_object = ObjectType { + name: name.to_owned(), + fields: left_fields.into_iter().map(|(k, v)| ObjectField{name: k, r#type: v, description: None}).collect(), + description: None + }; + let right_object = ObjectType { + name: name.to_owned(), + fields: right_fields.into_iter().map(|(k, v)| ObjectField{name: k, r#type: v, description: None}).collect(), + description: None + }; + let result = unify_object_type(left_object, right_object); + match result { + Err(err) => panic!("Got error result {err}"), + Ok(ot) => { + for field in &ot.fields { + // Any fields not shared between the two input types should be nullable. + if !shared.contains_key(&field.name) { + assert!(is_nullable(&field.r#type), "Found a non-shared field that is not nullable") + } + } + + // All input fields must appear in the result. + let fields: HashSet = ot.fields.into_iter().map(|f| f.name).collect(); + assert!(left.into_keys().chain(right.into_keys()).chain(shared.into_keys()).all(|k| fields.contains(&k)), + "Missing field in result type") + } + } + } + } }