Skip to content
Permalink
Browse files Browse the repository at this point in the history
Co-authored-by: ilslv <ilya.solovyiov@gmail.com>
  • Loading branch information
tyranron and ilslv committed Jul 28, 2022
1 parent c650713 commit 8d28cdb
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 102 deletions.
1 change: 1 addition & 0 deletions integration_tests/juniper_tests/Cargo.toml
Expand Up @@ -7,6 +7,7 @@ publish = false
[dependencies]
derive_more = "0.99"
futures = "0.3"
itertools = "0.10"
juniper = { path = "../../juniper" }
juniper_subscriptions = { path = "../../juniper_subscriptions" }

Expand Down
56 changes: 56 additions & 0 deletions integration_tests/juniper_tests/src/cve_2022_31173.rs
@@ -0,0 +1,56 @@
//! Checks that long looping chain of fragments doesn't cause a stack overflow.
//!
//! ```graphql
//! # Fragment loop example
//! query {
//! ...a
//! }
//!
//! fragment a on Query {
//! ...b
//! }
//!
//! fragment b on Query {
//! ...a
//! }
//! ```

use std::iter;

use itertools::Itertools as _;
use juniper::{graphql_object, EmptyMutation, EmptySubscription, Variables};

struct Query;

#[graphql_object]
impl Query {
fn dummy() -> bool {
false
}
}

type Schema = juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>;

#[tokio::test]
async fn test() {
const PERM: &str = "abcefghijk";
const CIRCLE_SIZE: usize = 7500;

let query = iter::once(format!("query {{ ...{PERM} }} "))
.chain(
PERM.chars()
.permutations(PERM.len())
.map(|vec| vec.into_iter().collect::<String>())
.take(CIRCLE_SIZE)
.collect::<Vec<_>>()
.into_iter()
.circular_tuple_windows::<(_, _)>()
.map(|(cur, next)| format!("fragment {cur} on Query {{ ...{next} }} ")),
)
.collect::<String>();

let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
let _ = juniper::execute(&query, None, &schema, &Variables::new(), &())
.await
.unwrap_err();
}
2 changes: 2 additions & 0 deletions integration_tests/juniper_tests/src/lib.rs
Expand Up @@ -7,6 +7,8 @@ mod codegen;
#[cfg(test)]
mod custom_scalar;
#[cfg(test)]
mod cve_2022_31173;
#[cfg(test)]
mod explicit_null;
#[cfg(test)]
mod infallible_as_field_error;
Expand Down
1 change: 1 addition & 0 deletions juniper/CHANGELOG.md
@@ -1,5 +1,6 @@
# master

- Fix [CVE-2022-31173](https://github.com/graphql-rust/juniper/security/advisories/GHSA-4rx6-g5vg-5f3j).
- Fix incorrect error when explicit `null` provided for `null`able list input parameter. ([#1086](https://github.com/graphql-rust/juniper/pull/1086))

# [[0.15.9] 2022-02-02](https://github.com/graphql-rust/juniper/releases/tag/juniper-v0.15.9)
Expand Down
68 changes: 43 additions & 25 deletions juniper/src/validation/rules/no_fragment_cycles.rs
Expand Up @@ -7,19 +7,6 @@ use crate::{
value::ScalarValue,
};

pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}

struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
path_indices: HashMap<&'a str, usize>,
errors: Vec<RuleError>,
}

pub fn factory<'a>() -> NoFragmentCycles<'a> {
NoFragmentCycles {
current_fragment: None,
Expand All @@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> {
}
}

pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}

impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a>
where
S: ScalarValue,
Expand All @@ -38,14 +31,12 @@ where
let mut detector = CycleDetector {
visited: HashSet::new(),
spreads: &self.spreads,
path_indices: HashMap::new(),
errors: Vec::new(),
};

for frag in &self.fragment_order {
if !detector.visited.contains(frag) {
let mut path = Vec::new();
detector.detect_from(frag, &mut path);
detector.detect_from(frag);
}
}

Expand Down Expand Up @@ -91,19 +82,46 @@ where
}
}

type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>);

struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
errors: Vec<RuleError>,
}

impl<'a> CycleDetector<'a> {
fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) {
fn detect_from(&mut self, from: &'a str) {
let mut to_visit = Vec::new();
to_visit.push((from, Vec::new(), HashMap::new()));

while let Some((from, path, path_indices)) = to_visit.pop() {
to_visit.extend(self.detect_from_inner(from, path, path_indices));
}
}

/// This function should be called only inside [`Self::detect_from()`], as
/// it's a recursive function using heap instead of a stack. So, instead of
/// the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::detect_from()`].
fn detect_from_inner(
&mut self,
from: &'a str,
path: Vec<&'a Spanning<&'a str>>,
mut path_indices: HashMap<&'a str, usize>,
) -> Vec<CycleDetectorState<'a>> {
self.visited.insert(from);

if !self.spreads.contains_key(from) {
return;
return Vec::new();
}

self.path_indices.insert(from, path.len());
path_indices.insert(from, path.len());

let mut to_visit = Vec::new();
for node in &self.spreads[from] {
let name = &node.item;
let index = self.path_indices.get(name).cloned();
let name = node.item;
let index = path_indices.get(name).cloned();

if let Some(index) = index {
let err_pos = if index < path.len() {
Expand All @@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> {

self.errors
.push(RuleError::new(&error_message(name), &[err_pos.start]));
} else if !self.visited.contains(name) {
} else {
let mut path = path.clone();
path.push(node);
self.detect_from(name, path);
path.pop();
to_visit.push((name, path, path_indices.clone()));
}
}

self.path_indices.remove(from);
to_visit
}
}

Expand Down
48 changes: 35 additions & 13 deletions juniper/src/validation/rules/no_undefined_variables.rs
Expand Up @@ -12,13 +12,6 @@ pub enum Scope<'a> {
Fragment(&'a str),
}

pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

pub fn factory<'a>() -> NoUndefinedVariables<'a> {
NoUndefinedVariables {
defined_variables: HashMap::new(),
Expand All @@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> {
}
}

pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}

impl<'a> NoUndefinedVariables<'a> {
fn find_undef_vars(
&'a self,
Expand All @@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> {
unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>,
) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) =
self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited)
{
to_visit.push(spreads);
}
}
}
}

/// This function should be called only inside [`Self::find_undef_vars()`],
/// as it's a recursive function using heap instead of a stack. So, instead
/// of the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::find_undef_vars()`].
fn find_undef_vars_inner(
&'a self,
scope: &Scope<'a>,
defined: &HashSet<&'a str>,
unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'a Vec<&'a str>> {
if visited.contains(scope) {
return;
return None;
}

visited.insert(scope.clone());
Expand All @@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> {
}
}

if let Some(spreads) = self.spreads.get(scope) {
for spread in spreads {
self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited);
}
}
self.spreads.get(scope)
}
}

Expand Down
49 changes: 32 additions & 17 deletions juniper/src/validation/rules/no_unused_fragments.rs
Expand Up @@ -7,18 +7,12 @@ use crate::{
value::ScalarValue,
};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Scope<'a> {
Operation(Option<&'a str>),
Fragment(&'a str),
}

pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}

pub fn factory<'a>() -> NoUnusedFragments<'a> {
NoUnusedFragments {
spreads: HashMap::new(),
Expand All @@ -27,21 +21,42 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> {
}
}

pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}

impl<'a> NoUnusedFragments<'a> {
fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
if let Scope::Fragment(name) = *from {
fn find_reachable_fragments(&'a self, from: Scope<'a>, result: &mut HashSet<&'a str>) {
let mut to_visit = Vec::new();
to_visit.push(from);

while let Some(from) = to_visit.pop() {
if let Some(next) = self.find_reachable_fragments_inner(from, result) {
to_visit.extend(next.iter().map(|s| Scope::Fragment(s)));
}
}
}

/// This function should be called only inside
/// [`Self::find_reachable_fragments()`], as it's a recursive function using
/// heap instead of a stack. So, instead of the recursive call, we return a
/// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`].
fn find_reachable_fragments_inner(
&'a self,
from: Scope<'a>,
result: &mut HashSet<&'a str>,
) -> Option<&'a Vec<&'a str>> {
if let Scope::Fragment(name) = from {
if result.contains(name) {
return;
return None;
} else {
result.insert(name);
}
}

if let Some(spreads) = self.spreads.get(from) {
for spread in spreads {
self.find_reachable_fragments(&Scope::Fragment(spread), result)
}
}
self.spreads.get(&from)
}
}

Expand All @@ -59,7 +74,7 @@ where
}) = *def
{
let op_name = name.as_ref().map(|s| s.item);
self.find_reachable_fragments(&Scope::Operation(op_name), &mut reachable);
self.find_reachable_fragments(Scope::Operation(op_name), &mut reachable);
}
}

Expand Down Expand Up @@ -96,7 +111,7 @@ where
) {
if let Some(ref scope) = self.current_scope {
self.spreads
.entry(scope.clone())
.entry(*scope)
.or_insert_with(Vec::new)
.push(spread.item.name.item);
}
Expand Down

0 comments on commit 8d28cdb

Please sign in to comment.