diff --git a/trustfall/examples/hackernews/adapter.rs b/trustfall/examples/hackernews/adapter.rs index 784b778c..77803e98 100644 --- a/trustfall/examples/hackernews/adapter.rs +++ b/trustfall/examples/hackernews/adapter.rs @@ -5,18 +5,17 @@ use std::collections::HashSet; use hn_api::{types::Item, HnClient}; use trustfall::{ provider::{ - field_property, resolve_coercion_with, resolve_neighbors_with, resolve_property_with, - BasicAdapter, ContextIterator, ContextOutcomeIterator, EdgeParameters, VertexIterator, + field_property, resolve_coercion_using_schema, resolve_neighbors_with, + resolve_property_with, BasicAdapter, ContextIterator, ContextOutcomeIterator, + EdgeParameters, VertexIterator, }, - FieldValue, Schema, + FieldValue, }; -use crate::vertex::Vertex; +use crate::{vertex::Vertex, SCHEMA}; lazy_static! { static ref CLIENT: HnClient = HnClient::init().expect("HnClient instantiated"); - static ref SCHEMA: Schema = - Schema::parse(include_str!("hackernews.graphql")).expect("valid schema"); } #[derive(Debug, Clone, Default)] @@ -360,18 +359,9 @@ impl BasicAdapter<'static> for HackerNewsAdapter { fn resolve_coercion( &self, contexts: ContextIterator<'static, Self::Vertex>, - type_name: &str, + _type_name: &str, coerce_to_type: &str, ) -> ContextOutcomeIterator<'static, Self::Vertex, bool> { - match (type_name, coerce_to_type) { - ("Item", "Job") => resolve_coercion_with(contexts, |v| v.as_job().is_some()), - ("Item", "Story") => resolve_coercion_with(contexts, |v| v.as_story().is_some()), - ("Item", "Comment") => resolve_coercion_with(contexts, |v| v.as_comment().is_some()), - ("Item", "Poll") => resolve_coercion_with(contexts, |v| v.as_poll().is_some()), - ("Item", "PollOption") => { - resolve_coercion_with(contexts, |v| v.as_poll_option().is_some()) - } - _ => unreachable!(), - } + resolve_coercion_using_schema(contexts, &SCHEMA, coerce_to_type) } } diff --git a/trustfall/src/lib.rs b/trustfall/src/lib.rs index 6ab10d4b..4696cf65 100644 --- a/trustfall/src/lib.rs +++ b/trustfall/src/lib.rs @@ -47,7 +47,8 @@ pub mod provider { // Helpers for common operations when building adapters. pub use trustfall_core::interpreter::helpers::{ - resolve_coercion_with, resolve_neighbors_with, resolve_property_with, + resolve_coercion_using_schema, resolve_coercion_with, resolve_neighbors_with, + resolve_property_with, resolve_typename, }; pub use trustfall_core::{accessor_property, field_property}; diff --git a/trustfall_core/src/interpreter/helpers.rs b/trustfall_core/src/interpreter/helpers.rs index 55b2189a..e50caf33 100644 --- a/trustfall_core/src/interpreter/helpers.rs +++ b/trustfall_core/src/interpreter/helpers.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::{collections::BTreeSet, fmt::Debug}; use crate::{ir::FieldValue, schema::Schema}; @@ -71,6 +71,35 @@ pub fn resolve_coercion_with<'vertex, Vertex: Debug + Clone + 'vertex>( })) } +/// Helper for implementing [`BasicAdapter::resolve_coercion`] and equivalents. +/// +/// Uses the schema to look up all the subtypes of the coercion target type. +/// Then uses the [`Typename`] trait to look up the exact runtime type of each vertex +/// and checks if it's equal or a subtype of the coercion target type. +/// +/// [`BasicAdapter::resolve_coercion`]: super::basic_adapter::BasicAdapter::resolve_coercion +pub fn resolve_coercion_using_schema<'vertex, Vertex: Debug + Clone + Typename + 'vertex>( + contexts: ContextIterator<'vertex, Vertex>, + schema: &'vertex Schema, + coerce_to_type: &str, +) -> ContextOutcomeIterator<'vertex, Vertex, bool> { + // If the vertex's typename is one of these types, + // then the coercion's result is `true`. + let subtypes: BTreeSet<_> = schema + .subtypes(coerce_to_type) + .unwrap_or_else(|| panic!("type {coerce_to_type} is not part of this schema")) + .collect(); + + Box::new(contexts.map(move |ctx| match ctx.active_vertex.as_ref() { + None => (ctx, false), + Some(vertex) => { + let typename = vertex.typename(); + let can_coerce = subtypes.contains(typename); + (ctx, can_coerce) + } + })) +} + /// Helper for making property resolver functions based on fields. /// /// Generally used with [`resolve_property_with`]. diff --git a/trustfall_core/src/schema/mod.rs b/trustfall_core/src/schema/mod.rs index 10ef4d71..6b022ff7 100644 --- a/trustfall_core/src/schema/mod.rs +++ b/trustfall_core/src/schema/mod.rs @@ -229,10 +229,10 @@ directive @transform(op: String!) on FIELD /// If the named type is defined, iterate through the names of its subtypes including itself. /// Otherwise, return None. - pub fn subtypes<'slf, 'a: 'slf>( + pub fn subtypes<'a, 'slf: 'a>( &'slf self, type_name: &'a str, - ) -> Option + 'slf> { + ) -> Option + 'a> { if !self.vertex_types.contains_key(type_name) { return None; }