| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| import six | ||
|
|
||
| from multipledispatch import Dispatcher | ||
|
|
||
| import ibis.expr.datatypes as dt | ||
|
|
||
|
|
||
| class TypeTranslationContext(object): | ||
| """A tag class to allow alteration of the way a particular type is | ||
| translated. | ||
| Notes | ||
| ----- | ||
| This is used to raise an exception when INT64 types are encountered to | ||
| avoid suprising results due to BigQuery's handling of INT64 types in | ||
| JavaScript UDFs. | ||
| """ | ||
| __slots__ = () | ||
|
|
||
|
|
||
| class UDFContext(TypeTranslationContext): | ||
| __slots__ = () | ||
|
|
||
|
|
||
| ibis_type_to_bigquery_type = Dispatcher('ibis_type_to_bigquery_type') | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(six.string_types) | ||
| def trans_string_default(datatype): | ||
| return ibis_type_to_bigquery_type(dt.dtype(datatype)) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.DataType) | ||
| def trans_default(t): | ||
| return ibis_type_to_bigquery_type(t, TypeTranslationContext()) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(six.string_types, TypeTranslationContext) | ||
| def trans_string_context(datatype, context): | ||
| return ibis_type_to_bigquery_type(dt.dtype(datatype), context) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Floating, TypeTranslationContext) | ||
| def trans_float64(t, context): | ||
| return 'FLOAT64' | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Integer, TypeTranslationContext) | ||
| def trans_integer(t, context): | ||
| return 'INT64' | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register( | ||
| dt.UInt64, (TypeTranslationContext, UDFContext) | ||
| ) | ||
| def trans_lossy_integer(t, context): | ||
| raise TypeError( | ||
| 'Conversion from uint64 to BigQuery integer type (int64) is lossy' | ||
| ) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Array, TypeTranslationContext) | ||
| def trans_array(t, context): | ||
| return 'ARRAY<{}>'.format( | ||
| ibis_type_to_bigquery_type(t.value_type, context)) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Struct, TypeTranslationContext) | ||
| def trans_struct(t, context): | ||
| return 'STRUCT<{}>'.format( | ||
| ', '.join( | ||
| '{} {}'.format( | ||
| name, | ||
| ibis_type_to_bigquery_type(dt.dtype(type), context) | ||
| ) for name, type in zip(t.names, t.types) | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Date, TypeTranslationContext) | ||
| def trans_date(t, context): | ||
| return 'DATE' | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Timestamp, TypeTranslationContext) | ||
| def trans_timestamp(t, context): | ||
| if t.timezone is not None: | ||
| raise TypeError('BigQuery does not support timestamps with timezones') | ||
| return 'TIMESTAMP' | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.DataType, TypeTranslationContext) | ||
| def trans_type(t, context): | ||
| return str(t).upper() | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Integer, UDFContext) | ||
| def trans_integer_udf(t, context): | ||
| # JavaScript does not have integers, only a Number class. BigQuery doesn't | ||
| # behave as expected with INT64 inputs or outputs | ||
| raise TypeError( | ||
| 'BigQuery does not support INT64 as an argument type or a return type ' | ||
| 'for UDFs. Replace INT64 with FLOAT64 in your UDF signature and ' | ||
| 'cast all INT64 inputs to FLOAT64.' | ||
| ) | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Decimal, TypeTranslationContext) | ||
| def trans_numeric(t, context): | ||
| if (t.precision, t.scale) != (38, 9): | ||
| raise TypeError( | ||
| 'BigQuery only supports decimal types with precision of 38 and ' | ||
| 'scale of 9' | ||
| ) | ||
| return 'NUMERIC' | ||
|
|
||
|
|
||
| @ibis_type_to_bigquery_type.register(dt.Decimal, TypeTranslationContext) | ||
| def trans_numeric_udf(t, context): | ||
| raise TypeError('Decimal types are not supported in BigQuery UDFs') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| import pytest | ||
|
|
||
| from multipledispatch.conflict import ambiguities | ||
|
|
||
| import ibis.expr.datatypes as dt | ||
| from ibis.bigquery.datatypes import ( | ||
| ibis_type_to_bigquery_type, UDFContext, TypeTranslationContext | ||
| ) | ||
|
|
||
|
|
||
| def test_no_ambiguities(): | ||
| ambs = ambiguities(ibis_type_to_bigquery_type.funcs) | ||
| assert not ambs | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ('datatype', 'expected'), | ||
| [ | ||
| (dt.float32, 'FLOAT64'), | ||
| (dt.float64, 'FLOAT64'), | ||
| (dt.uint8, 'INT64'), | ||
| (dt.uint16, 'INT64'), | ||
| (dt.uint32, 'INT64'), | ||
| (dt.int8, 'INT64'), | ||
| (dt.int16, 'INT64'), | ||
| (dt.int32, 'INT64'), | ||
| (dt.int64, 'INT64'), | ||
| (dt.string, 'STRING'), | ||
| (dt.Array(dt.int64), 'ARRAY<INT64>'), | ||
| (dt.Array(dt.string), 'ARRAY<STRING>'), | ||
| ( | ||
| dt.Struct.from_tuples([ | ||
| ('a', dt.int64), | ||
| ('b', dt.string), | ||
| ('c', dt.Array(dt.string)), | ||
| ]), | ||
| 'STRUCT<a INT64, b STRING, c ARRAY<STRING>>' | ||
| ), | ||
| (dt.date, 'DATE'), | ||
| (dt.timestamp, 'TIMESTAMP'), | ||
| pytest.mark.xfail( | ||
| (dt.timestamp(timezone='US/Eastern'), 'TIMESTAMP'), | ||
| raises=TypeError, | ||
| reason='Not supported in BigQuery' | ||
| ), | ||
| ('array<struct<a: string>>', 'ARRAY<STRUCT<a STRING>>'), | ||
| pytest.mark.xfail( | ||
| (dt.Decimal(38, 9), 'NUMERIC'), | ||
| raises=TypeError, | ||
| reason='Not supported in BigQuery' | ||
| ), | ||
| ] | ||
| ) | ||
| def test_simple(datatype, expected): | ||
| context = TypeTranslationContext() | ||
| assert ibis_type_to_bigquery_type(datatype, context) == expected | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('datatype', [dt.uint64, dt.Decimal(8, 3)]) | ||
| def test_simple_failure_mode(datatype): | ||
| with pytest.raises(TypeError): | ||
| ibis_type_to_bigquery_type(datatype) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ('type', 'expected'), | ||
| [ | ||
| pytest.mark.xfail((dt.int64, 'INT64'), raises=TypeError), | ||
| pytest.mark.xfail( | ||
| (dt.Array(dt.int64), 'ARRAY<INT64>'), | ||
| raises=TypeError | ||
| ), | ||
| pytest.mark.xfail( | ||
| ( | ||
| dt.Struct.from_tuples([('a', dt.Array(dt.int64))]), | ||
| 'STRUCT<a ARRAY<INT64>>' | ||
| ), | ||
| raises=TypeError, | ||
| ) | ||
| ] | ||
| ) | ||
| def test_ibis_type_to_bigquery_type_udf(type, expected): | ||
| context = UDFContext() | ||
| assert ibis_type_to_bigquery_type(type, context) == expected |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from ibis.bigquery.udf.api import udf # noqa: F401 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| import collections | ||
| import inspect | ||
| import itertools | ||
|
|
||
| import ibis.expr.rules as rlz | ||
| import ibis.expr.datatypes as dt | ||
|
|
||
| from ibis.compat import functools | ||
| from ibis.expr.signature import Argument as Arg | ||
|
|
||
| from ibis.bigquery.compiler import BigQueryUDFNode, compiles | ||
|
|
||
| from ibis.bigquery.udf.core import PythonToJavaScriptTranslator | ||
| from ibis.bigquery.datatypes import ibis_type_to_bigquery_type, UDFContext | ||
|
|
||
|
|
||
| __all__ = 'udf', | ||
|
|
||
|
|
||
| _udf_name_cache = collections.defaultdict(itertools.count) | ||
|
|
||
|
|
||
| def create_udf_node(name, fields): | ||
| """Create a new UDF node type. | ||
| Parameters | ||
| ---------- | ||
| name : str | ||
| Then name of the UDF node | ||
| fields : OrderedDict | ||
| Mapping of class member name to definition | ||
| Returns | ||
| ------- | ||
| result : type | ||
| A new BigQueryUDFNode subclass | ||
| """ | ||
| definition = next(_udf_name_cache[name]) | ||
| external_name = '{}_{:d}'.format(name, definition) | ||
| return type(external_name, (BigQueryUDFNode,), fields) | ||
|
|
||
|
|
||
| def udf(input_type, output_type, strict=True, libraries=None): | ||
| '''Define a UDF for BigQuery | ||
| Parameters | ||
| ---------- | ||
| input_type : List[DataType] | ||
| output_type : DataType | ||
| strict : bool | ||
| Whether or not to put a ``'use strict';`` string at the beginning of | ||
| the UDF. Setting to ``False`` is probably a bad idea. | ||
| libraries : List[str] | ||
| A list of Google Cloud Storage URIs containing to JavaScript source | ||
| code. Note that any symbols (functions, classes, variables, etc.) that | ||
| are exposed in these JavaScript files will be visible inside the UDF. | ||
| Returns | ||
| ------- | ||
| wrapper : Callable | ||
| The wrapped function | ||
| Notes | ||
| ----- | ||
| ``INT64`` is not supported as an argument type or a return type, as per | ||
| `the BigQuery documentation | ||
| <https://cloud.google.com/bigquery/docs/reference/standard-sql/user-defined-functions#sql-type-encodings-in-javascript>`_. | ||
| Examples | ||
| -------- | ||
| >>> from ibis.bigquery.api import udf | ||
| >>> import ibis.expr.datatypes as dt | ||
| >>> @udf(input_type=[dt.double], output_type=dt.double) | ||
| ... def add_one(x): | ||
| ... return x + 1 | ||
| >>> print(add_one.js) | ||
| CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64) | ||
| RETURNS FLOAT64 | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function add_one(x) { | ||
| return (x + 1); | ||
| } | ||
| return add_one(x); | ||
| """; | ||
| >>> @udf(input_type=[dt.double, dt.double], | ||
| ... output_type=dt.Array(dt.double)) | ||
| ... def my_range(start, stop): | ||
| ... def gen(start, stop): | ||
| ... curr = start | ||
| ... while curr < stop: | ||
| ... yield curr | ||
| ... curr += 1 | ||
| ... result = [] | ||
| ... for value in gen(start, stop): | ||
| ... result.append(value) | ||
| ... return result | ||
| >>> print(my_range.js) | ||
| CREATE TEMPORARY FUNCTION my_range_0(start FLOAT64, stop FLOAT64) | ||
| RETURNS ARRAY<FLOAT64> | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_range(start, stop) { | ||
| function* gen(start, stop) { | ||
| let curr = start; | ||
| while ((curr < stop)) { | ||
| yield curr; | ||
| curr += 1; | ||
| } | ||
| } | ||
| let result = []; | ||
| for (let value of gen(start, stop)) { | ||
| result.push(value); | ||
| } | ||
| return result; | ||
| } | ||
| return my_range(start, stop); | ||
| """; | ||
| >>> @udf( | ||
| ... input_type=[dt.double, dt.double], | ||
| ... output_type=dt.Struct.from_tuples([ | ||
| ... ('width', 'double'), ('height', 'double') | ||
| ... ]) | ||
| ... ) | ||
| ... def my_rectangle(width, height): | ||
| ... class Rectangle: | ||
| ... def __init__(self, width, height): | ||
| ... self.width = width | ||
| ... self.height = height | ||
| ... | ||
| ... @property | ||
| ... def area(self): | ||
| ... return self.width * self.height | ||
| ... | ||
| ... def perimeter(self): | ||
| ... return 2 * (self.width + self.height) | ||
| ... | ||
| ... return Rectangle(width, height) | ||
| >>> print(my_rectangle.js) | ||
| CREATE TEMPORARY FUNCTION my_rectangle_0(width FLOAT64, height FLOAT64) | ||
| RETURNS STRUCT<width FLOAT64, height FLOAT64> | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_rectangle(width, height) { | ||
| class Rectangle { | ||
| constructor(width, height) { | ||
| this.width = width; | ||
| this.height = height; | ||
| } | ||
| get area() { | ||
| return (this.width * this.height); | ||
| } | ||
| perimeter() { | ||
| return (2 * (this.width + this.height)); | ||
| } | ||
| } | ||
| return (new Rectangle(width, height)); | ||
| } | ||
| return my_rectangle(width, height); | ||
| """; | ||
| ''' | ||
| if libraries is None: | ||
| libraries = [] | ||
|
|
||
| def wrapper(f): | ||
| if not callable(f): | ||
| raise TypeError('f must be callable, got {}'.format(f)) | ||
|
|
||
| signature = inspect.signature(f) | ||
| parameter_names = signature.parameters.keys() | ||
|
|
||
| udf_node_fields = collections.OrderedDict([ | ||
| (name, Arg(rlz.value(type))) | ||
| for name, type in zip(parameter_names, input_type) | ||
| ] + [ | ||
| ( | ||
| 'output_type', | ||
| lambda self, output_type=output_type: rlz.shape_like( | ||
| self.args, dtype=output_type | ||
| ) | ||
| ), | ||
| ('__slots__', ('js',)), | ||
| ]) | ||
|
|
||
| udf_node = create_udf_node(f.__name__, udf_node_fields) | ||
|
|
||
| @compiles(udf_node) | ||
| def compiles_udf_node(t, expr): | ||
| return '{}({})'.format( | ||
| udf_node.__name__, | ||
| ', '.join(map(t.translate, expr.op().args)) | ||
| ) | ||
|
|
||
| type_translation_context = UDFContext() | ||
| return_type = ibis_type_to_bigquery_type( | ||
| dt.dtype(output_type), type_translation_context) | ||
| bigquery_signature = ', '.join( | ||
| '{name} {type}'.format( | ||
| name=name, | ||
| type=ibis_type_to_bigquery_type( | ||
| dt.dtype(type), type_translation_context) | ||
| ) for name, type in zip(parameter_names, input_type) | ||
| ) | ||
| source = PythonToJavaScriptTranslator(f).compile() | ||
| js = '''\ | ||
| CREATE TEMPORARY FUNCTION {external_name}({signature}) | ||
| RETURNS {return_type} | ||
| LANGUAGE js AS """ | ||
| {strict}{source} | ||
| return {internal_name}({args}); | ||
| """{libraries};'''.format( | ||
| external_name=udf_node.__name__, | ||
| internal_name=f.__name__, | ||
| return_type=return_type, | ||
| source=source, | ||
| signature=bigquery_signature, | ||
| strict=repr('use strict') + ';\n' if strict else '', | ||
| args=', '.join(parameter_names), | ||
| libraries=( | ||
| '\nOPTIONS (\n library={}\n)'.format( | ||
| repr(list(libraries)) | ||
| ) if libraries else '' | ||
| ) | ||
| ) | ||
|
|
||
| @functools.wraps(f) | ||
| def wrapped(*args, **kwargs): | ||
| node = udf_node(*args, **kwargs) | ||
| node.js = js | ||
| return node.to_expr() | ||
|
|
||
| wrapped.__signature__ = signature | ||
| wrapped.js = js | ||
| return wrapped | ||
|
|
||
| return wrapper |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import ast | ||
|
|
||
| import toolz | ||
|
|
||
|
|
||
| class NameFinder: | ||
| """Helper class to find the unique names in an AST. | ||
| """ | ||
|
|
||
| __slots__ = () | ||
|
|
||
| def find(self, node): | ||
| typename = type(node).__name__ | ||
| method = getattr(self, 'find_{}'.format(typename), None) | ||
| if method is None: | ||
| fields = getattr(node, '_fields', None) | ||
| if fields is None: | ||
| return | ||
| for field in fields: | ||
| value = getattr(node, field) | ||
| for result in self.find(value): | ||
| yield result | ||
| else: | ||
| for result in method(node): | ||
| yield result | ||
|
|
||
| def find_Name(self, node): | ||
| # TODO not sure if this is robust to scope changes | ||
| yield node | ||
|
|
||
| def find_list(self, node): | ||
| return list(toolz.concat(map(self.find, node))) | ||
|
|
||
| def find_Call(self, node): | ||
| if not isinstance(node.func, ast.Name): | ||
| fields = node._fields | ||
| else: | ||
| fields = [field for field in node._fields if field != 'func'] | ||
| return toolz.concat(map( | ||
| self.find, (getattr(node, field) for field in fields) | ||
| )) | ||
|
|
||
|
|
||
| def find_names(node): | ||
| """Return the unique :class:`ast.Name` instances in an AST. | ||
| Parameters | ||
| ---------- | ||
| node : ast.AST | ||
| Returns | ||
| ------- | ||
| unique_names : List[ast.Name] | ||
| Examples | ||
| -------- | ||
| >>> import ast | ||
| >>> node = ast.parse('a + b') | ||
| >>> names = find_names(node) | ||
| >>> names # doctest: +ELLIPSIS | ||
| [<_ast.Name object at 0x...>, <_ast.Name object at 0x...>] | ||
| >>> names[0].id | ||
| 'a' | ||
| >>> names[1].id | ||
| 'b' | ||
| """ | ||
| return list(toolz.unique( | ||
| filter(None, NameFinder().find(node)), | ||
| key=lambda node: (node.id, type(node.ctx)) | ||
| )) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import ast | ||
|
|
||
|
|
||
| def matches(value, pattern): | ||
| """Check whether `value` matches `pattern`. | ||
| Parameters | ||
| ---------- | ||
| value : ast.AST | ||
| pattern : ast.AST | ||
| Returns | ||
| ------- | ||
| matched : bool | ||
| """ | ||
| # types must match exactly | ||
| if type(value) != type(pattern): | ||
| return False | ||
|
|
||
| # primitive value, such as None, True, False etc | ||
| if not isinstance(value, ast.AST) and not isinstance(pattern, ast.AST): | ||
| return value == pattern | ||
|
|
||
| fields = [ | ||
| (field, getattr(pattern, field)) | ||
| for field in pattern._fields if hasattr(pattern, field) | ||
| ] | ||
| for field_name, field_value in fields: | ||
| if not matches(getattr(value, field_name), field_value): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| class Rewriter: | ||
| """AST pattern matching to enable rewrite rules. | ||
| Attributes | ||
| ---------- | ||
| funcs : List[Tuple[ast.AST, Callable[ast.expr, [ast.expr]]]] | ||
| """ | ||
| def __init__(self): | ||
| self.funcs = [] | ||
|
|
||
| def register(self, pattern): | ||
| def wrapper(f): | ||
| self.funcs.append((pattern, f)) | ||
| return f | ||
| return wrapper | ||
|
|
||
| def __call__(self, node): | ||
| # TODO: more efficient way of doing this? | ||
| for pattern, func in self.funcs: | ||
| if matches(node, pattern): | ||
| return func(node) | ||
| return node | ||
|
|
||
|
|
||
| rewrite = Rewriter() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import ast | ||
| from ibis.bigquery.udf.find import find_names | ||
| from ibis.util import is_iterable | ||
|
|
||
|
|
||
| def parse_expr(expr): | ||
| body = parse_stmt(expr) | ||
| return body.value | ||
|
|
||
|
|
||
| def parse_stmt(stmt): | ||
| body, = ast.parse(stmt).body | ||
| return body | ||
|
|
||
|
|
||
| def eq(left, right): | ||
| if type(left) != type(right): | ||
| return False | ||
|
|
||
| if is_iterable(left) and is_iterable(right): | ||
| return all(map(eq, left, right)) | ||
|
|
||
| if not isinstance(left, ast.AST) and not isinstance(right, ast.AST): | ||
| return left == right | ||
|
|
||
| assert hasattr(left, '_fields') and hasattr(right, '_fields') | ||
| return left._fields == right._fields and all( | ||
| eq(getattr(left, left_name), getattr(right, right_name)) | ||
| for left_name, right_name in zip(left._fields, right._fields) | ||
| ) | ||
|
|
||
|
|
||
| def var(id): | ||
| return ast.Name(id=id, ctx=ast.Load()) | ||
|
|
||
|
|
||
| def store(id): | ||
| return ast.Name(id=id, ctx=ast.Store()) | ||
|
|
||
|
|
||
| def test_find_BinOp(): | ||
| expr = parse_expr('a + 1') | ||
| found = find_names(expr) | ||
| assert len(found) == 1 | ||
| assert eq(found[0], var('a')) | ||
|
|
||
|
|
||
| def test_find_dup_names(): | ||
| expr = parse_expr('a + 1 * a') | ||
| found = find_names(expr) | ||
| assert len(found) == 1 | ||
| assert eq(found[0], var('a')) | ||
|
|
||
|
|
||
| def test_find_Name(): | ||
| expr = parse_expr('b') | ||
| found = find_names(expr) | ||
| assert len(found) == 1 | ||
| assert eq(found[0], var('b')) | ||
|
|
||
|
|
||
| def test_find_Tuple(): | ||
| expr = parse_expr('(a, (b, 1), (((c,),),))') | ||
| found = find_names(expr) | ||
| assert len(found) == 3 | ||
| assert eq(found, [var('a'), var('b'), var('c')]) | ||
|
|
||
|
|
||
| def test_find_Compare(): | ||
| expr = parse_expr('a < b < c == e + (f, (gh,))') | ||
| found = find_names(expr) | ||
| assert len(found) == 6 | ||
| assert eq( | ||
| found, | ||
| [var('a'), var('b'), var('c'), var('e'), var('f'), var('gh')] | ||
| ) | ||
|
|
||
|
|
||
| def test_find_ListComp(): | ||
| expr = parse_expr('[i for i in range(n) if i < 2]') | ||
| found = find_names(expr) | ||
| assert all(isinstance(f, ast.Name) for f in found) | ||
| assert eq(found, [var('i'), store('i'), var('n')]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,254 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| import pandas as pd | ||
| import pandas.util.testing as tm | ||
|
|
||
| import ibis | ||
| import ibis.expr.datatypes as dt | ||
|
|
||
| pytest.importorskip('google.cloud.bigquery') | ||
|
|
||
| pytestmark = pytest.mark.bigquery | ||
|
|
||
| from ibis.bigquery.api import udf # noqa: E402 | ||
|
|
||
| PROJECT_ID = os.environ.get('GOOGLE_BIGQUERY_PROJECT_ID', 'ibis-gbq') | ||
| DATASET_ID = 'testing' | ||
|
|
||
|
|
||
| @pytest.fixture(scope='module') | ||
| def client(): | ||
| ga = pytest.importorskip('google.auth') | ||
|
|
||
| try: | ||
| return ibis.bigquery.connect(PROJECT_ID, DATASET_ID) | ||
| except ga.exceptions.DefaultCredentialsError: | ||
| pytest.skip("no credentials found, skipping") | ||
|
|
||
|
|
||
| @pytest.fixture(scope='module') | ||
| def alltypes(client): | ||
| t = client.table('functional_alltypes') | ||
| expr = t[t.bigint_col.isin([10, 20])].limit(10) | ||
| return expr | ||
|
|
||
|
|
||
| @pytest.fixture(scope='module') | ||
| def df(alltypes): | ||
| return alltypes.execute() | ||
|
|
||
|
|
||
| def test_udf(client, alltypes, df): | ||
| @udf(input_type=[dt.double, dt.double], output_type=dt.double) | ||
| def my_add(a, b): | ||
| return a + b | ||
|
|
||
| expr = my_add(alltypes.double_col, alltypes.double_col) | ||
| result = expr.execute() | ||
| assert not result.empty | ||
|
|
||
| expected = (df.double_col + df.double_col).rename('tmp') | ||
| tm.assert_series_equal( | ||
| result.value_counts().sort_index(), | ||
| expected.value_counts().sort_index() | ||
| ) | ||
|
|
||
|
|
||
| def test_udf_with_struct(client, alltypes, df): | ||
| @udf( | ||
| input_type=[dt.double, dt.double], | ||
| output_type=dt.Struct.from_tuples([ | ||
| ('width', dt.double), | ||
| ('height', dt.double) | ||
| ]) | ||
| ) | ||
| def my_struct_thing(a, b): | ||
| class Rectangle: | ||
| def __init__(self, width, height): | ||
| self.width = width | ||
| self.height = height | ||
| return Rectangle(a, b) | ||
|
|
||
| assert my_struct_thing.js == '''\ | ||
| CREATE TEMPORARY FUNCTION my_struct_thing_0(a FLOAT64, b FLOAT64) | ||
| RETURNS STRUCT<width FLOAT64, height FLOAT64> | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_struct_thing(a, b) { | ||
| class Rectangle { | ||
| constructor(width, height) { | ||
| this.width = width; | ||
| this.height = height; | ||
| } | ||
| } | ||
| return (new Rectangle(a, b)); | ||
| } | ||
| return my_struct_thing(a, b); | ||
| """;''' | ||
|
|
||
| expr = my_struct_thing(alltypes.double_col, alltypes.double_col) | ||
| result = expr.execute() | ||
| assert not result.empty | ||
|
|
||
| expected = pd.Series( | ||
| [{'width': c, 'height': c} for c in df.double_col], | ||
| name='tmp' | ||
| ) | ||
| tm.assert_series_equal(result, expected) | ||
|
|
||
|
|
||
| def test_udf_compose(client, alltypes, df): | ||
| @udf([dt.double], dt.double) | ||
| def add_one(x): | ||
| return x + 1.0 | ||
|
|
||
| @udf([dt.double], dt.double) | ||
| def times_two(x): | ||
| return x * 2.0 | ||
|
|
||
| t = alltypes | ||
| expr = times_two(add_one(t.double_col)) | ||
| result = expr.execute() | ||
| expected = ((df.double_col + 1.0) * 2.0).rename('tmp') | ||
| tm.assert_series_equal(result, expected) | ||
|
|
||
|
|
||
| def test_udf_scalar(client): | ||
| @udf([dt.double, dt.double], dt.double) | ||
| def my_add(x, y): | ||
| return x + y | ||
|
|
||
| expr = my_add(1, 2) | ||
| result = client.execute(expr) | ||
| assert result == 3 | ||
|
|
||
|
|
||
| def test_multiple_calls_has_one_definition(client): | ||
|
|
||
| @udf([dt.string], dt.double) | ||
| def my_str_len(s): | ||
| return s.length | ||
|
|
||
| s = ibis.literal('abcd') | ||
| expr = my_str_len(s) + my_str_len(s) | ||
| sql = client.compile(expr) | ||
| expected = '''\ | ||
| CREATE TEMPORARY FUNCTION my_str_len_0(s STRING) | ||
| RETURNS FLOAT64 | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_str_len(s) { | ||
| return s.length; | ||
| } | ||
| return my_str_len(s); | ||
| """; | ||
| SELECT my_str_len_0('abcd') + my_str_len_0('abcd') AS `tmp`''' | ||
| assert sql == expected | ||
| result = client.execute(expr) | ||
| assert result == 8.0 | ||
|
|
||
|
|
||
| def test_udf_libraries(client): | ||
| @udf( | ||
| [dt.Array(dt.string)], | ||
| dt.double, | ||
| # whatever symbols are exported in the library are visible inside the | ||
| # UDF, in this case lodash defines _ and we use that here | ||
| libraries=['gs://ibis-testing-libraries/lodash.min.js'] | ||
| ) | ||
| def string_length(strings): | ||
| return _.sum(_.map(strings, lambda x: x.length)) # noqa: F821 | ||
|
|
||
| raw_data = ['aaa', 'bb', 'c'] | ||
| data = ibis.literal(raw_data) | ||
| expr = string_length(data) | ||
| result = client.execute(expr) | ||
| expected = sum(map(len, raw_data)) | ||
| assert result == expected | ||
|
|
||
|
|
||
| def test_udf_with_len(client): | ||
| @udf([dt.string], dt.double) | ||
| def my_str_len(x): | ||
| return len(x) | ||
|
|
||
| @udf([dt.Array(dt.string)], dt.double) | ||
| def my_array_len(x): | ||
| return len(x) | ||
|
|
||
| assert client.execute(my_str_len('aaa')) == 3 | ||
| assert client.execute(my_array_len(['aaa', 'bb'])) == 2 | ||
|
|
||
|
|
||
| def test_multiple_calls_redefinition(client): | ||
|
|
||
| @udf([dt.string], dt.double) | ||
| def my_len(s): | ||
| return s.length | ||
|
|
||
| s = ibis.literal('abcd') | ||
| expr = my_len(s) + my_len(s) | ||
|
|
||
| @udf([dt.string], dt.double) | ||
| def my_len(s): | ||
| return s.length + 1 | ||
| expr = expr + my_len(s) | ||
|
|
||
| sql = client.compile(expr) | ||
| expected = '''\ | ||
| CREATE TEMPORARY FUNCTION my_len_0(s STRING) | ||
| RETURNS FLOAT64 | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_len(s) { | ||
| return s.length; | ||
| } | ||
| return my_len(s); | ||
| """; | ||
| CREATE TEMPORARY FUNCTION my_len_1(s STRING) | ||
| RETURNS FLOAT64 | ||
| LANGUAGE js AS """ | ||
| 'use strict'; | ||
| function my_len(s) { | ||
| return (s.length + 1); | ||
| } | ||
| return my_len(s); | ||
| """; | ||
| SELECT (my_len_0('abcd') + my_len_0('abcd')) + my_len_1('abcd') AS `tmp`''' | ||
| assert sql == expected | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ('argument_type', 'return_type'), | ||
| [ | ||
| pytest.mark.xfail((dt.int64, dt.float64), raises=TypeError), | ||
| pytest.mark.xfail((dt.float64, dt.int64), raises=TypeError), | ||
| # complex argument type, valid return type | ||
| pytest.mark.xfail((dt.Array(dt.int64), dt.float64), raises=TypeError), | ||
| # valid argument type, complex invalid return type | ||
| pytest.mark.xfail( | ||
| (dt.float64, dt.Array(dt.int64)), raises=TypeError), | ||
| # both invalid | ||
| pytest.mark.xfail( | ||
| (dt.Array(dt.Array(dt.int64)), dt.int64), raises=TypeError), | ||
| # struct type with nested integer, valid return type | ||
| pytest.mark.xfail( | ||
| (dt.Struct.from_tuples([('x', dt.Array(dt.int64))]), dt.float64), | ||
| raises=TypeError, | ||
| ) | ||
| ] | ||
| ) | ||
| def test_udf_int64(client, argument_type, return_type): | ||
| # invalid argument type, valid return type | ||
| @udf([argument_type], return_type) | ||
| def my_int64_add(x): | ||
| return 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| import six | ||
| import itertools | ||
|
|
||
| import ibis.util as util | ||
| import ibis.expr.rules as rlz | ||
|
|
||
| from ibis.compat import PY2 | ||
| from collections import OrderedDict | ||
|
|
||
| try: | ||
| from cytoolz import unique | ||
| except ImportError: | ||
| from toolz import unique | ||
|
|
||
|
|
||
| _undefined = object() # marker for missing argument | ||
|
|
||
|
|
||
| class Argument(object): | ||
| """Argument definition | ||
| """ | ||
| if PY2: | ||
| # required to maintain definition order in Annotated metaclass | ||
| _counter = itertools.count() | ||
| __slots__ = '_serial', 'validator', 'default' | ||
| else: | ||
| __slots__ = 'validator', 'default' | ||
|
|
||
| def __init__(self, validator, default=_undefined): | ||
| """Argument constructor | ||
| Parameters | ||
| ---------- | ||
| validator : Union[Callable[[arg], coerced], Type, Tuple[Type]] | ||
| Function which handles validation and/or coercion of the given | ||
| argument. | ||
| default : Union[Any, Callable[[], str]] | ||
| In case of missing (None) value for validation this will be used. | ||
| Note, that default value (except for None) must also pass the inner | ||
| validator. | ||
| If callable is passed, it will be executed just before the inner, and | ||
| itsreturn value will be treaded as default. | ||
| """ | ||
| if PY2: | ||
| self._serial = next(self._counter) | ||
|
|
||
| self.default = default | ||
| if isinstance(validator, type): | ||
| self.validator = rlz.instance_of(validator) | ||
| elif isinstance(validator, tuple): | ||
| assert util.all_of(validator, type) | ||
| self.validator = rlz.instance_of(validator) | ||
| elif callable(validator): | ||
| self.validator = validator | ||
| else: | ||
| raise TypeError('Argument validator must be a callable, type or ' | ||
| 'tuple of types, given: {}'.format(validator)) | ||
|
|
||
| def __eq__(self, other): | ||
| return ( | ||
| self.validator == other.validator and | ||
| self.default == other.default | ||
| ) | ||
|
|
||
| @property | ||
| def optional(self): | ||
| return self.default is not _undefined | ||
|
|
||
| def validate(self, value=_undefined, name=None): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
| value : Any, default undefined | ||
| Raises TypeError if argument is mandatory but not value has been | ||
| given. | ||
| name : Optional[str] | ||
| Argument name for error message | ||
| """ | ||
| if self.optional: | ||
| if value is _undefined or value is None: | ||
| if self.default is None: | ||
| return None | ||
| elif util.is_function(self.default): | ||
| value = self.default() | ||
| else: | ||
| value = self.default | ||
| elif value is _undefined: | ||
| if name is not None: | ||
| name = ' `{}`'.format(name) | ||
| raise TypeError('Missing required value for argument' + name) | ||
|
|
||
| return self.validator(value) | ||
|
|
||
| __call__ = validate # syntactic sugar | ||
|
|
||
|
|
||
| class TypeSignature(OrderedDict): | ||
|
|
||
| __slots__ = () | ||
|
|
||
| @classmethod | ||
| def from_dtypes(cls, dtypes): | ||
| return cls(('_{}'.format(i), Argument(rlz.value(dtype))) | ||
| for i, dtype in enumerate(dtypes)) | ||
|
|
||
| def validate(self, *args, **kwargs): | ||
| result = [] | ||
| for i, (name, argument) in enumerate(self.items()): | ||
| if i < len(args): | ||
| if name in kwargs: | ||
| raise TypeError( | ||
| 'Got multiple values for argument {}'.format(name) | ||
| ) | ||
| value = argument.validate(args[i], name=name) | ||
| elif name in kwargs: | ||
| value = argument.validate(kwargs[name], name=name) | ||
| else: | ||
| value = argument.validate(name=name) | ||
|
|
||
| result.append((name, value)) | ||
|
|
||
| return result | ||
|
|
||
| __call__ = validate # syntactic sugar | ||
|
|
||
| def names(self): | ||
| return tuple(self.keys()) | ||
|
|
||
|
|
||
| class AnnotableMeta(type): | ||
|
|
||
| if PY2: | ||
| @staticmethod | ||
| def _precedes(arg1, arg2): | ||
| """Comparator helper for sorting name-argument pairs""" | ||
| return cmp(arg1[1]._serial, arg2[1]._serial) # noqa: F821 | ||
| else: | ||
| @classmethod | ||
| def __prepare__(metacls, name, bases, **kwds): | ||
| return OrderedDict() | ||
|
|
||
| def __new__(meta, name, bases, dct): | ||
| slots, signature = [], TypeSignature() | ||
|
|
||
| for parent in bases: | ||
| # inherit parent slots | ||
| if hasattr(parent, '__slots__'): | ||
| slots += parent.__slots__ | ||
| # inherit from parent signatures | ||
| if hasattr(parent, 'signature'): | ||
| signature.update(parent.signature) | ||
|
|
||
| # finally apply definitions from the currently created class | ||
| if PY2: | ||
| # on python 2 we cannot maintain definition order | ||
| attribs, arguments = {}, [] | ||
| for k, v in dct.items(): | ||
| if isinstance(v, Argument): | ||
| arguments.append((k, v)) | ||
| else: | ||
| attribs[k] = v | ||
|
|
||
| # so we need to sort arguments based on their unique counter | ||
| signature.update(sorted(arguments, cmp=meta._precedes)) | ||
| else: | ||
| # thanks to __prepare__ attrs are already ordered | ||
| attribs = {} | ||
| for k, v in dct.items(): | ||
| if isinstance(v, Argument): | ||
| # so we can set directly | ||
| signature[k] = v | ||
| else: | ||
| attribs[k] = v | ||
|
|
||
| # if slots or signature are defined no inheritance happens | ||
| signature = attribs.get('signature', signature) | ||
| slots = attribs.get('__slots__', tuple(slots)) + signature.names() | ||
|
|
||
| attribs['signature'] = signature | ||
| attribs['__slots__'] = tuple(unique(slots)) | ||
|
|
||
| return super(AnnotableMeta, meta).__new__(meta, name, bases, attribs) | ||
|
|
||
|
|
||
| @six.add_metaclass(AnnotableMeta) | ||
| class Annotable(object): | ||
|
|
||
| __slots__ = () | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| for name, value in self.signature.validate(*args, **kwargs): | ||
| setattr(self, name, value) | ||
| self._validate() | ||
|
|
||
| def _validate(self): | ||
| pass | ||
|
|
||
| @property | ||
| def args(self): | ||
| return tuple(getattr(self, name) for name in self.signature.names()) | ||
|
|
||
| @property | ||
| def argnames(self): | ||
| return self.signature.names() |