Skip to content

Commit

Permalink
fix: search_link fails when txt contains parentheses (#22892)
Browse files Browse the repository at this point in the history
* fix: search_link fails when txt contains parentheses

* fix: updating regex to replace number params also

* chore: replacing regex with sqlparse

* chore: not including fields like count(1) in asterisk_fields

* fix: owner/module not identified as column

* chore: lint fix and removing exception

* refactor: better function name

---------

Co-authored-by: Ankush Menat <ankush@frappe.io>
  • Loading branch information
ssuda and ankush committed Jan 16, 2024
1 parent 7123f50 commit 642e9f4
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 45 deletions.
40 changes: 19 additions & 21 deletions frappe/desk/reportview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import json

from sql_metadata import Parser

import frappe
import frappe.permissions
from frappe import _
Expand Down Expand Up @@ -91,7 +93,10 @@ def validate_fields(data):
wildcard = update_wildcard_field_param(data)

for field in list(data.fields or []):
fieldname = extract_fieldname(field)
fieldname = extract_fieldnames(field)[0]
if not fieldname:
raise_invalid_field(fieldname)

if is_standard(fieldname):
continue

Expand Down Expand Up @@ -173,23 +178,16 @@ def is_standard(fieldname):
)


def extract_fieldname(field):
for text in (",", "/*", "#"):
if text in field:
raise_invalid_field(field)
def extract_fieldnames(field):
parser = Parser(f"select {field}, _frappe_dummy from _dummy")
columns = [col for col in parser.columns if col != "_frappe_dummy"]

fieldname = field
for sep in (" as ", " AS "):
if sep in fieldname:
fieldname = fieldname.split(sep, 1)[0]
if not columns:
f = field.lower()
if "count(" in f or "sum(" in f or "avg(" in f:
return ["*"]

# certain functions allowed, extract the fieldname from the function
if fieldname.startswith("count(") or fieldname.startswith("sum(") or fieldname.startswith("avg("):
if not fieldname.strip().endswith(")"):
raise_invalid_field(field)
fieldname = fieldname.split("(", 1)[1][:-1]

return fieldname
return columns


def get_meta_and_docfield(fieldname, data):
Expand Down Expand Up @@ -236,13 +234,13 @@ def get_parenttype_and_fieldname(field, data):
parts = field.split(".")
parenttype = parts[0]
fieldname = parts[1]
if parenttype.startswith("`tab"):
# `tabChild DocType`.`fieldname`
parenttype = parenttype[4:-1]
fieldname = fieldname.strip("`")
df = frappe.get_meta(data.doctype).get_field(parenttype)
if not df:
# tabChild DocType.fieldname
parenttype = parenttype[3:]
else:
# tablefield.fieldname
parenttype = frappe.get_meta(data.doctype).get_field(parenttype).options
parenttype = df.options
else:
parenttype = data.doctype
fieldname = field.strip("`")
Expand Down
39 changes: 16 additions & 23 deletions frappe/model/db_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
STRICT_FIELD_PATTERN = re.compile(r".*/\*.*")
STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s")
ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*")
FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*")
SPECIAL_FIELD_CHARS = frozenset(("(", "`", ".", "'", '"', "*"))


Expand Down Expand Up @@ -626,6 +625,8 @@ def apply_fieldlevel_read_permissions(self):
- Query: fields=["*"]
- Result: fields=["title", ...] // will also include Frappe's meta field like `name`, `owner`, etc.
"""
from frappe.desk.reportview import extract_fieldnames

if self.flags.ignore_permissions:
return

Expand All @@ -638,23 +639,18 @@ def apply_fieldlevel_read_permissions(self):
)

for i, field in enumerate(self.fields):
if "distinct" in field.lower():
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
if _fn := FN_PARAMS_PATTERN.findall(field):
column = _fn[0].replace("distinct ", "").replace("DISTINCT ", "").replace("`", "")
# field: 'distinct name'
# column: 'name'
else:
column = field.split(" ", 2)[1].replace("`", "")
else:
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
column = field.split("(")[-1].split(")", 1)[0]
column = strip_alias(column).replace("`", "")
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
columns = extract_fieldnames(field)
if not columns:
continue

if column == "*" and not in_function("*", field):
asterisk_fields.append(i)
column = columns[0]
if column == "*":
if "*" in field and not in_function("*", field):
asterisk_fields.append(i)
continue

# handle pseudo columns
Expand Down Expand Up @@ -693,12 +689,9 @@ def apply_fieldlevel_read_permissions(self):
elif "(" in field:
if "*" in field:
continue
elif _params := FN_PARAMS_PATTERN.findall(field):
params = (x.strip() for x in _params[0].split(","))
for param in params:
if not (
not param or param in permitted_fields or param.isnumeric() or "'" in param or '"' in param
):
else:
for column in columns:
if not column in permitted_fields:
self.remove_field(i)
break
continue
Expand Down
62 changes: 61 additions & 1 deletion frappe/tests/test_reportview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License: MIT. See LICENSE

import frappe
from frappe.desk.reportview import export_query
from frappe.desk.reportview import export_query, extract_fieldnames
from frappe.tests.utils import FrappeTestCase


Expand Down Expand Up @@ -32,3 +32,63 @@ def test_csv(self):
for row in reader:
self.assertEqual(int(row["Is Single"]), 1)
self.assertEqual(row["Module"], "Core")

def test_extract_fieldname(self):
self.assertEqual(
extract_fieldnames("count(distinct `tabPhoto`.name) as total_count")[0], "tabPhoto.name"
)

self.assertEqual(extract_fieldnames("owner")[0], "owner")

self.assertEqual(extract_fieldnames("module")[0], "module")

self.assertEqual(extract_fieldnames("count(`tabPhoto`.name) as total_count")[0], "tabPhoto.name")

self.assertEqual(extract_fieldnames("count(distinct `tabPhoto`.name)")[0], "tabPhoto.name")

self.assertEqual(extract_fieldnames("count(`tabPhoto`.name)")[0], "tabPhoto.name")

self.assertEqual(
extract_fieldnames("count(distinct `tabJob Applicant`.name) as total_count")[0],
"tabJob Applicant.name",
)

self.assertEqual(
extract_fieldnames("(1 / nullif(locate('a', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)

self.assertEqual(
extract_fieldnames("(1 / nullif(locate('(a)', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)

self.assertEqual(
extract_fieldnames("EXTRACT(MONTH FROM date_column) AS month")[0], "date_column"
)

self.assertEqual(extract_fieldnames("COUNT(*) AS count")[0], "*")

self.assertEqual(extract_fieldnames("COUNT(1) AS count")[0], "*")

self.assertEqual(extract_fieldnames("COUNT(1) AS count, SUM(1) AS sum")[0], "*")

self.assertEqual(
extract_fieldnames("first_name + ' ' + last_name AS full_name"), ["first_name", "last_name"]
)

self.assertEqual(
extract_fieldnames("CONCAT(first_name, ' ', last_name) AS full_name"),
["first_name", "last_name"],
)

self.assertEqual(
extract_fieldnames("CONCAT(id, '/', name, '/', age, '/', marks) AS student"),
["id", "name", "age", "marks"],
)

self.assertEqual(extract_fieldnames("tablefield.fiedname")[0], "tablefield.fiedname")

self.assertEqual(
extract_fieldnames("`tabChild DocType`.`fiedname`")[0], "tabChild DocType.fiedname"
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dependencies = [
"semantic-version~=2.10.0",
"sentry-sdk~=1.37.1",
"sqlparse~=0.4.4",
"sql_metadata~=2.9.0",
"tenacity~=8.2.2",
"terminaltables~=3.1.10",
"traceback-with-variables~=2.0.4",
Expand Down

0 comments on commit 642e9f4

Please sign in to comment.