Skip to content

Commit

Permalink
Merge pull request #519 from opencybersecurityalliance/map_attrs
Browse files Browse the repository at this point in the history
frontend attr mapping
  • Loading branch information
pcoccoli committed May 15, 2024
2 parents 1255b9f + 2a4b185 commit 8c88862
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 72 deletions.
43 changes: 39 additions & 4 deletions packages/kestrel_core/src/kestrel/frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from lark import Transformer, Token
from typeguard import typechecked

from kestrel.mapping.data_model import translate_comparison_to_ocsf
from kestrel.mapping.data_model import (
translate_comparison_to_ocsf,
translate_projection_to_ocsf,
)
from kestrel.utils import unescape_quoted_string
from kestrel.ir.filter import (
FExpression,
Expand All @@ -35,6 +38,7 @@
Construct,
DataSource,
Filter,
Instruction,
Limit,
Offset,
ProjectAttrs,
Expand Down Expand Up @@ -178,8 +182,9 @@ def statement(self, args):

def assignment(self, args):
# TODO: move the var+var into expression in Lark
variable_node = Variable(args[0].value)
graph, root = args[1]
entity_type, native_type = self._get_type_from_predecessors(graph, root)
variable_node = Variable(args[0].value, entity_type, native_type)
graph.add_node(variable_node, root)
return graph

Expand Down Expand Up @@ -211,10 +216,12 @@ def new(self, args):
graph = IRGraph()
if len(args) == 1:
# Try to get entity type from first entity
entity_type = None
data = args[0]
else:
entity_type = args[0].value
data = args[1]
data_node = Construct(data)
data_node = Construct(data, entity_type)
graph.add_node(data_node)
return graph, data_node

Expand Down Expand Up @@ -264,7 +271,9 @@ def get(self, args):
# add reference nodes if used in Filter
_add_reference_branches_for_filter(graph, filter_node)

projection_node = graph.add_node(ProjectEntity(mapped_entity_name), filter_node)
projection_node = graph.add_node(
ProjectEntity(mapped_entity_name, entity_name), filter_node
)
root = projection_node
if len(args) > 3:
for arg in args[3:]:
Expand Down Expand Up @@ -385,6 +394,16 @@ def offset_clause(self, args):

def disp(self, args):
graph, root = args[0]
_logger.debug("disp: root = %s", root)
if isinstance(root, ProjectAttrs):
# Map attrs to OCSF
entity_type, native_type = self._get_type_from_predecessors(graph, root)
_logger.debug(
"Map %s attrs to OCSF %s in %s", native_type, entity_type, root
)
root.attrs = translate_projection_to_ocsf(
self.property_map, native_type, entity_type, root.attrs
)
graph.add_node(Return(), root)
return graph

Expand All @@ -394,3 +413,19 @@ def explain(self, args):
explain = graph.add_node(Explain(), reference)
graph.add_node(Return(), explain)
return graph

def _get_type_from_predecessors(self, graph: IRGraph, root: Instruction):
stack = [root]
native_type = None
entity_type = None
while stack and not all((native_type, entity_type)):
curr = stack.pop()
_logger.debug("_get_type: curr = %s", curr)
stack.extend(graph.predecessors(curr))
if isinstance(curr, Construct):
native_type = curr.entity_type
entity_type = self.entity_map.get(native_type, native_type)
elif isinstance(curr, Variable):
native_type = curr.native_type
entity_type = curr.entity_type
return entity_type, native_type
15 changes: 0 additions & 15 deletions packages/kestrel_core/src/kestrel/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,6 @@ def get_variables(self) -> Iterable[Variable]:
var_names = {v.name for v in self.get_nodes_by_type(Variable)}
return [self.get_variable(var_name) for var_name in var_names]

def add_variable(
self, vx: Union[str, Variable], dependent_node: Instruction
) -> Variable:
"""Create new variable (if needed) and add to IRGraph
Parameters:
vx: variable name (str) or already created node (Variable)
dependent_node: the instruction to which the variable refer
Returns:
The variable node created/added
"""
v = Variable(vx) if isinstance(vx, str) else vx
return self.add_node(v, dependent_node)

def get_reference(self, ref_name: str) -> Reference:
"""Get a Kestrel reference by its name
Expand Down
4 changes: 4 additions & 0 deletions packages/kestrel_core/src/kestrel/ir/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def resolve_references(self, f: Callable[[ReferenceValue], Any]):
@dataclass(eq=False)
class ProjectEntity(SolePredecessorTransformingInstruction):
entity_type: str
native_type: str


@dataclass(eq=False)
Expand Down Expand Up @@ -156,6 +157,8 @@ def __post_init__(self, uri: Optional[str], default_interface: Optional[str]):
@dataclass(eq=False)
class Variable(SolePredecessorTransformingInstruction):
name: str
entity_type: str
native_type: str
# required to dereference a variable that has been created multiple times
# the variable with the largest version will be used by dereference
version: int = 0
Expand Down Expand Up @@ -186,6 +189,7 @@ class Offset(SolePredecessorTransformingInstruction):
@dataclass(eq=False)
class Construct(SourceInstruction):
data: List[Mapping[str, Union[str, int, bool]]]
entity_type: Optional[str] = None
interface: str = CACHE_INTERFACE_IDENTIFIER


Expand Down
35 changes: 34 additions & 1 deletion packages/kestrel_core/src/kestrel/mapping/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,46 @@ def translate_projection_to_native(
return result


@typechecked
def translate_projection_to_ocsf(
dmm: dict,
native_type: Optional[str],
entity_type: Optional[str],
attrs: list,
) -> list:
result = []
for attr in attrs:
mapping = dmm.get(attr)
if not mapping and native_type:
mapping = dmm.get(f"{native_type}:{attr}", attr) # FIXME: only for STIX
else:
mapping = attr
ocsf_name = _get_from_mapping(mapping, "ocsf_field")
if isinstance(ocsf_name, list):
result.extend(ocsf_name)
else:
result.append(ocsf_name)
if entity_type:
# Need to prune the entity name
prefix = f"{entity_type}."
result = [
field[len(prefix) :] if field.startswith(prefix) else field
for field in result
]
return result


@typechecked
def translate_dataframe(df: DataFrame, dmm: dict) -> DataFrame:
# Translate results into Kestrel OCSF data model
# The column names of df are already mapped
df = df.replace({np.nan: None})
for col in df.columns:
mapping = dpath.get(dmm, col, separator=".")
try:
mapping = dpath.get(dmm, col, separator=".")
except KeyError:
_logger.debug("No mapping for %s", col)
mapping = None
if isinstance(mapping, dict):
transformer_name = mapping.get("ocsf_value")
df[col] = run_transformer_on_series(transformer_name, df[col])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_evaluate_Construct():
, {"name": "firefox.exe", "pid": 201}
, {"name": "chrome.exe", "pid": 205}
]
ins = Construct(data)
ins = Construct(data, "process")
df = evaluate_source_instruction(ins)
assert df.equals(DataFrame(data))


def test_non_exist_eval():
with pytest.raises(NotImplementedError):
evaluate_transforming_instruction(Variable("asdf"), DataFrame())
evaluate_transforming_instruction(Variable("asdf", "foo", "bar"), DataFrame())


def test_evaluate_Limit():
Expand Down
63 changes: 20 additions & 43 deletions packages/kestrel_core/tests/test_ir_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def test_get_node_by_id():
def test_get_nodes_by_type_and_attributes():
g = IRGraph()
s = g.add_datasource("stixshifter://abc")
v1 = g.add_variable("asdf", s)
v2 = g.add_variable("qwer", s)
v3 = g.add_variable("123", s)
v1 = g.add_node(Variable("asdf", "x", "y"), s)
v2 = g.add_node(Variable("qwer", "u", "v"), s)
v3 = g.add_node(Variable("123", "i", "j"), s)
ns = g.get_nodes_by_type_and_attributes(Variable, {"name": "asdf"})
assert ns == [v1]

Expand All @@ -72,36 +72,13 @@ def test_get_returns():
assert len(g.get_sink_nodes()) == 3


def test_add_variable():
g = IRGraph()
s = g.add_datasource("stixshifter://abc")
v1 = g.add_variable("asdf", s)
assert len(g) == 2
assert len(g.edges()) == 1

v2 = g.add_variable("asdf", s)
assert len(g) == 3
assert len(g.edges()) == 2

v = Variable("asdf")
v3 = g.add_variable(v, s)
assert v == v3
v4 = g.add_variable(v, s)
assert v3 == v4

assert v1.version == 0
assert v2.version == 1
assert v3.version == 2
assert len(g) == 4
assert len(g.edges()) == 3


def test_get_variables():
g = IRGraph()
s = g.add_datasource("stixshifter://abc")
v1 = g.add_variable("asdf", s)
v2 = g.add_variable("asdf", s)
v3 = g.add_variable("asdf", s)
v = Variable("asdf", "foo", "bar")
v1 = g.add_node(v, s)
v2 = g.add_node(v, s)
v3 = g.add_node(v, s)
vs = g.get_variables()
assert len(vs) == 1
assert vs[0].name == "asdf"
Expand All @@ -110,11 +87,11 @@ def test_get_variables():
def test_add_get_reference():
g = IRGraph()
s = g.add_node(DataSource("ss://ee"))
g.add_node(Variable("asdf"), s)
g.add_node(Variable("asdf", "foo", "bar"), s)
g.add_node(Reference("asdf"))
q1 = g.add_node(Reference("qwer"))
q2 = g.add_node(Reference("qwer"))
g.add_node(Variable("qwer"), s)
g.add_node(Variable("qwer", "foo", "bar"), s)
g.add_node(Reference("qwer"))
assert len(g) == 4
assert len(g.edges()) == 2
Expand Down Expand Up @@ -150,15 +127,15 @@ def test_deepcopy_graph():
def test_update_graph():
g = IRGraph()
s = g.add_datasource("stixshifter://abc")
v1 = g.add_variable("asdf", s)
v2 = g.add_variable("asdf", s)
v3 = g.add_variable("asdf", s)
v1 = g.add_node(Variable("asdf", "foo", "bar"), s)
v2 = g.add_node(Variable("asdf", "foo", "bar"), s)
v3 = g.add_node(Variable("asdf", "foo", "bar"), s)
r1 = g.add_return(v3)

g2 = IRGraph()
s2 = g2.add_datasource("stixshifter://abc")
v4 = g2.add_variable("asdf", g2.add_node(Reference("asdf")))
v5 = g2.add_variable("asdf", g2.add_node(TransformingInstruction(), s2))
v4 = g2.add_node(Variable("asdf", "foo", "bar"), g2.add_node(Reference("asdf")))
v5 = g2.add_node(Variable("asdf", "foo", "bar"), g2.add_node(TransformingInstruction(), s2))
r2 = g2.add_return(v5)

assert v1.version == 0
Expand Down Expand Up @@ -193,7 +170,7 @@ def test_serialization_deserialization():
g1 = IRGraph()
s = g1.add_node(DataSource("ss://ee"))
r = g1.add_node(Reference("asdf"))
v = g1.add_node(Variable("asdf"), s)
v = g1.add_node(Variable("asdf", "foo", "bar"), s)
j = g1.to_json()
g2 = IRGraph(j)
assert s in g2.nodes()
Expand All @@ -206,21 +183,21 @@ def test_find_cached_dependent_subgraph_of_node():
g = IRGraph()

a1 = g.add_node(DataSource("ss://ee"))
a2 = g.add_node(Variable("asdf"), a1)
a2 = g.add_node(Variable("asdf", "foo", "bar"), a1)
a3 = g.add_node(Instruction())
g.add_edge(a2, a3)
a4 = g.add_node(Variable("qwer"), a3)
a4 = g.add_node(Variable("qwer", "foo", "bar"), a3)

b1 = g.add_node(DataSource("ss://eee"))
b2 = g.add_node(Variable("asdfe"), b1)
b2 = g.add_node(Variable("asdfe", "foo", "bar"), b1)
b3 = g.add_node(Instruction())
g.add_edge(b2, b3)
b4 = g.add_node(Variable("qwere"), b3)
b4 = g.add_node(Variable("qwere", "foo", "bar"), b3)

c1 = g.add_node(Instruction())
g.add_edge(a4, c1)
g.add_edge(b4, c1)
c2 = g.add_node(Variable("zxcv"), c1)
c2 = g.add_node(Variable("zxcv", "foo", "bar"), c1)

g2 = g.find_cached_dependent_subgraph_of_node(c2, InMemoryCache())
assert networkx.utils.graphs_equal(g, g2)
Expand Down
10 changes: 5 additions & 5 deletions packages/kestrel_core/tests/test_ir_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@


def test_instruction_post_init():
v = Variable("asdf")
v = Variable("asdf", "foo", "bar")
j = v.to_dict()
assert "id" in j
assert "instruction" in j
assert j["instruction"] == "Variable"


def test_stable_id():
v = Variable("asdf")
v = Variable("asdf", "foo", "bar")
_id = v.id
v.name = "qwer"
assert v.id == _id
Expand All @@ -48,7 +48,7 @@ def test_eq():

def test_get_instruction_class():
cls = get_instruction_class("Variable")
v = cls("asdf")
v = cls("asdf", "foo", "bar")
assert cls == Variable
assert isinstance(v, Variable)

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_construct():


def test_instruction_from_dict():
v = Variable("asdf")
v = Variable("asdf", "foo", "bar")
d = v.to_dict()
w = instruction_from_dict(d)
assert w == v
Expand All @@ -97,7 +97,7 @@ def test_instruction_from_dict():


def test_instruction_from_json():
v = Variable("asdf")
v = Variable("asdf", "foo", "bar")
j = v.to_json()
w = instruction_from_json(j)
assert w == v
1 change: 0 additions & 1 deletion packages/kestrel_core/tests/test_mapping_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def test_reverse_mapping_executable():
assert item["ocsf_value"] == "basename"



@pytest.mark.parametrize(
"dmm, field, op, value, expected_result",
[
Expand Down

0 comments on commit 8c88862

Please sign in to comment.