Skip to content

Commit

Permalink
Merge ed30d68 into 785e007
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Apr 9, 2020
2 parents 785e007 + ed30d68 commit 52e750d
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 112 deletions.
113 changes: 34 additions & 79 deletions geoalchemy2/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
from .exc import ArgumentError


class _SpatialElement(functions.Function):
if PY3:
BinasciiError = binascii.Error
else:
BinasciiError = TypeError


class HasFunction(object):
pass


class _SpatialElement(HasFunction):
"""
The base class for :class:`geoalchemy2.elements.WKTElement` and
:class:`geoalchemy2.elements.WKBElement`.
Expand All @@ -40,11 +50,6 @@ def __init__(self, data, srid=-1, extended=False):
self.srid = srid
self.data = data
self.extended = extended
if self.extended:
args = [self.geom_from_extended_version, self.data]
else:
args = [self.geom_from, self.data, self.srid]
functions.Function.__init__(self, *args)

def __str__(self):
return self.desc
Expand Down Expand Up @@ -72,7 +77,6 @@ def __getattr__(self, name):
# We create our own _FunctionGenerator here, and use it in place of
# SQLAlchemy's "func" object. This is to be able to "bind" the
# function to the SQL expression. See also GenericFunction above.

func_ = functions._FunctionGenerator(expr=self)
return getattr(func_, name)

Expand All @@ -81,39 +85,19 @@ def __getstate__(self):
'srid': self.srid,
'data': str(self),
'extended': self.extended,
'name': self.name,
}
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.srid = state['srid']
self.extended = state['extended']
self.data = self._data_from_desc(state['data'])
args = [self.name, self.data]
if not self.extended:
args.append(self.srid)
# we need to call Function.__init__ to properly initialize SQLAlchemy's
# internal states
functions.Function.__init__(self, *args)

@staticmethod
def _data_from_desc(desc):
raise NotImplementedError()


# Default handlers are required for SQLAlchemy < 1.1
# See more details in https://github.com/geoalchemy/geoalchemy2/issues/213
@compiles(_SpatialElement)
def compile_spatialelement_default(element, compiler, **kw):
return "{}({})".format(element.name,
compiler.process(element.clauses, **kw))


@compiles(_SpatialElement, 'sqlite')
def compile_spatialelement_sqlite(element, compiler, **kw):
return "{}({})".format(element.name.lstrip("ST_"),
compiler.process(element.clauses, **kw))


class WKTElement(_SpatialElement):
"""
Instances of this class wrap a WKT or EWKT value.
Expand Down Expand Up @@ -220,77 +204,48 @@ def _data_from_desc(desc):
return binascii.unhexlify(desc)


class RasterElement(FunctionElement):
class RasterElement(_SpatialElement):
"""
Instances of this class wrap a ``raster`` value. Raster values read
from the database are converted to instances of this type. In
most cases you won't need to create ``RasterElement`` instances
yourself.
"""

name = 'raster'
geom_from_extended_version = 'raster'

def __init__(self, data):
self.data = data
FunctionElement.__init__(self, self.data)

def __str__(self):
return self.desc # pragma: no cover

def __repr__(self):
return "<%s at 0x%x; %r>" % \
(self.__class__.__name__, id(self), self.desc) # pragma: no cover
# read srid from the WKB (binary or hexadecimal format)
# The WKB structure is documented in the file
# raster/doc/RFC2-WellKnownBinaryFormat of the PostGIS sources.
try:
bin_data = binascii.unhexlify(data[:114])
except BinasciiError:
bin_data = data
data = str(binascii.hexlify(data).decode(encoding='utf-8'))
byte_order = bin_data[0]
srid = bin_data[53:57]
if not PY3:
byte_order = bytearray(byte_order)[0]
srid = struct.unpack('<I' if byte_order else '>I', srid)[0]
_SpatialElement.__init__(self, data, srid, True)

@property
def desc(self):
"""
This element's description string.
"""
desc = binascii.hexlify(self.data)
if PY3:
# hexlify returns a bytes object on py3
desc = str(desc, encoding="utf-8")

if len(desc) < 30:
return desc

return desc[:30] + '...' # pragma: no cover

def __getattr__(self, name):
#
# This is how things like ocean.rast.ST_Value(...) creates
# SQL expressions of this form:
#
# ST_Value(:ST_GeomFromWKB_1), :param_1)
#

# We create our own _FunctionGenerator here, and use it in place of
# SQLAlchemy's "func" object. This is to be able to "bind" the
# function to the SQL expression. See also GenericFunction.

func_ = functions._FunctionGenerator(expr=self)
return getattr(func_, name)

return self.data

@compiles(RasterElement)
def compile_RasterElement(element, compiler, **kw):
"""
This function makes sure the :class:`geoalchemy2.elements.RasterElement`
contents are correctly casted to the ``raster`` type before using it.
The other elements in this module don't need such a function because
they are derived from :class:`functions.Function`. For the
:class:`geoalchemy2.elements.RasterElement` class however it would not be
of any use to have it compile to ``raster('...')`` so it is compiled to
``'...'::raster`` by this function.
"""
return "%s::raster" % compiler.process(element.clauses)
@staticmethod
def _data_from_desc(desc):
return desc


class CompositeElement(FunctionElement):
"""
Instances of this class wrap a Postgres composite type.
"""

def __init__(self, base, field, type_):
self.name = field
self.type = to_instance(type_)
Expand Down
13 changes: 12 additions & 1 deletion geoalchemy2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from sqlalchemy.ext.compiler import compiles

from . import types
from . import elements


class GenericFunction(functions.GenericFunction):
Expand Down Expand Up @@ -90,8 +91,18 @@ class ST_TransScale(GenericFunction):

def __init__(self, *args, **kwargs):
expr = kwargs.pop('expr', None)
args = list(args)
if expr is not None:
args = (expr,) + args
args = [expr] + args
for idx, elem in enumerate(args):
if isinstance(elem, elements.HasFunction):
if elem.extended:
func_name = elem.geom_from_extended_version
func_args = [elem.data]
else:
func_name = elem.geom_from
func_args = [elem.data, elem.srid]
args[idx] = getattr(functions.func, func_name)(*func_args)
functions.GenericFunction.__init__(self, *args, **kwargs)


Expand Down
46 changes: 34 additions & 12 deletions geoalchemy2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,16 @@ class _GISType(UserDefinedType):
geometry/geography columns. """

def __init__(self, geometry_type='GEOMETRY', srid=-1, dimension=2,
spatial_index=True, management=False, use_typmod=None):
spatial_index=True, management=False, use_typmod=None,
from_text=None, name=None):
geometry_type, srid = self.check_ctor_args(
geometry_type, srid, dimension, management, use_typmod)
self.geometry_type = geometry_type
self.srid = srid
if name is not None:
self.name = name
if from_text is not None:
self.from_text = from_text
self.dimension = dimension
self.spatial_index = spatial_index
self.management = management
Expand All @@ -151,7 +156,12 @@ def column_expression(self, col):
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return WKBElement(value, srid=self.srid, extended=self.extended)
kwargs = {}
if self.srid > 0:
kwargs['srid'] = self.srid
if self.extended is not None:
kwargs['extended'] = self.extended
return self.ElementType(value, **kwargs)
return process

def bind_expression(self, bindvalue):
Expand Down Expand Up @@ -245,6 +255,8 @@ class Geometry(_GISType):
""" The "as binary" function to use. Used by the parent class'
``column_expression`` method. """

ElementType = WKBElement


class Geography(_GISType):
"""
Expand All @@ -270,8 +282,10 @@ class Geography(_GISType):
""" The "as binary" function to use. Used by the parent class'
``column_expression`` method. """

ElementType = WKBElement


class Raster(UserDefinedType):
class Raster(_GISType):
"""
The Raster column type.
Expand All @@ -297,17 +311,25 @@ class Raster(UserDefinedType):
defined for raster columns.
"""

def __init__(self, spatial_index=True):
self.spatial_index = spatial_index
from_text = 'raster'
""" The "from text" raster constructor. Used by the parent class'
``bind_expression`` method. """

def get_col_spec(self):
return 'raster'
as_binary = 'raster'
""" The "as binary" function to use. Used by the parent class'
``column_expression`` method. """

def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return RasterElement(value)
return process
name = 'raster'
""" Type name used for defining raster columns in ``CREATE TABLE``. """

ElementType = RasterElement

def __init__(self, *args, **kwargs):
# Enforce default values
kwargs['geometry_type'] = None
kwargs['srid'] = -1
super(Raster, self).__init__(*args, **kwargs)
self.extended = None


class CompositeType(UserDefinedType):
Expand Down
22 changes: 22 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest


def skip_postgis1(postgis_version):
return pytest.mark.skipif(
postgis_version.startswith('1.'),
reason="requires PostGIS != 1",
)


def skip_postgis2(postgis_version):
return pytest.mark.skipif(
postgis_version.startswith('2.'),
reason="requires PostGIS != 2",
)


def skip_postgis3(postgis_version):
return pytest.mark.skipif(
postgis_version.startswith('3.'),
reason="requires PostGIS != 3",
)
46 changes: 38 additions & 8 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_pickle_unpickle(self):
assert unpickled.srid == 3
assert unpickled.extended is True
assert unpickled.data == 'POINT(1 2)'
assert unpickled.name == 'ST_GeomFromEWKT'
f = unpickled.ST_Buffer(2)
eq_sql(f, 'ST_Buffer('
'ST_GeomFromEWKT(:ST_GeomFromEWKT_1), '
Expand Down Expand Up @@ -92,7 +91,6 @@ def test_pickle_unpickle(self):
assert unpickled.srid == self._srid
assert unpickled.extended is True
assert unpickled.data == self._ewkt
assert unpickled.name == 'ST_GeomFromEWKT'
f = unpickled.ST_Buffer(2)
eq_sql(f, 'ST_Buffer('
'ST_GeomFromEWKT(:ST_GeomFromEWKT_1), '
Expand Down Expand Up @@ -215,7 +213,6 @@ def test_pickle_unpickle(self):
assert unpickled.srid == self._srid
assert unpickled.extended is True
assert unpickled.data == bytes_(self._bin)
assert unpickled.name == 'ST_GeomFromEWKB'
f = unpickled.ST_Buffer(2)
eq_sql(f, 'ST_Buffer('
'ST_GeomFromEWKB(:ST_GeomFromEWKB_1), '
Expand Down Expand Up @@ -273,15 +270,48 @@ def test_function_str(self):

class TestRasterElement():

rast_data = (
b'\x01\x00\x00\x01\x00\x9a\x99\x99\x99\x99\x99\xc9?\x9a\x99\x99\x99\x99\x99'
b'\xc9\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe6\x10\x00'
b'\x00\x05\x00\x05\x00D\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01'
b'\x01\x00\x00\x01\x01\x00\x00\x00\x01\x00\x00\x00\x00')

hex_rast_data = (
'01000001009a9999999999c93f9a9999999999c9bf0000000000000000000000000000f03'
'f00000000000000000000000000000000e610000005000500440001010101010101010100'
'010101000001010000000100000000')

def test_desc(self):
e = RasterElement(b'\x01\x02')
assert e.desc == '0102'
e = RasterElement(self.rast_data)
assert e.desc == self.hex_rast_data
assert e.srid == 4326
e = RasterElement(self.hex_rast_data)
assert e.desc == self.hex_rast_data
assert e.srid == 4326

def test_function_call(self):
e = RasterElement(b'\x01\x02')
e = RasterElement(self.rast_data)
f = e.ST_Height()
eq_sql(f, 'ST_Height(:raster_1::raster)')
assert f.compile().params == {u'raster_1': b'\x01\x02'}
eq_sql(f, 'ST_Height(raster(:raster_1))')
assert f.compile().params == {u'raster_1': self.hex_rast_data}

def test_pickle_unpickle(self):
import pickle
e = RasterElement(self.rast_data)
assert e.srid == 4326
assert e.extended is True
assert e.data == self.hex_rast_data
pickled = pickle.dumps(e)
unpickled = pickle.loads(pickled)
assert unpickled.srid == 4326
assert unpickled.extended is True
assert unpickled.data == self.hex_rast_data
f = unpickled.ST_Height()
eq_sql(f, 'ST_Height(raster(:raster_1))')
assert f.compile().params == {
u'raster_1': self.hex_rast_data,
}


class TestCompositeElement():
Expand Down
Loading

0 comments on commit 52e750d

Please sign in to comment.