Skip to content

Commit

Permalink
feat: frappe.db.sql results as_iterator (backport #19810) (backport
Browse files Browse the repository at this point in the history
#24346) (#24562)

* feat: `frappe.db.sql` results `as_iterator` (backport #19810) (#24346)

* feat: `frappe.db.sql` results as iterator

- Also avoid self.last_result that holds on to large result set reference.

(cherry picked from commit 588157d)

# Conflicts:
#	frappe/database/database.py

* perf: avoid duplicate copies of result set

When as_list, as_dict is done we hold on to original result set until
next query is performed. This can be HUGE for large queries.

(cherry picked from commit d5b2706)

* test: add perf test for references

(cherry picked from commit 03b6d8a)

* chore: conflict

* perf: Unbuffered cursors for large result sets (#24365)

If you're reading 1000s of rows from MySQL, the default behaviour is to
read all of them in memory at once.

One of the use case for reading large rows is reporting where a lot of
data is read and then processed in Python. The read row is hoever not
used again but still consumes memory until entire function exits.

SSCursor (Server Side Cursor) allows fetching one row at a time.

Note: This is slower than fetching everything at once AND has risk of
connection loss. So, don't use this as a crutch. If possible rewrite
code so processing is done in SQL.

---------

Co-authored-by: Ankush Menat <ankush@frappe.io>
(cherry picked from commit 99a3a35)

# Conflicts:
#	frappe/database/database.py
#	frappe/database/mariadb/database.py
#	pyproject.toml

* chore: conflicts

* chore: remove test for dead functionality

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Ankush Menat <ankush@frappe.io>
  • Loading branch information
mergify[bot] and ankush committed Jan 29, 2024
1 parent 41d2fe2 commit 7f3a12b
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 32 deletions.
89 changes: 75 additions & 14 deletions frappe/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import traceback
from contextlib import contextmanager, suppress
from time import time
from typing import TYPE_CHECKING, Any, Union

from pypika.terms import Criterion, NullValue

Expand All @@ -31,11 +32,20 @@
from frappe.utils import cint, get_datetime, get_table_name, getdate, now, sbool
from frappe.utils.deprecations import deprecated, deprecation_warning

if TYPE_CHECKING:
from psycopg2 import connection as PostgresConnection
from psycopg2 import cursor as PostgresCursor
from pymysql.connections import Connection as MariadbConnection
from pymysql.cursors import Cursor as MariadbCursor


IFNULL_PATTERN = re.compile(r"ifnull\(", flags=re.IGNORECASE)
INDEX_PATTERN = re.compile(r"\s*\([^)]+\)\s*")
SINGLE_WORD_PATTERN = re.compile(r'([`"]?)(tab([A-Z]\w+))\1')
MULTI_WORD_PATTERN = re.compile(r'([`"])(tab([A-Z]\w+)( [A-Z]\w+)+)\1')

SQL_ITERATOR_BATCH_SIZE = 100


class Database:
"""
Expand Down Expand Up @@ -112,8 +122,8 @@ def setup_type_map(self):
def connect(self):
"""Connects to a database as set in `site_config.json`."""
self.cur_db_name = self.user
self._conn = self.get_connection()
self._cursor = self._conn.cursor()
self._conn: Union["MariadbConnection", "PostgresConnection"] = self.get_connection()
self._cursor: Union["MariadbCursor", "PostgresCursor"] = self._conn.cursor()
frappe.local.rollback_observers = []

try:
Expand Down Expand Up @@ -144,6 +154,9 @@ def _transform_query(self, query: Query, values: QueryValues) -> tuple:
def _transform_result(self, result: list[tuple]) -> list[tuple]:
return result

def _clean_up(self):
pass

def sql(
self,
query: Query,
Expand All @@ -159,6 +172,7 @@ def sql(
explain=False,
run=True,
pluck=False,
as_iterator=False,
):
"""Execute a SQL query and fetch all rows.
Expand All @@ -172,7 +186,12 @@ def sql(
:param as_utf8: Encode values as UTF 8.
:param auto_commit: Commit after executing the query.
:param update: Update this dict to all rows (if returned `as_dict`).
:param run: Returns query without executing it if False.
:param run: Return query without executing it if False.
:param pluck: Get the plucked field only.
:param explain: Print `EXPLAIN` in error log.
:param as_iterator: Returns iterator over results instead of fetching all results at once.
This should be used with unbuffered cursor as default cursors used by pymysql and postgres
buffer the results internally. See `Database.unbuffered_cursor`.
Examples:
# return customer names as dicts
Expand Down Expand Up @@ -267,10 +286,14 @@ def sql(
if not self._cursor.description:
return ()

self.last_result = self._transform_result(self._cursor.fetchall())
if as_iterator:
return self._return_as_iterator(pluck=pluck, as_dict=as_dict, as_list=as_list, update=update)

last_result = self._transform_result(self._cursor.fetchall())
if pluck:
return [r[0] for r in self.last_result]
last_result = [r[0] for r in last_result]
self._clean_up()
return last_result

if as_utf8:
deprecation_warning("as_utf8 parameter is deprecated and will be removed in version 15.")
Expand All @@ -279,14 +302,37 @@ def sql(

# scrub output if required
if as_dict:
ret = self.fetch_as_dict(formatted, as_utf8)
last_result = self.fetch_as_dict(last_result, as_utf8=as_utf8)
if update:
for r in ret:
for r in last_result:
r.update(update)
return ret
elif as_list or as_utf8:
return self.convert_to_lists(self.last_result, formatted, as_utf8)
return self.last_result
elif as_list:
last_result = self.convert_to_lists(last_result, as_utf8=as_utf8)

self._clean_up()
return last_result

def _return_as_iterator(self, *, pluck, as_dict, as_list, update):
while result := self._transform_result(self._cursor.fetchmany(SQL_ITERATOR_BATCH_SIZE)):
if pluck:
for row in result:
yield row[0]

elif as_dict:
keys = [column[0] for column in self._cursor.description]
for row in result:
row = frappe._dict(zip(keys, row))
if update:
row.update(update)
yield row

elif as_list:
for row in result:
yield list(row)
else:
frappe.throw(_("`as_iterator` only works with `as_list=True` or `as_dict=True`"))

self._clean_up()

def _log_query(
self,
Expand Down Expand Up @@ -404,9 +450,8 @@ def check_implicit_commit(self, query):
):
raise ImplicitCommitError("This statement can cause implicit commit")

def fetch_as_dict(self, formatted=0, as_utf8=0) -> list[frappe._dict]:
"""Internal. Converts results to dict."""
result = self.last_result
def fetch_as_dict(self, result, as_utf8=False) -> list[frappe._dict]:
"""Internal. Convert results to dict."""
if result:
keys = [column[0] for column in self._cursor.description]

Expand Down Expand Up @@ -1375,6 +1420,22 @@ def enqueue_jobs_after_commit():
def rename_column(self, doctype: str, old_column_name: str, new_column_name: str):
raise NotImplementedError

@contextmanager
def unbuffered_cursor(self):
"""Context manager to temporarily use unbuffered cursor.
Using this with `as_iterator=True` provides O(1) memory usage while reading large result sets.
NOTE: You MUST do entire result set processing in the context, otherwise underlying cursor
will be switched and you'll not get complete results.
Usage:
with frappe.db.unbuffered_cursor():
for row in frappe.db.sql("query with huge result", as_iterator=True):
continue # Do some processing.
"""
raise NotImplementedError


@contextmanager
def savepoint(catch: type | tuple[type, ...] = Exception):
Expand Down
20 changes: 20 additions & 0 deletions frappe/database/mariadb/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from contextlib import contextmanager

import pymysql
from pymysql.constants import ER, FIELD_TYPE
Expand Down Expand Up @@ -204,6 +205,13 @@ def log_query(self, query, values, debug, explain):
self._log_query(self.last_query, debug, explain, query)
return self.last_query

def _clean_up(self):
# PERF: Erase internal references of pymysql to trigger GC as soon as
# results are consumed.
self._cursor._result = None
self._cursor._rows = None
self._cursor.connection._result = None

@staticmethod
def escape(s, percent=True):
"""Excape quotes and percent in given string."""
Expand Down Expand Up @@ -444,3 +452,15 @@ def get_tables(self, cached=True):
frappe.cache().set_value("db_tables", tables)

return tables

@contextmanager
def unbuffered_cursor(self):
from pymysql.cursors import SSCursor

try:
original_cursor = self._cursor
new_cursor = self._cursor = self._conn.cursor(SSCursor)
yield
finally:
self._cursor = original_cursor
new_cursor.close()
52 changes: 34 additions & 18 deletions frappe/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from frappe.core.utils import find
from frappe.custom.doctype.custom_field.custom_field import create_custom_field
from frappe.database import savepoint
from frappe.database.database import Database, get_query_execution_timeout
from frappe.database.database import get_query_execution_timeout
from frappe.database.utils import FallBackDateTimeStr
from frappe.query_builder import Field
from frappe.query_builder.functions import Concat_ws
Expand Down Expand Up @@ -769,23 +769,6 @@ def test_cleared_cache(self):
cached_doc = frappe.get_cached_doc(self.todo2.doctype, self.todo2.name)
self.assertEqual(cached_doc.description, description)

def test_update_alias(self):
args = (self.todo1.doctype, self.todo1.name, "description", "Updated by `test_update_alias`")
kwargs = {
"for_update": False,
"modified": None,
"modified_by": None,
"update_modified": True,
"debug": False,
}

self.assertTrue("return self.set_value(" in inspect.getsource(frappe.db.update))

with patch.object(Database, "set_value") as set_value:
frappe.db.update(*args, **kwargs)
set_value.assert_called_once()
set_value.assert_called_with(*args, **kwargs)

@classmethod
def tearDownClass(cls):
frappe.db.rollback()
Expand Down Expand Up @@ -944,3 +927,36 @@ def inner():

outer()
self.assertEqual(write_connection, db_id())


class TestSqlIterator(FrappeTestCase):
def test_db_sql_iterator(self):
test_queries = [
"select * from `tabCountry` order by name",
"select code from `tabCountry` order by name",
"select code from `tabCountry` order by name limit 5",
]

for query in test_queries:
self.assertEqual(
frappe.db.sql(query, as_dict=True),
list(frappe.db.sql(query, as_dict=True, as_iterator=True)),
msg=f"{query=} results not same as iterator",
)

self.assertEqual(
frappe.db.sql(query, pluck=True),
list(frappe.db.sql(query, pluck=True, as_iterator=True)),
msg=f"{query=} results not same as iterator",
)

self.assertEqual(
frappe.db.sql(query, as_list=True),
list(frappe.db.sql(query, as_list=True, as_iterator=True)),
msg=f"{query=} results not same as iterator",
)

@run_only_if(db_type_is.MARIADB)
def test_unbuffered_cursor(self):
with frappe.db.unbuffered_cursor():
self.test_db_sql_iterator()
17 changes: 17 additions & 0 deletions frappe/tests/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
>>> get_controller("User")
"""
import gc
import sys
import time
import unittest
from unittest.mock import patch
Expand Down Expand Up @@ -161,3 +163,18 @@ def test_no_ifnull_checks(self):
query = frappe.get_all("DocType", {"autoname": ("is", "set")}, run=0).lower()
self.assertNotIn("coalesce", query)
self.assertNotIn("ifnull", query)

def test_no_stale_ref_sql(self):
"""frappe.db.sql should not hold any internal references to result set.
pymysql stores results internally. If your code reads a lot and doesn't make another
query, for that entire duration there's copy of result consuming memory in internal
attributes of pymysql.
We clear it manually, this test ensures that it actually works.
"""

query = "select * from tabUser"
for kwargs in ({}, {"as_dict": True}, {"as_list": True}):
result = frappe.db.sql(query, **kwargs)
self.assertEqual(sys.getrefcount(result), 2) # Note: This always returns +1
self.assertFalse(gc.get_referrers(result))
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies = [
"Jinja2~=3.1.2",
"Pillow~=10.0.1",
"PyJWT~=2.4.0",
# We depend on internal attributes,
# do NOT add loose requirements on PyMySQL versions.
"PyMySQL==1.0.3",
"PyPDF2~=2.1.0",
"PyPika~=0.48.9",
Expand Down

0 comments on commit 7f3a12b

Please sign in to comment.