Skip to content

Commit

Permalink
Merge pull request from GHSA-4rx6-g5vg-5f3j
Browse files Browse the repository at this point in the history
* Replace recursions with heap allocations

* Some corrections [skip ci]

* Add recursive nested fragments test case

* Docs and small corrections

* Corrections

Co-authored-by: Kai Ren <tyranron@gmail.com>
  • Loading branch information
ilslv and tyranron committed Jul 28, 2022
1 parent 6d6c71f commit 2b609ee
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 101 deletions.
68 changes: 43 additions & 25 deletions juniper/src/validation/rules/no_fragment_cycles.rs
Expand Up @@ -7,19 +7,6 @@ use crate::{
value::ScalarValue,
};

pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}

struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
path_indices: HashMap<&'a str, usize>,
errors: Vec<RuleError>,
}

pub fn factory<'a>() -> NoFragmentCycles<'a> {
NoFragmentCycles {
current_fragment: None,
Expand All @@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> {
}
}

pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}

impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a>
where
S: ScalarValue,
Expand All @@ -38,14 +31,12 @@ where
let mut detector = CycleDetector {
visited: HashSet::new(),
spreads: &self.spreads,
path_indices: HashMap::new(),
errors: Vec::new(),
};

for frag in &self.fragment_order {
if !detector.visited.contains(frag) {
let mut path = Vec::new();
detector.detect_from(frag, &mut path);
detector.detect_from(frag);
}
}

Expand Down Expand Up @@ -91,19 +82,46 @@ where
}
}

type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>);

struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
errors: Vec<RuleError>,
}

impl<'a> CycleDetector<'a> {
fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) {
fn detect_from(&mut self, from: &'a str) {
let mut to_visit = Vec::new();
to_visit.push((from, Vec::new(), HashMap::new()));

while let Some((from, path, path_indices)) = to_visit.pop() {
to_visit.extend(self.detect_from_inner(from, path, path_indices));
}
}

/// This function should be called only inside [`Self::detect_from()`], as
/// it's a recursive function using heap instead of a stack. So, instead of
/// the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::detect_from()`].
fn detect_from_inner(
&mut self,
from: &'a str,
path: Vec<&'a Spanning<&'a str>>,
mut path_indices: HashMap<&'a str, usize>,
) -> Vec<CycleDetectorState<'a>> {
self.visited.insert(from);

if !self.spreads.contains_key(from) {
return;
return Vec::new();
}

self.path_indices.insert(from, path.len());
path_indices.insert(from, path.len());

let mut to_visit = Vec::new();
for node in &self.spreads[from] {
let name = &node.item;
let index = self.path_indices.get(name).cloned();
let name = node.item;
let index = path_indices.get(name).cloned();

if let Some(index) = index {
let err_pos = if index < path.len() {
Expand All @@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> {

self.errors
.push(RuleError::new(&error_message(name), &[err_pos.start]));
} else if !self.visited.contains(name) {
} else {
let mut path = path.clone();
path.push(node);
self.detect_from(name, path);
path.pop();
to_visit.push((name, path, path_indices.clone()));
}
}

self.path_indices.remove(from);
to_visit
}
}

Expand Down
48 changes: 35 additions & 13 deletions juniper/src/validation/rules/no_undefined_variables.rs
Expand Up @@ -12,13 +12,6 @@ pub enum Scope<'a> {
Fragment(&'a str),
}

pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

pub fn factory<'a>() -> NoUndefinedVariables<'a> {
NoUndefinedVariables {
defined_variables: HashMap::new(),
Expand All @@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> {
}
}

pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

impl<'a> NoUndefinedVariables<'a> {
fn find_undef_vars(
&'a self,
Expand All @@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> {
unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>,
) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) =
self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited)
{
to_visit.push(spreads);
}
}
}
}

/// This function should be called only inside [`Self::find_undef_vars()`],
/// as it's a recursive function using heap instead of a stack. So, instead
/// of the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::find_undef_vars()`].
fn find_undef_vars_inner(
&'a self,
scope: &Scope<'a>,
defined: &HashSet<&'a str>,
unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'a Vec<&'a str>> {
if visited.contains(scope) {
return;
return None;
}

visited.insert(scope.clone());
Expand All @@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> {
}
}

if let Some(spreads) = self.spreads.get(scope) {
for spread in spreads {
self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited);
}
}
self.spreads.get(scope)
}
}

Expand Down
45 changes: 30 additions & 15 deletions juniper/src/validation/rules/no_unused_fragments.rs
Expand Up @@ -13,12 +13,6 @@ pub enum Scope<'a> {
Fragment(&'a str),
}

pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}

pub fn factory<'a>() -> NoUnusedFragments<'a> {
NoUnusedFragments {
spreads: HashMap::new(),
Expand All @@ -27,22 +21,43 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> {
}
}

pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}

impl<'a> NoUnusedFragments<'a> {
fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
fn find_reachable_fragments(&'a self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
let mut to_visit = Vec::new();
if let Scope::Fragment(name) = *from {
if result.contains(name) {
return;
} else {
result.insert(name);
}
to_visit.push(name);
}

if let Some(spreads) = self.spreads.get(from) {
for spread in spreads {
self.find_reachable_fragments(&Scope::Fragment(spread), result)
while let Some(from) = to_visit.pop() {
if let Some(next) = self.find_reachable_fragments_inner(from, result) {
to_visit.extend(next);
}
}
}

/// This function should be called only inside
/// [`Self::find_reachable_fragments()`], as it's a recursive function using
/// heap instead of a stack. So, instead of the recursive call, we return a
/// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`].
fn find_reachable_fragments_inner(
&'a self,
from: &'a str,
result: &mut HashSet<&'a str>,
) -> Option<&'a Vec<&'a str>> {
if result.contains(from) {
return None;
} else {
result.insert(from);
}

self.spreads.get(&Scope::Fragment(from))
}
}

impl<'a, S> Visitor<'a, S> for NoUnusedFragments<'a>
Expand Down
50 changes: 36 additions & 14 deletions juniper/src/validation/rules/no_unused_variables.rs
Expand Up @@ -12,13 +12,6 @@ pub enum Scope<'a> {
Fragment(&'a str),
}

pub struct NoUnusedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

pub fn factory<'a>() -> NoUnusedVariables<'a> {
NoUnusedVariables {
defined_variables: HashMap::new(),
Expand All @@ -28,16 +21,49 @@ pub fn factory<'a>() -> NoUnusedVariables<'a> {
}
}

pub struct NoUnusedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

impl<'a> NoUnusedVariables<'a> {
fn find_used_vars(
&self,
&'a self,
from: &Scope<'a>,
defined: &HashSet<&'a str>,
used: &mut HashSet<&'a str>,
visited: &mut HashSet<Scope<'a>>,
) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.find_used_vars_inner(from, defined, used, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) =
self.find_used_vars_inner(&Scope::Fragment(spread), defined, used, visited)
{
to_visit.push(spreads);
}
}
}
}

/// This function should be called only inside [`Self::find_used_vars()`],
/// as it's a recursive function using heap instead of a stack. So, instead
/// of the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::find_used_vars()`].
fn find_used_vars_inner(
&'a self,
from: &Scope<'a>,
defined: &HashSet<&'a str>,
used: &mut HashSet<&'a str>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'a Vec<&'a str>> {
if visited.contains(from) {
return;
return None;
}

visited.insert(from.clone());
Expand All @@ -50,11 +76,7 @@ impl<'a> NoUnusedVariables<'a> {
}
}

if let Some(spreads) = self.spreads.get(from) {
for spread in spreads {
self.find_used_vars(&Scope::Fragment(spread), defined, used, visited);
}
}
self.spreads.get(from)
}
}

Expand Down

0 comments on commit 2b609ee

Please sign in to comment.