Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: search_link fails when txt contains parentheses #22892

Merged
merged 8 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions frappe/desk/reportview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import json

from sql_metadata import Parser

import frappe
import frappe.permissions
from frappe import _
Expand Down Expand Up @@ -91,7 +93,10 @@ def validate_fields(data):
wildcard = update_wildcard_field_param(data)

for field in list(data.fields or []):
fieldname = extract_fieldname(field)
fieldname = extract_fieldnames(field)[0]
if not fieldname:
raise_invalid_field(fieldname)

if is_standard(fieldname):
continue

Expand Down Expand Up @@ -173,23 +178,16 @@ def is_standard(fieldname):
)


def extract_fieldname(field):
for text in (",", "/*", "#"):
if text in field:
raise_invalid_field(field)
def extract_fieldnames(field):
parser = Parser(f"select {field}, _frappe_dummy from _dummy")
columns = [col for col in parser.columns if col != "_frappe_dummy"]

fieldname = field
for sep in (" as ", " AS "):
if sep in fieldname:
fieldname = fieldname.split(sep, 1)[0]
if not columns:
f = field.lower()
if "count(" in f or "sum(" in f or "avg(" in f:
return ["*"]

# certain functions allowed, extract the fieldname from the function
if fieldname.startswith("count(") or fieldname.startswith("sum(") or fieldname.startswith("avg("):
if not fieldname.strip().endswith(")"):
raise_invalid_field(field)
fieldname = fieldname.split("(", 1)[1][:-1]

return fieldname
return columns


def get_meta_and_docfield(fieldname, data):
Expand Down Expand Up @@ -236,13 +234,13 @@ def get_parenttype_and_fieldname(field, data):
parts = field.split(".")
parenttype = parts[0]
fieldname = parts[1]
if parenttype.startswith("`tab"):
# `tabChild DocType`.`fieldname`
parenttype = parenttype[4:-1]
fieldname = fieldname.strip("`")
df = frappe.get_meta(data.doctype).get_field(parenttype)
if not df:
# tabChild DocType.fieldname
parenttype = parenttype[3:]
else:
# tablefield.fieldname
parenttype = frappe.get_meta(data.doctype).get_field(parenttype).options
parenttype = df.options
else:
parenttype = data.doctype
fieldname = field.strip("`")
Expand Down
39 changes: 16 additions & 23 deletions frappe/model/db_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
STRICT_FIELD_PATTERN = re.compile(r".*/\*.*")
STRICT_UNION_PATTERN = re.compile(r".*\s(union).*\s")
ORDER_GROUP_PATTERN = re.compile(r".*[^a-z0-9-_ ,`'\"\.\(\)].*")
FN_PARAMS_PATTERN = re.compile(r".*?\((.*)\).*")
SPECIAL_FIELD_CHARS = frozenset(("(", "`", ".", "'", '"', "*"))


Expand Down Expand Up @@ -626,6 +625,8 @@ def apply_fieldlevel_read_permissions(self):
- Query: fields=["*"]
- Result: fields=["title", ...] // will also include Frappe's meta field like `name`, `owner`, etc.
"""
from frappe.desk.reportview import extract_fieldnames

if self.flags.ignore_permissions:
return

Expand All @@ -638,23 +639,18 @@ def apply_fieldlevel_read_permissions(self):
)

for i, field in enumerate(self.fields):
if "distinct" in field.lower():
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
if _fn := FN_PARAMS_PATTERN.findall(field):
column = _fn[0].replace("distinct ", "").replace("DISTINCT ", "").replace("`", "")
# field: 'distinct name'
# column: 'name'
else:
column = field.split(" ", 2)[1].replace("`", "")
else:
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
column = field.split("(")[-1].split(")", 1)[0]
column = strip_alias(column).replace("`", "")
# field: 'count(distinct `tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
# field: 'count(`tabPhoto`.name) as total_count'
# column: 'tabPhoto.name'
columns = extract_fieldnames(field)
if not columns:
continue

if column == "*" and not in_function("*", field):
asterisk_fields.append(i)
column = columns[0]
if column == "*":
if "*" in field and not in_function("*", field):
asterisk_fields.append(i)
continue

# handle pseudo columns
Expand Down Expand Up @@ -693,12 +689,9 @@ def apply_fieldlevel_read_permissions(self):
elif "(" in field:
if "*" in field:
continue
elif _params := FN_PARAMS_PATTERN.findall(field):
params = (x.strip() for x in _params[0].split(","))
for param in params:
if not (
not param or param in permitted_fields or param.isnumeric() or "'" in param or '"' in param
):
else:
for column in columns:
if not column in permitted_fields:
self.remove_field(i)
break
continue
Expand Down
62 changes: 61 additions & 1 deletion frappe/tests/test_reportview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License: MIT. See LICENSE

import frappe
from frappe.desk.reportview import export_query
from frappe.desk.reportview import export_query, extract_fieldnames
from frappe.tests.utils import FrappeTestCase


Expand Down Expand Up @@ -32,3 +32,63 @@ def test_csv(self):
for row in reader:
self.assertEqual(int(row["Is Single"]), 1)
self.assertEqual(row["Module"], "Core")

def test_extract_fieldname(self):
self.assertEqual(
extract_fieldnames("count(distinct `tabPhoto`.name) as total_count")[0], "tabPhoto.name"
)

self.assertEqual(extract_fieldnames("owner")[0], "owner")

self.assertEqual(extract_fieldnames("module")[0], "module")

self.assertEqual(extract_fieldnames("count(`tabPhoto`.name) as total_count")[0], "tabPhoto.name")

self.assertEqual(extract_fieldnames("count(distinct `tabPhoto`.name)")[0], "tabPhoto.name")

self.assertEqual(extract_fieldnames("count(`tabPhoto`.name)")[0], "tabPhoto.name")

self.assertEqual(
extract_fieldnames("count(distinct `tabJob Applicant`.name) as total_count")[0],
"tabJob Applicant.name",
)

self.assertEqual(
extract_fieldnames("(1 / nullif(locate('a', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)

self.assertEqual(
extract_fieldnames("(1 / nullif(locate('(a)', `tabAddress`.`name`), 0)) as `_relevance`")[0],
"tabAddress.name",
)

self.assertEqual(
extract_fieldnames("EXTRACT(MONTH FROM date_column) AS month")[0], "date_column"
)

self.assertEqual(extract_fieldnames("COUNT(*) AS count")[0], "*")

self.assertEqual(extract_fieldnames("COUNT(1) AS count")[0], "*")

self.assertEqual(extract_fieldnames("COUNT(1) AS count, SUM(1) AS sum")[0], "*")

self.assertEqual(
extract_fieldnames("first_name + ' ' + last_name AS full_name"), ["first_name", "last_name"]
)

self.assertEqual(
extract_fieldnames("CONCAT(first_name, ' ', last_name) AS full_name"),
["first_name", "last_name"],
)

self.assertEqual(
extract_fieldnames("CONCAT(id, '/', name, '/', age, '/', marks) AS student"),
["id", "name", "age", "marks"],
)

self.assertEqual(extract_fieldnames("tablefield.fiedname")[0], "tablefield.fiedname")

self.assertEqual(
extract_fieldnames("`tabChild DocType`.`fiedname`")[0], "tabChild DocType.fiedname"
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dependencies = [
"semantic-version~=2.10.0",
"sentry-sdk~=1.37.1",
"sqlparse~=0.4.4",
"sql_metadata~=2.9.0",
"tenacity~=8.2.2",
"terminaltables~=3.1.10",
"traceback-with-variables~=2.0.4",
Expand Down
Loading