Skip to content

Commit

Permalink
Forbid @Skip and @include directives in subscription root selection
Browse files Browse the repository at this point in the history
  • Loading branch information
benjie committed Oct 9, 2023
1 parent 7a6d055 commit 57d7b11
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 21 deletions.
60 changes: 51 additions & 9 deletions src/execution/collectFields.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { isSameSet } from '../jsutils/isSameSet.js';
import type { ObjMap } from '../jsutils/ObjMap.js';

import type {
DirectiveNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
Expand All @@ -26,7 +27,7 @@ import type { GraphQLSchema } from '../type/schema.js';

import { typeFromAST } from '../utilities/typeFromAST.js';

import { getDirectiveValues } from './values.js';
import { getArgumentValues, getDirectiveValues } from './values.js';

export interface DeferUsage {
label: string | undefined;
Expand Down Expand Up @@ -60,6 +61,7 @@ export interface CollectFieldsResult {
groupedFieldSet: GroupedFieldSet;
newGroupedFieldSetDetails: Map<DeferUsageSet, GroupedFieldSetDetails>;
newDeferUsages: ReadonlyArray<DeferUsage>;
forbiddenDirectiveInstances: ReadonlyArray<DirectiveNode>;
}

interface CollectFieldsContext {
Expand All @@ -72,6 +74,7 @@ interface CollectFieldsContext {
fieldsByTarget: Map<Target, AccumulatorMap<string, FieldNode>>;
newDeferUsages: Array<DeferUsage>;
visitedFragmentNames: Set<string>;
forbiddenDirectiveInstances: Array<DirectiveNode>;
}

/**
Expand Down Expand Up @@ -100,16 +103,28 @@ export function collectFields(
targetsByKey: new Map(),
newDeferUsages: [],
visitedFragmentNames: new Set(),
forbiddenDirectiveInstances: [],
};

collectFieldsImpl(context, operation.selectionSet);

return {
...buildGroupedFieldSets(context.targetsByKey, context.fieldsByTarget),
newDeferUsages: context.newDeferUsages,
forbiddenDirectiveInstances: context.forbiddenDirectiveInstances,
};
}

/**
* This variable is the empty variables used during the validation phase (where
* no variables exist) for field collection; if a `@skip` or `@include`
* directive is ever seen when `variableValues` is set to this, it should
* throw.
*/
export const VALIDATION_PHASE_EMPTY_VARIABLES: {
[variable: string]: any;
} = Object.freeze(Object.create(null));

/**
* Given an array of field nodes, collects all of the subfields of the passed
* in fields, and returns them at the end.
Expand Down Expand Up @@ -139,6 +154,7 @@ export function collectSubfields(
targetsByKey: new Map(),
newDeferUsages: [],
visitedFragmentNames: new Set(),
forbiddenDirectiveInstances: [],
};

for (const fieldDetails of fieldGroup.fields) {
Expand All @@ -155,6 +171,7 @@ export function collectSubfields(
fieldGroup.targets,
),
newDeferUsages: context.newDeferUsages,
forbiddenDirectiveInstances: context.forbiddenDirectiveInstances,
};
}

Expand All @@ -179,7 +196,7 @@ function collectFieldsImpl(
for (const selection of selectionSet.selections) {
switch (selection.kind) {
case Kind.FIELD: {
if (!shouldIncludeNode(variableValues, selection)) {
if (!shouldIncludeNode(context, variableValues, selection)) {
continue;
}
const key = getFieldEntryKey(selection);
Expand All @@ -200,7 +217,7 @@ function collectFieldsImpl(
}
case Kind.INLINE_FRAGMENT: {
if (
!shouldIncludeNode(variableValues, selection) ||
!shouldIncludeNode(context, variableValues, selection) ||
!doesFragmentConditionMatch(schema, selection, runtimeType)
) {
continue;
Expand Down Expand Up @@ -232,7 +249,7 @@ function collectFieldsImpl(
case Kind.FRAGMENT_SPREAD: {
const fragName = selection.name.value;

if (!shouldIncludeNode(variableValues, selection)) {
if (!shouldIncludeNode(context, variableValues, selection)) {
continue;
}

Expand Down Expand Up @@ -304,19 +321,44 @@ function getDeferValues(
* directives, where `@skip` has higher precedence than `@include`.
*/
function shouldIncludeNode(
context: CollectFieldsContext,
variableValues: { [variable: string]: unknown },
node: FragmentSpreadNode | FieldNode | InlineFragmentNode,
): boolean {
const skip = getDirectiveValues(GraphQLSkipDirective, node, variableValues);
const skipDirectiveNode = node.directives?.find(
(directive) => directive.name.value === GraphQLSkipDirective.name,
);
if (
skipDirectiveNode &&
variableValues === VALIDATION_PHASE_EMPTY_VARIABLES
) {
context.forbiddenDirectiveInstances.push(skipDirectiveNode);
return false;
}
const skip = skipDirectiveNode
? getArgumentValues(GraphQLSkipDirective, skipDirectiveNode, variableValues)
: undefined;
if (skip?.if === true) {
return false;
}

const include = getDirectiveValues(
GraphQLIncludeDirective,
node,
variableValues,
const includeDirectiveNode = node.directives?.find(
(directive) => directive.name.value === GraphQLIncludeDirective.name,
);
if (
includeDirectiveNode &&
variableValues === VALIDATION_PHASE_EMPTY_VARIABLES
) {
context.forbiddenDirectiveInstances.push(includeDirectiveNode);
return false;
}
const include = includeDirectiveNode
? getArgumentValues(
GraphQLIncludeDirective,
includeDirectiveNode,
variableValues,
)
: undefined;
if (include?.if === false) {
return false;
}
Expand Down
42 changes: 42 additions & 0 deletions src/validation/__tests__/SingleFieldSubscriptionsRule-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,48 @@ describe('Validate: Subscriptions with single field', () => {
]);
});

it('fails with @skip or @include directive', () => {
expectErrors(`
subscription RequiredRuntimeValidation($bool: Boolean!) {
newMessage @include(if: $bool) {
body
sender
}
disallowedSecondRootField @skip(if: $bool)
}
`).toDeepEqual([
{
message:
'Subscription "RequiredRuntimeValidation" must not use `@skip` or `@include` directives in the top level selection.',
locations: [
{ line: 3, column: 20 },
{ line: 7, column: 35 },
],
},
]);
});

it('fails with @skip or @include directive in anonymous subscription', () => {
expectErrors(`
subscription ($bool: Boolean!) {
newMessage @include(if: $bool) {
body
sender
}
disallowedSecondRootField @skip(if: $bool)
}
`).toDeepEqual([
{
message:
'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.',
locations: [
{ line: 3, column: 20 },
{ line: 7, column: 35 },
],
},
]);
});

it('skips if not subscription type', () => {
const emptySchema = buildSchema(`
type Query {
Expand Down
38 changes: 26 additions & 12 deletions src/validation/rules/SingleFieldSubscriptionsRule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import { Kind } from '../../language/kinds.js';
import type { ASTVisitor } from '../../language/visitor.js';

import type { FieldGroup } from '../../execution/collectFields.js';
import { collectFields } from '../../execution/collectFields.js';
import {
collectFields,
VALIDATION_PHASE_EMPTY_VARIABLES,
} from '../../execution/collectFields.js';

import type { ValidationContext } from '../ValidationContext.js';

Expand All @@ -23,7 +26,8 @@ function toNodes(fieldGroup: FieldGroup): ReadonlyArray<FieldNode> {
* Subscriptions must only include a non-introspection field.
*
* A GraphQL subscription is valid only if it contains a single root field and
* that root field is not an introspection field.
* that root field is not an introspection field. `@skip` and `@include`
* directives are forbidden.
*
* See https://spec.graphql.org/draft/#sec-Single-root-field
*/
Expand All @@ -37,23 +41,33 @@ export function SingleFieldSubscriptionsRule(
const subscriptionType = schema.getSubscriptionType();
if (subscriptionType) {
const operationName = node.name ? node.name.value : null;
const variableValues: {
[variable: string]: any;
} = Object.create(null);
const variableValues = VALIDATION_PHASE_EMPTY_VARIABLES;
const document = context.getDocument();
const fragments: ObjMap<FragmentDefinitionNode> = Object.create(null);
for (const definition of document.definitions) {
if (definition.kind === Kind.FRAGMENT_DEFINITION) {
fragments[definition.name.value] = definition;
}
}
const { groupedFieldSet } = collectFields(
schema,
fragments,
variableValues,
subscriptionType,
node,
);
const { groupedFieldSet, forbiddenDirectiveInstances } =
collectFields(
schema,
fragments,
variableValues,
subscriptionType,
node,
);
if (forbiddenDirectiveInstances.length > 0) {
context.reportError(
new GraphQLError(
operationName != null
? `Subscription "${operationName}" must not use \`@skip\` or \`@include\` directives in the top level selection.`
: 'Anonymous Subscription must not use `@skip` or `@include` directives in the top level selection.',
{ nodes: forbiddenDirectiveInstances },
),
);
return;
}
if (groupedFieldSet.size > 1) {
const fieldGroups = [...groupedFieldSet.values()];
const extraFieldGroups = fieldGroups.slice(1);
Expand Down

0 comments on commit 57d7b11

Please sign in to comment.