Skip to content

Commit

Permalink
work in progress for idaholab#386, include missing aggregations in AP…
Browse files Browse the repository at this point in the history
…I bucket queries
  • Loading branch information
mmguero committed Jan 25, 2024
1 parent 627c6fd commit 4cdf138
Showing 1 changed file with 84 additions and 70 deletions.
154 changes: 84 additions & 70 deletions api/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import urllib3
import warnings

from collections import defaultdict
from collections import defaultdict, OrderedDict
from collections.abc import Iterable
from datetime import datetime
from flask import Flask, jsonify, request
Expand Down Expand Up @@ -152,6 +152,15 @@
field_type_map['time'] = 'date'
field_type_map['timestamp'] = 'date'

# field type maps to various supported "missing" values
# TODO: do I need to handle weird ones like "date" and "geo"?
missing_field_map = defaultdict(lambda: '-')
missing_field_map['double'] = 0.0
missing_field_map['float'] = 0.0
missing_field_map['integer'] = 0
missing_field_map['ip'] = '0.0.0.0'
missing_field_map['long'] = 0

urllib3.disable_warnings()
warnings.filterwarnings(
"ignore",
Expand Down Expand Up @@ -212,32 +221,16 @@
)


def deep_get(d, keys, default=None):
assert type(keys) is list
if d is None:
return default
if not keys:
return d
return deep_get(d.get(keys[0]), keys[1:], default)


def get_iterable(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return x
else:
return (x,)


def random_id(length=20):
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))


def get_request_arguments(req):
arguments = {}
if 'POST' in get_iterable(req.method):
if 'POST' in malcolm_utils.get_iterable(req.method):
if (data := req.get_json() if req.is_json else None) and isinstance(data, dict):
arguments.update(data)
if 'GET' in get_iterable(req.method):
if 'GET' in malcolm_utils.get_iterable(req.method):
arguments.update(request.args)
if debugApi:
print(f"{req.method} {req.path} arguments: {json.dumps(arguments)}")
Expand Down Expand Up @@ -342,7 +335,7 @@ def urls_for_field(fieldname, start_time=None, end_time=None):
translated = []

if databaseMode != malcolm_utils.DatabaseMode.ElasticsearchRemote:
for field in get_iterable(fieldname):
for field in malcolm_utils.get_iterable(fieldname):
for url_regex_pair in fields_to_urls:
if (len(url_regex_pair) == 2) and re.search(url_regex_pair[0], field, flags=re.IGNORECASE):
for url in url_regex_pair[1]:
Expand Down Expand Up @@ -370,7 +363,7 @@ def doctype_from_args(args):
return doctype
network|host
"""
return deep_get(args, ["doctype"], app.config["DOCTYPE_DEFAULT"])
return malcolm_utils.deep_get(args, ["doctype"], app.config["DOCTYPE_DEFAULT"])


def index_from_args(args):
Expand Down Expand Up @@ -502,7 +495,7 @@ def filtervalues(search, args):
# field != value
s = s.exclude(
"terms",
**{fieldname[1:]: get_iterable(filtervalue)},
**{fieldname[1:]: malcolm_utils.get_iterable(filtervalue)},
)
else:
# field exists ("is not null")
Expand All @@ -513,7 +506,7 @@ def filtervalues(search, args):
# field == value
s = s.filter(
"terms",
**{fieldname: get_iterable(filtervalue)},
**{fieldname: malcolm_utils.get_iterable(filtervalue)},
)
else:
# field does not exist ("is null")
Expand Down Expand Up @@ -550,43 +543,60 @@ def aggfields(fieldnames, current_request, urls=None):
global SearchClass

args = get_request_arguments(current_request)
idx = index_from_args(args)
s = SearchClass(
using=databaseClient,
index=index_from_args(args),
index=idx,
).extra(size=0)
start_time_ms, end_time_ms, s = filtertime(s, args)
filters, s = filtervalues(s, args)
bucket_limit = int(deep_get(args, ["limit"], app.config["RESULT_SET_LIMIT"]))
bucket_limit = int(malcolm_utils.deep_get(args, ["limit"], app.config["RESULT_SET_LIMIT"]))
last_bucket = s.aggs
aggCount = 0
for fname in get_iterable(fieldnames):
aggCount += 1
# TODO: missing string needs to match the type of the field
# "error": "RequestError: RequestError(400, 'search_phase_execution_exception', \"'__missing__' is not an IP string literal.\")"

for fname in malcolm_utils.get_iterable(fieldnames):
# Get the field mapping type for this field, and map it to a good default "missing"
# (empty bucket) label for the bucket missing= parameter below
mapping = databaseClient.indices.get_field_mapping(
fname,
index=idx,
)
missing_val = (
missing_field_map[
next(
iter(
malcolm_utils.dictsearch(
mapping[next(iter(OrderedDict(sorted(mapping.items(), reverse=True))))], 'type'
)
),
None,
)
]
if (mapping and isinstance(mapping, dict))
else missing_field_map[None]
)

# chain on the aggregation for the next field
last_bucket = last_bucket.bucket(
f"values_{aggCount}",
fname,
"terms",
field=fname,
size=bucket_limit,
missing="__missing__",
missing=missing_val,
)

response = s.execute()

top_bucket_name = next(iter(malcolm_utils.get_iterable(fieldnames)))
result_dict = {
top_bucket_name: response.aggregations.to_dict().get(top_bucket_name, {}),
'range': (start_time_ms // 1000, end_time_ms // 1000),
'filter': filters,
'fields': malcolm_utils.get_iterable(fieldnames),
}
if (urls is not None) and (len(urls) > 0):
return jsonify(
values=response.aggregations.to_dict().get("values_1", {}),
range=(start_time_ms // 1000, end_time_ms // 1000),
filter=filters,
fields=get_iterable(fieldnames),
urls=urls,
)
else:
return jsonify(
values=response.aggregations.to_dict().get("values_1", {}),
range=(start_time_ms // 1000, end_time_ms // 1000),
filter=filters,
fields=get_iterable(fieldnames),
)
result_dict['urls'] = urls

return jsonify(result_dict)


@app.route(
Expand Down Expand Up @@ -650,7 +660,7 @@ def document():
s = SearchClass(
using=databaseClient,
index=index_from_args(args),
).extra(size=int(deep_get(args, ["limit"], app.config["RESULT_SET_LIMIT"])))
).extra(size=int(malcolm_utils.deep_get(args, ["limit"], app.config["RESULT_SET_LIMIT"])))
start_time_ms, end_time_ms, s = filtertime(s, args, default_from="1970-1-1", default_to="now")
filters, s = filtervalues(s, args)
return jsonify(
Expand Down Expand Up @@ -712,7 +722,7 @@ def fields():

args = get_request_arguments(request)

templateName = deep_get(args, ["template"], app.config["MALCOLM_TEMPLATE"])
templateName = malcolm_utils.deep_get(args, ["template"], app.config["MALCOLM_TEMPLATE"])
arkimeFields = (templateName == app.config["MALCOLM_TEMPLATE"]) and (doctype_from_args(args) == 'network')

fields = defaultdict(dict)
Expand All @@ -725,12 +735,12 @@ def fields():
index=index_from_args(args),
).extra(size=5000)
for hit in [x['_source'] for x in s.execute().to_dict().get('hits', {}).get('hits', [])]:
if (fieldname := deep_get(hit, ['dbField2'])) and (fieldname not in fields):
if (fieldname := malcolm_utils.deep_get(hit, ['dbField2'])) and (fieldname not in fields):
if debugApi:
hit['source'] = 'arkime'
fields[fieldname] = {
'description': deep_get(hit, ['help']),
'type': field_type_map[deep_get(hit, ['type'])],
'description': malcolm_utils.deep_get(hit, ['help']),
'type': field_type_map[malcolm_utils.deep_get(hit, ['type'])],
}
if debugApi:
fields[fieldname]['original'] = [hit]
Expand All @@ -746,35 +756,39 @@ def fields():
verify=opensearchSslVerify,
).json()

for template in deep_get(getTemplateResponseJson, ["index_templates"]):
for template in malcolm_utils.deep_get(getTemplateResponseJson, ["index_templates"]):
# top-level fields
for fieldname, fieldinfo in deep_get(
for fieldname, fieldinfo in malcolm_utils.deep_get(
template,
["index_template", "template", "mappings", "properties"],
).items():
if debugApi:
fieldinfo['source'] = f'opensearch.{templateName}'
if 'type' in fieldinfo:
fields[fieldname]['type'] = field_type_map[deep_get(fieldinfo, ['type'])]
fields[fieldname]['type'] = field_type_map[malcolm_utils.deep_get(fieldinfo, ['type'])]
if debugApi:
fields[fieldname]['original'] = fields[fieldname].get('original', []) + [fieldinfo]

# descendant component fields
for componentName in get_iterable(deep_get(template, ["index_template", "composed_of"])):
for componentName in malcolm_utils.get_iterable(
malcolm_utils.deep_get(template, ["index_template", "composed_of"])
):
getComponentResponseJson = requests.get(
f'{opensearchUrl}/_component_template/{componentName}',
auth=opensearchReqHttpAuth,
verify=opensearchSslVerify,
).json()
for component in get_iterable(deep_get(getComponentResponseJson, ["component_templates"])):
for fieldname, fieldinfo in deep_get(
for component in malcolm_utils.get_iterable(
malcolm_utils.deep_get(getComponentResponseJson, ["component_templates"])
):
for fieldname, fieldinfo in malcolm_utils.deep_get(
component,
["component_template", "template", "mappings", "properties"],
).items():
if debugApi:
fieldinfo['source'] = f'opensearch.{templateName}.{componentName}'
if 'type' in fieldinfo:
fields[fieldname]['type'] = field_type_map[deep_get(fieldinfo, ['type'])]
fields[fieldname]['type'] = field_type_map[malcolm_utils.deep_get(fieldinfo, ['type'])]
if debugApi:
fields[fieldname]['original'] = fields[fieldname].get('original', []) + [fieldinfo]

Expand All @@ -793,12 +807,12 @@ def fields():
auth=opensearchReqHttpAuth,
verify=opensearchSslVerify,
).json()['fields']:
if fieldname := deep_get(field, ['name']):
if fieldname := malcolm_utils.deep_get(field, ['name']):
if debugApi:
field['source'] = 'dashboards'
field_types = deep_get(field, ['esTypes'], [])
field_types = malcolm_utils.deep_get(field, ['esTypes'], [])
fields[fieldname]['type'] = field_type_map[
field_types[0] if len(field_types) > 0 else deep_get(fields[fieldname], ['type'])
field_types[0] if len(field_types) > 0 else malcolm_utils.deep_get(fields[fieldname], ['type'])
]
if debugApi:
fields[fieldname]['original'] = fields[fieldname].get('original', []) + [field]
Expand Down Expand Up @@ -939,7 +953,7 @@ def event():
data = get_request_arguments(request)
nowTimeStr = datetime.now().astimezone(pytz.utc).isoformat().replace('+00:00', 'Z')
if 'alert' in data:
alert[app.config["MALCOLM_NETWORK_INDEX_TIME_FIELD"]] = deep_get(
alert[app.config["MALCOLM_NETWORK_INDEX_TIME_FIELD"]] = malcolm_utils.deep_get(
data,
[
'alert',
Expand All @@ -949,7 +963,7 @@ def event():
nowTimeStr,
)
alert['firstPacket'] = alert[app.config["MALCOLM_NETWORK_INDEX_TIME_FIELD"]]
alert['lastPacket'] = deep_get(
alert['lastPacket'] = malcolm_utils.deep_get(
data,
[
'alert',
Expand All @@ -969,23 +983,23 @@ def event():
alert['event']['dataset'] = 'alerting'
alert['event']['module'] = 'alerting'
alert['event']['url'] = '/dashboards/app/alerting#/dashboard'
alertId = deep_get(
alertId = malcolm_utils.deep_get(
data,
[
'alert',
'alert',
],
)
alert['event']['id'] = alertId if alertId else random_id()
if alertBody := deep_get(
if alertBody := malcolm_utils.deep_get(
data,
[
'alert',
'body',
],
):
alert['event']['original'] = alertBody
if triggerName := deep_get(
if triggerName := malcolm_utils.deep_get(
data,
[
'alert',
Expand All @@ -994,7 +1008,7 @@ def event():
],
):
alert['event']['reason'] = triggerName
if monitorName := deep_get(
if monitorName := malcolm_utils.deep_get(
data,
[
'alert',
Expand All @@ -1005,7 +1019,7 @@ def event():
alert['rule'] = {}
alert['rule']['name'] = monitorName
if alertSeverity := str(
deep_get(
malcolm_utils.deep_get(
data,
[
'alert',
Expand All @@ -1019,15 +1033,15 @@ def event():
alert['event']['risk_score_norm'] = sevnum
alert['event']['severity'] = sevnum
alert['event']['severity_tags'] = 'Alert'
if alertResults := deep_get(
if alertResults := malcolm_utils.deep_get(
data,
[
'alert',
'results',
],
):
if len(alertResults) > 0:
if hitCount := deep_get(alertResults[0], ['hits', 'total', 'value'], 0):
if hitCount := malcolm_utils.deep_get(alertResults[0], ['hits', 'total', 'value'], 0):
alert['event']['hits'] = hitCount

docDateStr = dateparser.parse(alert[app.config["MALCOLM_NETWORK_INDEX_TIME_FIELD"]]).strftime('%y%m%d')
Expand Down

0 comments on commit 4cdf138

Please sign in to comment.