Skip to content

Commit

Permalink
Make python3 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Sep 18, 2017
1 parent 7fd9770 commit c11223b
Show file tree
Hide file tree
Showing 25 changed files with 156 additions and 125 deletions.
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ language: python
cache: pip
python:
- "2.7"
# - "3.6" # enable when compliant
- "3.6"
install:
- pip install -r dev-requirements.txt
- pip install -r requirements.txt
- pip install -e .
script:
- flake8 graphql_compiler/
Expand Down
18 changes: 10 additions & 8 deletions graphql_compiler/compiler/blocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright 2017 Kensho Technologies, Inc.
import six

from .expressions import Expression
from .helpers import (CompilerEntity, ensure_unicode_string, safe_quoted_string,
validate_marked_location, validate_safe_string)
Expand Down Expand Up @@ -52,7 +54,7 @@ def __init__(self, start_class):
def validate(self):
"""Ensure that the QueryRoot block is valid."""
if not (isinstance(self.start_class, set) and
all(isinstance(x, basestring) for x in self.start_class)):
all(isinstance(x, six.string_types) for x in self.start_class)):
raise TypeError(u'Expected set of basestring start_class, got: {} {}'.format(
type(self.start_class).__name__, self.start_class))

Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(self, target_class):
def validate(self):
"""Ensure that the CoerceType block is valid."""
if not (isinstance(self.target_class, set) and
all(isinstance(x, basestring) for x in self.target_class)):
all(isinstance(x, six.string_types) for x in self.target_class)):
raise TypeError(u'Expected set of basestring target_class, got: {} {}'.format(
type(self.target_class).__name__, self.target_class))

Expand Down Expand Up @@ -123,7 +125,7 @@ def __init__(self, fields):
"""
self.fields = {
ensure_unicode_string(key): value
for key, value in fields.iteritems()
for key, value in six.iteritems(fields)
}

# All key values are normalized to unicode before being passed to the parent constructor,
Expand All @@ -137,7 +139,7 @@ def validate(self):
raise TypeError(u'Expected dict fields, got: {} {}'.format(
type(self.fields).__name__, self.fields))

for key, value in self.fields.iteritems():
for key, value in six.iteritems(self.fields):
validate_safe_string(key)
if not isinstance(value, Expression):
raise TypeError(
Expand All @@ -148,7 +150,7 @@ def visit_and_update_expressions(self, visitor_fn):
"""Create an updated version (if needed) of the ConstructResult via the visitor pattern."""
new_fields = {}

for key, value in self.fields.iteritems():
for key, value in six.iteritems(self.fields):
new_value = value.visit_and_update(visitor_fn)
if new_value is not value:
new_fields[key] = new_value
Expand Down Expand Up @@ -253,7 +255,7 @@ def __init__(self, direction, edge_name, optional=False):

def validate(self):
"""Ensure that the Traverse block is valid."""
if not isinstance(self.direction, basestring):
if not isinstance(self.direction, six.string_types):
raise TypeError(u'Expected basestring direction, got: {} {}'.format(
type(self.direction).__name__, self.direction))

Expand Down Expand Up @@ -317,7 +319,7 @@ def __init__(self, direction, edge_name, depth):

def validate(self):
"""Ensure that the Traverse block is valid."""
if not isinstance(self.direction, basestring):
if not isinstance(self.direction, six.string_types):
raise TypeError(u'Expected basestring direction, got: {} {}'.format(
type(self.direction).__name__, self.direction))

Expand All @@ -344,7 +346,7 @@ def to_gremlin(self):

recurse_steps = [
recurse_base + (recurse_traversal * i)
for i in xrange(self.depth + 1)
for i in six.moves.xrange(self.depth + 1)
]
return template.format(recurse=','.join(recurse_steps))

Expand Down
13 changes: 7 additions & 6 deletions graphql_compiler/compiler/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from graphql.type.definition import (GraphQLInterfaceType, GraphQLList, GraphQLObjectType,
GraphQLUnionType)
from graphql.validation import validate
import six

from . import blocks, expressions
from ..exceptions import GraphQLCompilationError, GraphQLParsingError, GraphQLValidationError
Expand Down Expand Up @@ -225,15 +226,15 @@ def _process_output_source_directive(schema, current_schema_type, ast,

def _validate_property_directives(directives):
"""Validate the directives that appear at a property field."""
for directive_name in directives.iterkeys():
for directive_name in six.iterkeys(directives):
if directive_name in vertex_only_directives:
raise GraphQLCompilationError(
u'Found vertex-only directive {} set on property.'.format(directive_name))


def _validate_vertex_directives(directives):
"""Validate the directives that appear at a vertex field."""
for directive_name in directives.iterkeys():
for directive_name in six.iterkeys(directives):
if directive_name in property_only_directives:
raise GraphQLCompilationError(
u'Found property-only directive {} set on vertex.'.format(directive_name))
Expand Down Expand Up @@ -529,7 +530,7 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
def _validate_fold_has_outputs(fold_data, outputs):
# At least one output in the outputs list must point to the fold_data,
# or the scope corresponding to fold_data had no @outputs and is illegal.
for output in outputs.values():
for output in six.itervalues(outputs):
if output['fold'] is fold_data:
return True

Expand Down Expand Up @@ -694,7 +695,7 @@ def _compile_root_ast_to_ir(schema, ast):

# Ensure the GraphQL query root doesn't have any vertex directives
# that are disallowed on the root node.
directives_present_at_root = set(_get_directives(base_ast).iterkeys())
directives_present_at_root = set(six.iterkeys(_get_directives(base_ast)))
disallowed_directives = directives_present_at_root & vertex_directives_prohibited_on_root
if disallowed_directives:
raise GraphQLCompilationError(u'Found prohibited directives on root vertex: '
Expand All @@ -710,7 +711,7 @@ def _compile_root_ast_to_ir(schema, ast):
basic_blocks.append(_compile_output_step(outputs_context))
output_metadata = {
name: OutputMetadata(type=value['type'], optional=value['optional'])
for name, value in outputs_context.iteritems()
for name, value in six.iteritems(outputs_context)
}

return basic_blocks, output_metadata, context['inputs'], context['location_types']
Expand All @@ -732,7 +733,7 @@ def _compile_output_step(outputs):
u'one field with the @output directive.')

output_fields = {}
for output_name, output_context in outputs.iteritems():
for output_name, output_context in six.iteritems(outputs):
location = output_context['location']
optional = output_context['optional']
graphql_type = output_context['type']
Expand Down
15 changes: 8 additions & 7 deletions graphql_compiler/compiler/expressions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2017 Kensho Technologies, Inc.
from graphql import GraphQLList, GraphQLNonNull
import six

from ..exceptions import GraphQLCompilationError
from ..schema import GraphQLDate, GraphQLDateTime
Expand Down Expand Up @@ -75,7 +76,7 @@ def validate(self):
return

# Literal safe strings are correctly representable and supported.
if isinstance(self.value, basestring):
if isinstance(self.value, six.string_types):
validate_safe_string(self.value)
return

Expand All @@ -99,12 +100,12 @@ def _to_output_code(self):
return u'true'
elif self.value is False:
return u'false'
elif isinstance(self.value, basestring):
elif isinstance(self.value, six.string_types):
return safe_quoted_string(self.value)
elif isinstance(self.value, list):
if len(self.value) == 0:
return '[]'
elif all(isinstance(x, basestring) for x in self.value):
elif all(isinstance(x, six.string_types) for x in self.value):
list_contents = ', '.join(safe_quoted_string(x) for x in sorted(self.value))
return '[' + list_contents + ']'
else:
Expand Down Expand Up @@ -179,7 +180,7 @@ def to_match(self):
# We don't want the dollar sign as part of the variable name.
variable_with_no_dollar_sign = self.variable_name[1:]

match_variable_name = '{%s}' % (unicode(variable_with_no_dollar_sign),)
match_variable_name = '{%s}' % (six.text_type(variable_with_no_dollar_sign),)

# We can't directly pass a Date or DateTime object, so we have to pass it as a string
# and then parse it inline. For date format parameter meanings, see:
Expand All @@ -203,7 +204,7 @@ def to_gremlin(self):
elif GraphQLDateTime.is_same_type(self.inferred_type):
return u'Date.parse("{}", {})'.format(STANDARD_DATETIME_FORMAT, self.variable_name)
else:
return unicode(self.variable_name)
return six.text_type(self.variable_name)

def __eq__(self, other):
"""Return True if the given object is equal to this one, and False otherwise."""
Expand Down Expand Up @@ -234,7 +235,7 @@ def validate(self):
def to_match(self):
"""Return a unicode object with the MATCH representation of this LocalField."""
self.validate()
return unicode(self.field_name)
return six.text_type(self.field_name)

def to_gremlin(self):
"""Return a unicode object with the Gremlin representation of this expression."""
Expand Down Expand Up @@ -616,7 +617,7 @@ def __init__(self, operator, left, right):

def validate(self):
"""Validate that the BinaryComposition is correctly representable."""
if not isinstance(self.operator, unicode):
if not isinstance(self.operator, six.text_type):
raise TypeError(u'Expected unicode operator, got: {} {}'.format(
type(self.operator).__name__, self.operator))

Expand Down
32 changes: 13 additions & 19 deletions graphql_compiler/compiler/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import string

from graphql import GraphQLEnumType, GraphQLNonNull, GraphQLScalarType, GraphQLString, is_type
import six

from ..exceptions import GraphQLCompilationError


VARIABLE_ALLOWED_CHARS = frozenset(unicode(string.ascii_letters + string.digits + '_'))
VARIABLE_ALLOWED_CHARS = frozenset(six.text_type(string.ascii_letters + string.digits + '_'))


def get_ast_field_name(ast):
Expand Down Expand Up @@ -57,9 +58,9 @@ def is_graphql_type(graphql_type):

def ensure_unicode_string(value):
"""Ensure the value is a basestring, and return it as unicode."""
if not isinstance(value, basestring):
if not isinstance(value, six.string_types):
raise TypeError(u'Expected basestring value, got: {}'.format(value))
return unicode(value)
return six.text_type(value)


def get_uniquely_named_objects_by_name(object_list):
Expand Down Expand Up @@ -97,7 +98,7 @@ def validate_safe_string(value):
# The following strings are explicitly allowed, despite having otherwise-illegal chars.
legal_strings_with_special_chars = frozenset({'@rid', '@class', '@this', '%'})

if not isinstance(value, basestring):
if not isinstance(value, six.string_types):
raise TypeError(u'Expected basestring value, got: {} {}'.format(
type(value).__name__, value))

Expand All @@ -122,6 +123,7 @@ def validate_marked_location(location):
raise GraphQLCompilationError(u'Cannot mark location at a field: {}'.format(location))


@six.python_2_unicode_compatible
class Location(object):
def __init__(self, query_path, field=None, visit_counter=1):
"""Create a new Location object.
Expand Down Expand Up @@ -155,7 +157,7 @@ def __init__(self, query_path, field=None, visit_counter=1):
if not isinstance(query_path, tuple):
raise TypeError(u'Expected query_path to be a tuple, was: '
u'{} {}'.format(type(query_path).__name__, query_path))
if field and not isinstance(field, basestring):
if field and not isinstance(field, six.string_types):
raise TypeError(u'Expected field to be None or basestring, was: '
u'{} {}'.format(type(field).__name__, field))

Expand All @@ -182,7 +184,7 @@ def at_vertex(self):

def navigate_to_subpath(self, child):
"""Return a new Location object at a child vertex of the current Location's vertex."""
if not isinstance(child, basestring):
if not isinstance(child, six.string_types):
raise TypeError(u'Expected child to be a basestring, was: {}'.format(child))
if self.field:
raise AssertionError(u'Currently at a field, cannot go to child: {}'.format(self))
Expand All @@ -196,16 +198,12 @@ def revisit(self):

def get_location_name(self):
"""Return a tuple of a unique name of the Location, and the current field name (or None)."""
mark_name = u'__'.join(self.query_path) + u'___' + unicode(self.visit_counter)
mark_name = u'__'.join(self.query_path) + u'___' + six.text_type(self.visit_counter)
return (mark_name, self.field)

def __unicode__(self):
"""Return a human-readable unicode representation of the Location object."""
return u'Location({}, {}, {})'.format(self.query_path, self.field, self.visit_counter)

def __str__(self):
"""Return a human-readable str representation of the Location object."""
return self.__unicode__().encode('utf-8')
return u'Location({}, {}, {})'.format(self.query_path, self.field, self.visit_counter)

def __repr__(self):
"""Return a human-readable str representation of the Location object."""
Expand All @@ -226,11 +224,11 @@ def __hash__(self):
return hash(self.query_path) ^ hash(self.field) ^ hash(self.visit_counter)


@six.python_2_unicode_compatible
@six.add_metaclass(ABCMeta)
class CompilerEntity(object):
"""An abstract compiler entity. Can represent things like basic blocks and expressions."""

__metaclass__ = ABCMeta

def __init__(self, *args, **kwargs):
"""Construct a new CompilerEntity."""
self._print_args = args
Expand All @@ -240,7 +238,7 @@ def validate(self):
"""Ensure that the CompilerEntity is valid."""
pass

def __unicode__(self):
def __str__(self):
"""Return a human-readable unicode representation of this CompilerEntity."""
printed_args = []
if self._print_args:
Expand All @@ -253,10 +251,6 @@ def __unicode__(self):
args=self._print_args,
kwargs=self._print_kwargs)

def __str__(self):
"""Return a human-readable str representation of this CompilerEntity."""
return self.__unicode__().encode('utf-8')

def __repr__(self):
"""Return a human-readable str representation of the CompilerEntity object."""
return self.__str__()
Expand Down
2 changes: 1 addition & 1 deletion graphql_compiler/compiler/ir_lowering_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2017 Kensho Technologies, Inc.
"""Language-independent IR lowering and optimization functions."""

from funcy import pairwise
from funcy.py2 import pairwise

from .blocks import (Backtrack, CoerceType, ConstructResult, Filter, MarkLocation, OutputSource,
QueryRoot, Recurse, Traverse)
Expand Down
5 changes: 3 additions & 2 deletions graphql_compiler/compiler/ir_lowering_gremlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
to simplify the final code generation step.
"""
from graphql.type import GraphQLInterfaceType, GraphQLObjectType, GraphQLUnionType
import six

from ..exceptions import GraphQLCompilationError
from .blocks import Backtrack, CoerceType, Filter, Traverse
Expand All @@ -28,7 +29,7 @@ def lower_coerce_type_block_type_data(ir_blocks, type_equivalence_hints):
allowed_value_type_spec = GraphQLUnionType

# Validate that the type_equivalence_hints parameter has correct types.
for key, value in type_equivalence_hints.iteritems():
for key, value in six.iteritems(type_equivalence_hints):
if (not isinstance(key, allowed_key_type_spec) or
not isinstance(value, allowed_value_type_spec)):
msg = (u'Invalid type equivalence hints received! Hint {} ({}) -> {} ({}) '
Expand All @@ -43,7 +44,7 @@ def lower_coerce_type_block_type_data(ir_blocks, type_equivalence_hints):
# a dict of type name -> set of names of equivalent types, which can be used more readily.
equivalent_type_names = {
key.name: {x.name for x in value.types}
for key, value in type_equivalence_hints.iteritems()
for key, value in six.iteritems(type_equivalence_hints)
}

new_ir_blocks = []
Expand Down
4 changes: 3 additions & 1 deletion graphql_compiler/compiler/ir_lowering_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
us to convert this Expression into other Expressions, using data already present in the IR,
to simplify the final code generation step.
"""
import six

from .blocks import Backtrack, MarkLocation, QueryRoot, Traverse
from .expressions import (BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral,
Literal, TernaryConditional, TrueLiteral)
Expand Down Expand Up @@ -261,7 +263,7 @@ def _flatten_location_translations(location_translations):
location_translations: dict of Location -> Location, where the key translates to the value.
Mutated in place for efficiency and simplicity of implementation.
"""
sources_to_process = set(location_translations.iterkeys())
sources_to_process = set(six.iterkeys(location_translations))

def _update_translation(source):
"""Return the proper (fully-flattened) translation for the given location."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
For details, see:
https://github.com/orientechnologies/orientdb/issues/7225
"""
import funcy
import funcy.py2 as funcy

from ..blocks import Filter, Recurse, Traverse
from ..expressions import BinaryComposition, Literal, LocalField
Expand Down
Loading

0 comments on commit c11223b

Please sign in to comment.