Skip to content

Commit

Permalink
Query on multiple Pandas DataFrame (#89)
Browse files Browse the repository at this point in the history
* Add dataframe.query() to run sql on multi dataframe

* Add join dataframe test

* Support passing dataframe as table
  • Loading branch information
auxten committed Aug 17, 2023
1 parent e503e5b commit c330249
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 58 deletions.
8 changes: 6 additions & 2 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ chdb.query('select * from file("data.parquet", Parquet)', 'Dataframe')
```python
import chdb.dataframe as cdf
import pandas as pd
tbl = cdf.Table(dataframe=pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}))
ret_tbl = tbl.query('select * from __table__')
# Join 2 DataFrames
df1 = pd.DataFrame({'a': [1, 2, 3], 'b': ["one", "two", "three"]})
df2 = pd.DataFrame({'c': [1, 2, 3], 'd': ["", "", ""]})
ret_tbl = cdf.query(sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
tbl1=df1, tbl2=df2)
print(ret_tbl)
# Query on the DataFrame Table
print(ret_tbl.query('select b, sum(a) from __table__ group by b'))
```
</details>
Expand Down
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ chdb.query('select * from file("data.parquet", Parquet)', 'Dataframe')
```python
import chdb.dataframe as cdf
import pandas as pd
tbl = cdf.Table(dataframe=pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}))
ret_tbl = tbl.query('select * from __table__')
# Join 2 DataFrames
df1 = pd.DataFrame({'a': [1, 2, 3], 'b': ["one", "two", "three"]})
df2 = pd.DataFrame({'c': [1, 2, 3], 'd': ["", "", ""]})
ret_tbl = cdf.query(sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
tbl1=df1, tbl2=df2)
print(ret_tbl)
# Query on the DataFrame Table
print(ret_tbl.query('select b, sum(a) from __table__ group by b'))
```
</details>
Expand Down
6 changes: 5 additions & 1 deletion chdb/dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@
if pd.__version__[0] < '2':
print('Please upgrade pandas to version 2.0.0 or higher to have better performance')

from .query import *
from .query import Table, pandas_read_parquet # noqa: C0413

query = Table.queryStatic

__all__ = ['Table', 'query', 'pandas_read_parquet']
207 changes: 155 additions & 52 deletions chdb/dataframe/query.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import os
import tempfile
from io import BytesIO
import re
import pandas as pd
import pyarrow as pa
from io import BytesIO
from chdb import query as chdb_query


class Table(object):
class Table:
"""
Table is a wrapper of multiple formats of data buffer, including parquet file path,
parquet bytes, and pandas dataframe.
if use_memfd is True, will try using memfd_create to create a temp file in memory, which is
only available on Linux. If failed, will fallback to use tempfile.mkstemp to create a temp file
"""

def __init__(self,
parquet_path: str = None,
temp_parquet_path: str = None,
parquet_memoryview: memoryview = None,
dataframe: pd.DataFrame = None,
arrow_table: pa.Table = None,
use_memfd: bool = False):
def __init__(
self,
parquet_path: str = None,
temp_parquet_path: str = None,
parquet_memoryview: memoryview = None,
dataframe: pd.DataFrame = None,
arrow_table: pa.Table = None,
use_memfd: bool = False,
):
"""
Initialize a Table object with one of parquet file path, parquet bytes, pandas dataframe or
parquet table.
Expand All @@ -33,11 +36,11 @@ def __init__(self,
self.use_memfd = use_memfd

def __del__(self):
try:
if self._temp_parquet_path is not None:
if self._temp_parquet_path is not None:
try:
os.remove(self._temp_parquet_path)
except:
pass
except OSError:
pass

def to_pandas(self) -> pd.DataFrame:
if self._dataframe is None:
Expand All @@ -63,55 +66,53 @@ def flush_to_disk(self):
return

if self._dataframe is not None:
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
self._dataframe.to_parquet(tmp)
self._temp_parquet_path = tmp.name
del self._dataframe
self._dataframe = None
self._df_to_disk(self._dataframe)
self._dataframe = None
elif self._arrow_table is not None:
import pyarrow.parquet as pq
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
pq.write_table(self._arrow_table, tmp.name)
self._temp_parquet_path = tmp.name
del self._arrow_table
self._arrow_table = None
self._arrow_table_to_disk(self._arrow_table)
self._arrow_table = None
elif self._parquet_memoryview is not None:
# copy memoryview to temp file
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
tmp.write(self._parquet_memoryview.tobytes())
self._temp_parquet_path = tmp.name
self._parquet_memoryview.release()
del self._parquet_memoryview
self._parquet_memoryview = None
self._memoryview_to_disk(self._parquet_memoryview)
self._parquet_memoryview = None
else:
raise ValueError("No data in Table object")

def _df_to_disk(self, df):
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
df.to_parquet(tmp)
self._temp_parquet_path = tmp.name

def _arrow_table_to_disk(self, arrow_table):
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
pa.parquet.write_table(arrow_table, tmp.name)
self._temp_parquet_path = tmp.name

def _memoryview_to_disk(self, memoryview):
# copy memoryview to temp file
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
tmp.write(memoryview.tobytes())
self._temp_parquet_path = tmp.name

def __repr__(self):
return repr(self.to_pandas())

def __str__(self):
return str(self.to_pandas())

def query(self, sql, **kwargs) -> "Table":
def query(self, sql: str, **kwargs) -> "Table":
"""
Query on current Table object, return a new Table object.
The `FROM` table name in SQL should always be `__table__`. eg:
`SELECT * FROM __table__ WHERE ...`
"""
# check if "__table__" is in sql
if "__table__" not in sql:
raise ValueError("SQL should always contain `FROM __table__`")
self._validate_sql(sql)

if self._parquet_path is not None: # if we have parquet file path, run chdb query on it directly is faster
# replace "__table__" with file("self._parquet_path", Parquet)
new_sql = sql.replace("__table__", f"file(\"{self._parquet_path}\", Parquet)")
res = chdb_query(new_sql, "Parquet", **kwargs)
return Table(parquet_memoryview=res.get_memview())
if (
self._parquet_path is not None
): # if we have parquet file path, run chdb query on it directly is faster
return self._query_on_path(self._parquet_path, sql, **kwargs)
elif self._temp_parquet_path is not None:
# replace "__table__" with file("self._temp_parquet_path", Parquet)
new_sql = sql.replace("__table__", f"file(\"{self._temp_parquet_path}\", Parquet)")
res = chdb_query(new_sql, "Parquet", **kwargs)
return Table(parquet_memoryview=res.get_memview())
return self._query_on_path(self._temp_parquet_path, sql, **kwargs)
elif self._parquet_memoryview is not None:
return self.queryParquetBuffer(sql, **kwargs)
elif self._dataframe is not None:
Expand All @@ -121,6 +122,15 @@ def query(self, sql, **kwargs) -> "Table":
else:
raise ValueError("Table object is not initialized correctly")

def _query_on_path(self, path, sql, **kwargs):
new_sql = sql.replace("__table__", f'file("{path}", Parquet)')
res = chdb_query(new_sql, "Parquet", **kwargs)
return Table(parquet_memoryview=res.get_memview())

def _validate_sql(self, sql):
if "__table__" not in sql:
raise ValueError("SQL should always contain `FROM __table__`")

def queryParquetBuffer(self, sql: str, **kwargs) -> "Table":
if "__table__" not in sql:
raise ValueError("SQL should always contain `FROM __table__`")
Expand All @@ -139,6 +149,8 @@ def queryParquetBuffer(self, sql: str, **kwargs) -> "Table":
ffd.flush()
ret = self._run_on_temp(parquet_fd, temp_path, sql=sql, fmt="Parquet", **kwargs)
ffd.close()
if temp_path is not None:
os.remove(temp_path)
return ret

def queryArrowTable(self, sql: str, **kwargs) -> "Table":
Expand All @@ -159,6 +171,8 @@ def queryArrowTable(self, sql: str, **kwargs) -> "Table":
ffd.flush()
ret = self._run_on_temp(arrow_fd, temp_path, sql=sql, fmt="Arrow", **kwargs)
ffd.close()
if temp_path is not None:
os.remove(temp_path)
return ret

def queryDF(self, sql: str, **kwargs) -> "Table":
Expand All @@ -174,19 +188,104 @@ def queryDF(self, sql: str, **kwargs) -> "Table":
if parquet_fd == -1:
parquet_fd, temp_path = tempfile.mkstemp()
ffd = os.fdopen(parquet_fd, "wb")
self._dataframe.to_parquet(ffd, engine='pyarrow', compression=None)
self._dataframe.to_parquet(ffd, engine="pyarrow", compression=None)
ffd.flush()
ret = self._run_on_temp(parquet_fd, temp_path, sql=sql, fmt="Parquet", **kwargs)
ffd.close()
if temp_path is not None:
os.remove(temp_path)
return ret

def _run_on_temp(self, fd: int, temp_path: str = None, sql: str = None, fmt: str = "Parquet", **kwargs) -> "Table":
@staticmethod
def queryStatic(sql: str, **kwargs) -> "Table":
"""
Query on multiple Tables, use Table variables as the table name in SQL
eg.
table1 = Table(...)
table2 = Table(...)
query("SELECT * FROM __table1__ JOIN __table2__ ON ...", table1=table1, table2=table2)
"""
ansiTablePattern = re.compile(r"__([a-zA-Z][a-zA-Z0-9_]*)__")
temp_paths = []
ffds = []

def replace_table_name(match):
tableName = match.group(1)
if tableName not in kwargs:
raise ValueError(f"Table {tableName} should be passed as a parameter")

tbl = kwargs[tableName]
# if tbl is DataFrame, convert it to Table
if isinstance(tbl, pd.DataFrame):
tbl = Table(dataframe=tbl)
elif not isinstance(tbl, Table):
raise ValueError(f"Table {tableName} should be an instance of Table or DataFrame")

if tbl._parquet_path is not None:
return f'file("{tbl._parquet_path}", Parquet)'

if tbl._temp_parquet_path is not None:
return f'file("{tbl._temp_parquet_path}", Parquet)'

temp_path = None
data_fd = -1

if tbl.use_memfd:
data_fd = memfd_create()

if data_fd == -1:
data_fd, temp_path = tempfile.mkstemp()
temp_paths.append(temp_path)

ffd = os.fdopen(data_fd, "wb")
ffds.append(ffd)

if tbl._parquet_memoryview is not None:
ffd.write(tbl._parquet_memoryview.tobytes())
ffd.flush()
os.lseek(data_fd, 0, os.SEEK_SET)
return f'file("/dev/fd/{data_fd}", Parquet)'

if tbl._dataframe is not None:
ffd.write(tbl._dataframe.to_parquet(engine="pyarrow", compression=None))
ffd.flush()
os.lseek(data_fd, 0, os.SEEK_SET)
return f'file("/dev/fd/{data_fd}", Parquet)'

if tbl._arrow_table is not None:
with pa.RecordBatchFileWriter(ffd, tbl._arrow_table.schema) as writer:
writer.write_table(tbl._arrow_table)
ffd.flush()
os.lseek(data_fd, 0, os.SEEK_SET)
return f'file("/dev/fd/{data_fd}", Arrow)'

raise ValueError(f"Table {tableName} is not initialized correctly")

sql = ansiTablePattern.sub(replace_table_name, sql)
res = chdb_query(sql, "Parquet")

for fd in ffds:
fd.close()

for tmp_path in temp_paths:
os.remove(tmp_path)

return Table(parquet_memoryview=res.get_memview())

def _run_on_temp(
self,
fd: int,
temp_path: str = None,
sql: str = None,
fmt: str = "Parquet",
**kwargs,
) -> "Table":
# replace "__table__" with file("temp_path", Parquet) or file("/dev/fd/{parquet_fd}", Parquet)
if temp_path is not None:
new_sql = sql.replace("__table__", f"file(\"{temp_path}\", {fmt})")
new_sql = sql.replace("__table__", f'file("{temp_path}", {fmt})')
else:
os.lseek(fd, 0, os.SEEK_SET)
new_sql = sql.replace("__table__", f"file(\"/dev/fd/{fd}\", {fmt})")
new_sql = sql.replace("__table__", f'file("/dev/fd/{fd}", {fmt})')
res = chdb_query(new_sql, "Parquet", **kwargs)
return Table(parquet_memoryview=res.get_memview())

Expand All @@ -212,10 +311,14 @@ def memfd_create(name: str = None) -> int:
if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='Run SQL on parquet file')
parser.add_argument('parquet_path', type=str, help='path to parquet file')
parser.add_argument('sql', type=str, help='SQL to run')
parser.add_argument('--use-memfd', action='store_true', help='use memfd_create to create file descriptor')
parser = argparse.ArgumentParser(description="Run SQL on parquet file")
parser.add_argument("parquet_path", type=str, help="path to parquet file")
parser.add_argument("sql", type=str, help="SQL to run")
parser.add_argument(
"--use-memfd",
action="store_true",
help="use memfd_create to create file descriptor",
)
args = parser.parse_args()

table = Table(parquet_path=args.parquet_path, use_memfd=args.use_memfd)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def test_gc(self):
gc.collect()
self.assertEqual(mv3.tobytes(), b'123,"adbcdefg"\n')
self.assertEqual(len(mv3), 15)

if __name__ == '__main__':
unittest.main()
47 changes: 47 additions & 0 deletions tests/test_joindf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!python3

import unittest
import pandas as pd
from chdb import dataframe as cdf


class TestJoinDf(unittest.TestCase):
def test_1df(self):
df1 = pd.DataFrame({"a": [1, 2, 3], "b": [b"one", b"two", b"three"]})
cdf1 = cdf.Table(dataframe=df1)
ret1 = cdf.query(sql="select * from __tbl1__", tbl1=cdf1)
self.assertEqual(str(ret1), str(df1))

def test_2df(self):
df1 = pd.DataFrame({"a": [1, 2, 3], "b": ["one", "two", "three"]})
df2 = pd.DataFrame({"c": [1, 2, 3], "d": ["①", "②", "③"]})
ret_tbl = cdf.query(
sql="select * from __tbl1__ t1 join __tbl2__ t2 on t1.a = t2.c",
tbl1=df1,
tbl2=df2,
)
self.assertEqual(
str(ret_tbl),
str(
pd.DataFrame(
{
"a": [1, 2, 3],
"b": [b"one", b"two", b"three"],
"c": [1, 2, 3],
"d": [b"\xe2\x91\xa0", b"\xe2\x91\xa1", b"\xe2\x91\xa2"],
}
)
),
)

ret_tbl2 = ret_tbl.query(
"select b, a+c s from __table__ order by s"
)
self.assertEqual(
str(ret_tbl2),
str(pd.DataFrame({"b": [b"one", b"two", b"three"], "s": [2, 4, 6]})),
)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit c330249

Please sign in to comment.