Skip to content

Commit

Permalink
ENH: Map full range of supported precision values.
Browse files Browse the repository at this point in the history
Postgres maps float(1) through float(24) to real, and float(25) to float(53) to
double precision. Map those ranges to float32 and float64, respectively.
  • Loading branch information
Eddie Hebert authored and llllllllll committed Oct 18, 2017
1 parent e8ee377 commit 4dafafa
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
42 changes: 31 additions & 11 deletions odo/backends/sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function

import numbers
import os
import re
import subprocess
Expand Down Expand Up @@ -129,15 +130,34 @@ class MSSQLTimestamp(mssql.TIMESTAMP):
postgresql.base.DOUBLE_PRECISION
}

# Maps the value of a precision types `precision` attribute to the desired
# dtype.
# e.g. the value returned by
# `postgresql.base.DOUBLE_PRECISION(precision=53).precision`
# maps to `float64`.
precision_to_dtype = {
24: float32,
53: float64
}

def precision_to_dtype(precision):
"""
Maps a float or double precision attribute to the desired dtype.
The mappings are as follows:
[1, 24] -> float32
[25, 53] -> float64
Values outside of those ranges raise a ``ValueError``.
Parameter
---------
precision : int
A double or float precision. e.g. the value returned by
`postgresql.base.DOUBLE_PRECISION(precision=53).precision`
Returns
-------
dtype : datashape.dtype (float32|float64)
The dtype to use for columns of the specified precision.
"""
if isinstance(precision, numbers.Integral):
if 1 <= precision <= 24:
return float32
elif 25 <= precision <= 53:
return float64
raise ValueError("{} is not a supported precision".format(precision))


# interval types are special cased in discover_typeengine so remove them from
Expand Down Expand Up @@ -220,8 +240,8 @@ def discover_typeengine(typ):
'second_precision=%d, day_precision=%d' %
(typ.second_precision, typ.day_precision))
return datashape.TimeDelta(unit=units)
if type(typ) in precision_types and typ.precision in precision_to_dtype:
return precision_to_dtype[typ.precision]
if type(typ) in precision_types and typ.precision is not None:
return precision_to_dtype(typ.precision)
if typ in revtypes:
return dshape(revtypes[typ])[0]
if type(typ) in revtypes:
Expand Down
17 changes: 17 additions & 0 deletions odo/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def test_discover_oracle_intervals(freq):
(sa.types.NullType, string),
(sa.REAL, float32),
(sa.Float, float64),
(sa.Float(precision=8), float32),
(sa.Float(precision=24), float32),
(sa.Float(precision=42), float64),
(sa.Float(precision=53), float64),
),
)
Expand All @@ -315,6 +317,21 @@ def test_types(typ, dtype):
assert_dshape_equal(discover(t), expected)


@pytest.mark.parametrize(
'typ', (
sa.Float(precision=-1),
sa.Float(precision=0),
sa.Float(precision=54)
)
)
def test_unsupported_precision(typ):
t = sa.Table('t', sa.MetaData(), sa.Column('value', typ))
with pytest.raises(ValueError) as err:
discover(t)
assert str(err.value) == "{} is not a supported precision".format(
typ.precision)


def test_mssql_types():
typ = sa.dialects.mssql.BIT()
t = sa.Table('t', sa.MetaData(), sa.Column('bit', typ))
Expand Down

0 comments on commit 4dafafa

Please sign in to comment.