Skip to content

Commit

Permalink
Merge pull request #520 from opencybersecurityalliance/refactor_cond
Browse files Browse the repository at this point in the history
refactor SQL conditional translation
  • Loading branch information
pcoccoli committed May 16, 2024
2 parents 8c88862 + 0f5a496 commit 0db6807
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 44 deletions.
39 changes: 19 additions & 20 deletions packages/kestrel_core/src/kestrel/interface/codegen/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,6 @@
}


@typechecked
def _render_comp(comp: FComparison) -> BinaryExpression:
col: ColumnClause = column(comp.field)
if comp.op == StrCompOp.NMATCHES:
return ~comp2func[comp.op](col, comp.value)
return comp2func[comp.op](col, comp.value)


@typechecked
def _render_multi_comp(comps: MultiComp) -> BooleanClauseList:
op = and_ if comps.op == ExpOp.AND else or_
return reduce(op, map(_render_comp, comps.comps))


@typechecked
class SqlTranslator:
def __init__(
Expand All @@ -87,19 +73,32 @@ def __init__(
# SQLAlchemy statement object
self.query: Select = select("*").select_from(from_obj)

@typechecked
def _render_comp(self, comp: FComparison) -> BinaryExpression:
col: ColumnClause = column(comp.field)
if comp.op == StrCompOp.NMATCHES:
return ~comp2func[comp.op](col, comp.value)
return comp2func[comp.op](col, comp.value)

@typechecked
def _render_multi_comp(self, comps: MultiComp) -> BooleanClauseList:
op = and_ if comps.op == ExpOp.AND else or_
return reduce(op, map(self._render_comp, comps.comps))

@typechecked
def _render_exp(self, exp: BoolExp) -> BooleanClauseList:
if isinstance(exp.lhs, BoolExp):
lhs = self._render_exp(exp.lhs)
elif isinstance(exp.lhs, MultiComp):
lhs = _render_multi_comp(exp.lhs)
lhs = self._render_multi_comp(exp.lhs)
else:
lhs = _render_comp(exp.lhs)
lhs = self._render_comp(exp.lhs)
if isinstance(exp.rhs, BoolExp):
rhs = self._render_exp(exp.rhs)
elif isinstance(exp.rhs, MultiComp):
rhs = _render_multi_comp(exp.rhs)
rhs = self._render_multi_comp(exp.rhs)
else:
rhs = _render_comp(exp.rhs)
rhs = self._render_comp(exp.rhs)
return and_(lhs, rhs) if exp.op == ExpOp.AND else or_(lhs, rhs)

def add_Filter(self, filt: Filter) -> None:
Expand All @@ -120,9 +119,9 @@ def add_Filter(self, filt: Filter) -> None:
if isinstance(exp, BoolExp):
comp = self._render_exp(exp)
elif isinstance(exp, MultiComp):
comp = _render_multi_comp(exp)
comp = self._render_multi_comp(exp)
else:
comp = _render_comp(exp)
comp = self._render_comp(exp)
self.query = self.query.where(comp)

def add_ProjectAttrs(self, proj: ProjectAttrs) -> None:
Expand Down
1 change: 0 additions & 1 deletion packages/kestrel_interface_opensearch/tests/test_ossql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def _remove_nl(s):
]
)
def test_opensearch_translator(iseq, sql):
cols = '`CommandLine` AS `cmd_line`, `Image` AS `file.path`, `ProcessId` AS `pid`, `ParentProcessId` AS `parent_process.pid`'
if ProjectEntity in {type(i) for i in iseq}:
cols = '`CommandLine` AS `cmd_line`, `Image` AS `file.path`, `ProcessId` AS `pid`, `ParentProcessId` AS `parent_process.pid`'
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from pandas import DataFrame, read_sql
import sqlalchemy
from sqlalchemy import and_, column, or_
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy import column, or_
from sqlalchemy.sql.expression import ColumnClause
from typeguard import typechecked

Expand Down Expand Up @@ -84,27 +83,6 @@ def _render_comp(self, comp: FComparison):
translated_comps.append(tmp)
return reduce(or_, translated_comps)

@typechecked
def _render_multi_comp(self, comps: MultiComp):
op = and_ if comps.op == ExpOp.AND else or_
return reduce(op, map(self._render_comp, comps.comps))

# This is copied verbatim from sql.py but we need to supply our own _render_comp
def _render_exp(self, exp: BoolExp) -> BooleanClauseList:
if isinstance(exp.lhs, BoolExp):
lhs = self._render_exp(exp.lhs)
elif isinstance(exp.lhs, MultiComp):
lhs = self._render_multi_comp(exp.lhs)
else:
lhs = self._render_comp(exp.lhs)
if isinstance(exp.rhs, BoolExp):
rhs = self._render_exp(exp.rhs)
elif isinstance(exp.rhs, MultiComp):
rhs = self._render_multi_comp(exp.rhs)
else:
rhs = self._render_comp(exp.rhs)
return and_(lhs, rhs) if exp.op == ExpOp.AND else or_(lhs, rhs)

@typechecked
def _add_filter(self) -> Optional[str]:
if not self.filt:
Expand Down

0 comments on commit 0db6807

Please sign in to comment.