Skip to content

Commit

Permalink
Add GeoPandas support (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas committed Sep 14, 2023
1 parent 1816d10 commit ce7a4bf
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 16 deletions.
24 changes: 22 additions & 2 deletions heavyai/_pandas_loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import functools
import math

import numpy as np
Expand Down Expand Up @@ -32,7 +33,9 @@
gpd = None


GEO_TYPE_NAMES = ['POINT', 'LINESTRING', 'POLYGON', 'MULTIPOLYGON']
GEO_TYPE_NAMES = ['POINT', 'MULTIPOINT',
'LINESTRING', 'MULTILINESTRING',
'POLYGON', 'MULTIPOLYGON']
GEO_TYPE_ID = [
v[1] for v in TDatumType._NAMES_TO_VALUES.items() if v[0] in GEO_TYPE_NAMES
]
Expand Down Expand Up @@ -97,8 +100,12 @@ def get_mapd_type_from_object(data):
return 'ARRAY/{}'.format(get_mapd_dtype(pd.Series(val)))
elif isinstance(val, shapely.geometry.Point):
return 'POINT'
elif isinstance(val, shapely.geometry.MultiPoint):
return 'MULTIPOINT'
elif isinstance(val, shapely.geometry.LineString):
return 'LINESTRING'
elif isinstance(val, shapely.geometry.MultiLineString):
return 'MULTILINESTRING'
elif isinstance(val, shapely.geometry.Polygon):
return 'POLYGON'
elif isinstance(val, shapely.geometry.MultiPolygon):
Expand Down Expand Up @@ -235,7 +242,7 @@ def _serialize_arrow_payload(data, table_metadata, preserve_index=True):
if isinstance(data, pd.DataFrame):

# detect if there are categorical columns in dataframe
cols = data.select_dtypes(include=['category']).columns
cols = data.select_dtypes(include=['category', 'object']).columns

# if there are categorical columns, make a copy before casting
# to avoid mutating input data
Expand All @@ -246,6 +253,19 @@ def _serialize_arrow_payload(data, table_metadata, preserve_index=True):
else:
data_ = data

# convert geo columns to WKT representation
for col in table_metadata:
if col.type in GEO_TYPE_NAMES:
try:
fn = functools.partial(shapely.to_wkt,
rounding_precision=-1,
trim=False,
output_dimension=2)
data_[col.name] = data_[col.name].apply(fn)
except TypeError:
msg = (f"Column '{col.name}' is not a geometry column. "
"Please check your input data.")
raise ValueError(msg)
data = pa.RecordBatch.from_pandas(data_, preserve_index=preserve_index)

stream = pa.BufferOutputStream()
Expand Down
4 changes: 4 additions & 0 deletions heavyai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def _parse_tdf_gpu(tdf):
'TIME': 'int_col',
'STR': 'str_col',
'POINT': 'str_col',
'MULTIPOINT': 'str_col',
'LINESTRING': 'str_col',
'MULTILINESTRING': 'str_col',
'POLYGON': 'str_col',
'MULTIPOLYGON': 'str_col',
'TINYINT': 'int_col',
Expand All @@ -202,7 +204,9 @@ def _parse_tdf_gpu(tdf):
'TIME': -9223372036854775808,
'STR': '',
'POINT': '',
'MULTIPOINT': '',
'LINESTRING': '',
'MULTILINESTRING': '',
'POLYGON': '',
'MULTIPOLYGON': '',
'TINYINT': -128,
Expand Down
17 changes: 16 additions & 1 deletion heavyai/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from ._loaders import _build_input_rows
from ._transforms import _change_dashboard_sources
from .ipc import load_buffer, shmdt
from ._pandas_loaders import build_row_desc, _serialize_arrow_payload
from ._pandas_loaders import build_row_desc, _serialize_arrow_payload, GEO_TYPE_NAMES
from . import _pandas_loaders
from ._mutators import set_tdf, get_tdf
from types import MethodType
from packaging.version import Version


class Connection(heavydb.Connection):
Expand Down Expand Up @@ -269,8 +270,18 @@ def load_table_arrow(
load_table
load_table_columnar
load_table_rowwise
Notes
-----
Use ``load_table_columnar`` to load geometry data if ``heavydb <= 7.0``
"""
metadata = self.get_table_details(table_name)
for col in metadata:
if col.type in GEO_TYPE_NAMES and self.get_version() < Version('7.0'):
# prevent the server from crashing
msg = (f'Cannot use `load_table_arrow` with column of type "{col.type}". '
'Use `load_table_columnar` or `load_table_rowwise` instead.')
raise ValueError(msg)
payload = _serialize_arrow_payload(
data, metadata, preserve_index=preserve_index
)
Expand Down Expand Up @@ -638,6 +649,10 @@ def render_vega(self, vega, compression_level=1):
rendered_vega = RenderedVega(result)
return rendered_vega

def get_version(self):
semver = self._client.get_version()
return Version(semver.split("-")[0])


class RenderedVega:
def __init__(self, render_result):
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,15 @@ def _tests_table_no_nulls(n_samples):
point_ = pd.read_csv("tests/data/points_10000.zip", header=None).values
point_ = np.squeeze(point_)

mpoint_ = pd.read_csv("tests/data/mpoint_10000.zip", header=None).values
mpoint_ = np.squeeze(mpoint_)

line_ = pd.read_csv("tests/data/lines_10000.zip", header=None).values
line_ = np.squeeze(line_)

mline_ = pd.read_csv("tests/data/mline_10000.zip", header=None).values
mline_ = np.squeeze(mline_)

mpoly_ = pd.read_csv("tests/data/mpoly_10000.zip", header=None).values
mpoly_ = np.squeeze(mpoly_)

Expand All @@ -168,7 +174,9 @@ def _tests_table_no_nulls(n_samples):
'time_': time_,
'text_': text_,
'point_': point_,
'mpoint_': mpoint_,
'line_': line_,
'mline_': mline_,
'mpoly_': mpoly_,
'poly_': poly_,
}
Expand Down
Binary file added tests/data/mline_10000.zip
Binary file not shown.
Binary file added tests/data/mpoint_10000.zip
Binary file not shown.
18 changes: 16 additions & 2 deletions tests/test_data_no_nulls_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_create_load_table_no_nulls_sql_execute(self, con, method):
"""
df_in = _tests_table_no_nulls(10000)
df_in.drop(
columns=["point_", "line_", "mpoly_", "poly_"], inplace=True
columns=["point_", "mpoint_", "line_", "mline_", "mpoly_", "poly_"], inplace=True
)
con.execute("drop table if exists test_data_no_nulls;")
con.load_table("test_data_no_nulls", df_in, method=method)
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_create_load_table_no_nulls_select_ipc(self, con, method):
# need to drop unsupported columns from df_in
df_in = _tests_table_no_nulls(10000)
df_in.drop(
columns=["point_", "line_", "mpoly_", "poly_"], inplace=True
columns=["point_", "mpoint_", "line_", "mline_", "mpoly_", "poly_"], inplace=True
)

con.execute("drop table if exists test_data_no_nulls_ipc;")
Expand Down Expand Up @@ -205,7 +205,9 @@ def test_load_table_geospatial_no_nulls(self, con, method):
time_ time,
text_ text encoding dict(32),
point_ point,
mpoint_ multipoint,
line_ linestring,
mline_ multilinestring,
mpoly_ multipolygon,
poly_ polygon
)"""
Expand Down Expand Up @@ -266,12 +268,24 @@ def test_load_table_geospatial_no_nulls(self, con, method):
[x.equals_exact(y, 0.000001) for x, y in zip(point_in, point_out)]
)

mpoint_in = [wkt.loads(x) for x in df_in["mpoint_"]]
mpoint_out = [wkt.loads(x) for x in df_out["mpoint_"]]
assert all(
[x.equals_exact(y, 0.000001) for x, y in zip(mpoint_in, mpoint_out)]
)

line_in = [wkt.loads(x) for x in df_in["line_"]]
line_out = [wkt.loads(x) for x in df_out["line_"]]
assert all(
[x.equals_exact(y, 0.000001) for x, y in zip(line_in, line_out)]
)

mline_in = [wkt.loads(x) for x in df_in["mline_"]]
mline_out = [wkt.loads(x) for x in df_out["mline_"]]
assert all(
[x.equals_exact(y, 0.000001) for x, y in zip(mline_in, mline_out)]
)

mpoly_in = [wkt.loads(x) for x in df_in["mpoly_"]]
mpoly_out = [wkt.loads(x) for x in df_out["mpoly_"]]
assert all(
Expand Down
137 changes: 128 additions & 9 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@
from pandas.api.types import is_object_dtype, is_categorical_dtype
import pandas.testing as tm
import shapely
from shapely.geometry import Point, LineString, Polygon, MultiPolygon
from shapely.geometry import (
Point,
MultiPoint,
LineString,
MultiLineString,
Polygon,
MultiPolygon
)
import textwrap
from .conftest import no_gpu
from packaging.version import Version
from .conftest import no_gpu, _tests_table_no_nulls
from .data import dashboard_metadata

heavydb_host = os.environ.get('HEAVYDB_HOST', 'localhost')
Expand All @@ -36,14 +44,15 @@

def _cursor2df(cursor):
col_types = {c.name: c.type_code for c in cursor.description}

TDatumTypeGeo = [TDatumType.POINT, TDatumType.LINESTRING,
TDatumType.POLYGON, TDatumType.MULTIPOLYGON]
for typ in ("MULTIPOINT", "MULTILINESTRING"):
if hasattr(TDatumType, typ):
TDatumTypeGeo.append(getattr(TDatumType, typ))

has_geodata = {
k: v
in [
TDatumType.POINT,
TDatumType.LINESTRING,
TDatumType.POLYGON,
TDatumType.MULTIPOLYGON,
]
k: v in TDatumTypeGeo
for k, v in col_types.items()
}
col_names = list(col_types.keys())
Expand Down Expand Up @@ -757,14 +766,87 @@ def test_load_empty_table_arrow(self, con):
'a POINT, b LINESTRING, c POLYGON, d MULTIPOLYGON',
id='geo_values',
),
pytest.param(
gpd.GeoDataFrame(
{
'a': [MultiPoint([(1, 2), (3, 4), (5, 6)]),
MultiPoint([(2, 1), (4, 3), (6, 5)])],
'b': [MultiLineString([[[0, 0], [1, 2]], [[4, 4], [5, 6]]]),
MultiLineString([[[0, 1], [1, 2]], [[2, 4], [4, 6]]])]
}
),
'a MULTIPOINT, b MULTILINESTRING',
id='geo_values_multi'
),
],
)
def test_load_table_columnar(self, con, tmp_table, df, table_fields):

for typ in ("MULTIPOINT", "MULTILINESTRING"):
if typ in table_fields and not hasattr(TDatumType, typ):
pytest.skip(f'Missing type "{typ}" in pyheavydb')

con.execute("create table {} ({});".format(tmp_table, table_fields))
con.load_table_columnar(tmp_table, df)
result = _cursor2df(con.execute('select * from {}'.format(tmp_table)))
pd.testing.assert_frame_equal(df, result)

@pytest.mark.parametrize('column, typ', [
('point_', 'POINT'),
('mpoint_', 'MULTIPOINT'),
('line_', 'LINESTRING'),
('mline_', 'MULTILINESTRING'),
('poly_', 'POLYGON'),
('mpoly_', 'MULTIPOLYGON'),
])
def test_load_table_arrow_geo(self, con, column, typ):
if not hasattr(TDatumType, typ):
pytest.skip(f'Missing type "{typ}" in pyheavydb')

if con.get_version() < Version("7.0"):
pytest.skip(f'Requires heavydb version 7.0, got {con.get_version()}')

con.execute("drop table if exists test_geo")
con.execute(f"create table test_geo ({column} {typ})")

df_in = _tests_table_no_nulls(10000)
gdf_in = gpd.GeoDataFrame({
column: df_in[column].apply(shapely.from_wkt),
})

con.load_table_arrow("test_geo", gdf_in)

df_out = pd.read_sql("select * from test_geo;", con)
gdf_out = gpd.GeoDataFrame({
column: df_out[column].apply(shapely.wkt.loads)
})

s1 = gpd.GeoSeries(gdf_in[column])
s2 = gpd.GeoSeries(gdf_out[column])
assert s1.geom_almost_equals(s2, decimal=1).all()

@pytest.mark.parametrize(
'col, defn',
[
('point_', 'POINT'),
('mpoint_', 'MULTIPOINT'),
('line_', 'LINESTRING'),
('mline_', 'MULTILINESTRING'),
('poly_', 'POLYGON'),
('mpoly_', 'MULTIPOLYGON'),
],
)
def test_load_table_arrow_geo_error(self, con, col, defn):
if con.get_version() < Version("7.0"):
pytest.skip(f'Requires heavydb version 7.0, got {con.get_version()}')
con.execute("drop table if exists test_geo")
con.execute(f"create table test_geo ({col} {defn})")
df_in = _tests_table_no_nulls(10000).filter([col])
msg = (f"Column '{col}' is not a geometry column. Please check your "
"input data.")
with pytest.raises(ValueError, match=msg):
con.load_table_arrow("test_geo", df_in)

def test_load_infer(self, con):

con.execute("drop table if exists baz;")
Expand Down Expand Up @@ -1539,3 +1621,40 @@ def test_dashboard_duplication_remap(self, con):
]['dashboard']['dataSources'].items():
for col in val['columnMetadata']:
assert col['table'] == new_dashboard_name

@pytest.mark.parametrize('func', ('ST_AsText', 'ST_AsWkt', 'ST_AsBinary', 'ST_AsWkb'))
@pytest.mark.parametrize('column, typ', [
('point_', 'POINT'),
('mpoint_', 'MULTIPOINT'),
('line_', 'LINESTRING'),
('mline_', 'MULTILINESTRING'),
('poly_', 'POLYGON'),
('mpoly_', 'MULTIPOLYGON')])
def test_AsText_AsBinary(self, con, func, column, typ):
if not hasattr(TDatumType, typ):
pytest.skip(f'Missing type "{typ}" in pyheavydb')

con.execute("drop table if exists test_geo")

con.execute(f"create table test_geo ({column} {typ})")
df_in = _tests_table_no_nulls(10000).filter([column])
con.load_table("test_geo", df_in, method='rows')

query = f'select {func}({column}) as "{column}" from test_geo'
try:
df_out = pd.read_sql(query, con)
except pd.errors.DatabaseError as msg:
err_msg = f'No match found for function signature {func}'
if err_msg in msg.args[0]:
pytest.skip(f'Server does not have {func}')

assert len(df_in[column]) == len(df_out[column])

# format df_in/out to shapely WKB/WKT format
series_in = gpd.GeoSeries(df_in[column].apply(shapely.wkt.loads))
if func in ('ST_AsText', 'ST_AsWkt'):
series_out = gpd.GeoSeries(df_out[column].apply(shapely.wkt.loads))
else:
series_out = gpd.GeoSeries(df_out[column].apply(shapely.wkb.loads))

assert series_in.geom_almost_equals(series_out, decimal=1).all()

0 comments on commit ce7a4bf

Please sign in to comment.