Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make datetime fieldtypes timezone aware #78

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
index: str = "records",
http_compress: Union[str, bool] = True,
selector: Union[None, Selector, CompiledSelector] = None,
**kwargs
**kwargs,
) -> None:
self.index = index
self.uri = uri
Expand Down
5 changes: 3 additions & 2 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import sys
import warnings
from datetime import datetime
from datetime import datetime, timezone
from itertools import zip_longest
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
from urllib.parse import parse_qsl, urlparse
Expand Down Expand Up @@ -44,6 +44,7 @@
from .whitelist import WHITELIST, WHITELIST_TREE

log = logging.getLogger(__package__)
_utcnow = functools.partial(datetime.now, timezone.utc)

RECORD_VERSION = 1
RESERVED_FIELDS = OrderedDict(
Expand Down Expand Up @@ -422,7 +423,7 @@ def _generate_record_class(name: str, fields: Tuple[Tuple[str, str]]) -> type:
_globals = {
"Record": Record,
"RECORD_VERSION": RECORD_VERSION,
"_utcnow": datetime.utcnow,
"_utcnow": _utcnow,
"_zip_longest": zip_longest,
}
for field in all_fields.values():
Expand Down
77 changes: 53 additions & 24 deletions flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import os
import pathlib
import re
import sys
from binascii import a2b_hex, b2a_hex
from datetime import datetime as _dt
from datetime import timezone
from posixpath import basename, dirname
from typing import Any, Tuple
from typing import Any, Optional, Tuple
from zoneinfo import ZoneInfo

try:
import urlparse
Expand All @@ -22,6 +24,12 @@
RE_STRIP_NANOSECS = re.compile(r"(\.\d{6})\d+")
NATIVE_UNICODE = isinstance("", str)

UTC = timezone.utc
ISO_FORMAT = "%Y-%m-%dT%H:%M:%S%z"
ISO_FORMAT_WITH_MS = "%Y-%m-%dT%H:%M:%S.%f%z"

PY_311 = sys.version_info >= (3, 11, 0)

PATH_POSIX = 0
PATH_WINDOWS = 1

Expand All @@ -32,6 +40,27 @@
path_type = pathlib.PurePath


def flow_record_tz(default_tz: str = "UTC") -> Optional[ZoneInfo]:
"""Return a ZoneInfo object based on the ``FLOW_RECORD_TZ`` environment variable.
yunzheng marked this conversation as resolved.
Show resolved Hide resolved

Args:
default_tz: default timezone if ``FLOW_RECORD_TZ`` is not set (default: UTC)
yunzheng marked this conversation as resolved.
Show resolved Hide resolved

Returns:
None if ``FLOW_RECORD_TZ=NONE`` otherwise ``ZoneInfo(FLOW_RECORD_TZ)``
"""
tz = os.environ.get("FLOW_RECORD_TZ", default_tz)
if tz.upper() == "NONE":
return None
return ZoneInfo(tz)
yunzheng marked this conversation as resolved.
Show resolved Hide resolved


# The environment variable ``FLOW_RECORD_TZ`` affects the display of datetime fields.
#
# The timezone to use when displaying datetime fields. By default this is UTC.
DISPLAY_TZINFO = flow_record_tz("UTC")


def defang(value: str) -> str:
"""Defangs the value to make URLs or ip addresses unclickable"""
value = re.sub("^http://", "hxxp://", value, flags=re.IGNORECASE)
Expand Down Expand Up @@ -238,24 +267,24 @@ def __new__(cls, *args, **kwargs):
# String constructor is used for example in JsonRecordAdapter
# Note: ISO 8601 is fully implemented in fromisoformat() from Python 3.11 and onwards.
# Until then, we need to manually detect timezone info and handle it.
if any(z in arg[19:] for z in ["Z", "+", "-"]):
if "." in arg[19:]:
try:
return cls.strptime(arg, "%Y-%m-%dT%H:%M:%S.%f%z")
except ValueError:
# Sometimes nanoseconds need to be stripped
return cls.strptime(re.sub(RE_STRIP_NANOSECS, "\\1", arg), "%Y-%m-%dT%H:%M:%S.%f%z")
return cls.strptime(arg, "%Y-%m-%dT%H:%M:%S%z")
if not PY_311 and any(z in arg[19:] for z in ["Z", "+", "-"]):
spec = ISO_FORMAT_WITH_MS if "." in arg[19:] else ISO_FORMAT
try:
obj = cls.strptime(arg, spec)
except ValueError:
# Sometimes nanoseconds need to be stripped
obj = cls.strptime(re.sub(RE_STRIP_NANOSECS, "\\1", arg), spec)
else:
try:
return cls.fromisoformat(arg)
obj = cls.fromisoformat(arg)
except ValueError:
# Sometimes nanoseconds need to be stripped
return cls.fromisoformat(re.sub(RE_STRIP_NANOSECS, "\\1", arg))
obj = cls.fromisoformat(re.sub(RE_STRIP_NANOSECS, "\\1", arg))
elif isinstance(arg, (int, float_type)):
return cls.utcfromtimestamp(arg)
obj = cls.fromtimestamp(arg, UTC)
elif isinstance(arg, (_dt,)):
return _dt.__new__(
tzinfo = UTC if arg.tzinfo is None else arg.tzinfo
yunzheng marked this conversation as resolved.
Show resolved Hide resolved
obj = _dt.__new__(
cls,
arg.year,
arg.month,
Expand All @@ -264,24 +293,24 @@ def __new__(cls, *args, **kwargs):
arg.minute,
arg.second,
arg.microsecond,
arg.tzinfo,
tzinfo,
)
else:
obj = _dt.__new__(cls, *args, **kwargs)

return _dt.__new__(cls, *args, **kwargs)

def __eq__(self, other):
# Avoid TypeError: can't compare offset-naive and offset-aware datetimes
# naive datetimes are treated as UTC in flow.record instead of local time
ts1 = self.timestamp() if self.tzinfo else self.replace(tzinfo=timezone.utc).timestamp()
ts2 = other.timestamp() if other.tzinfo else other.replace(tzinfo=timezone.utc).timestamp()
return ts1 == ts2
# Ensure we always return a timezone aware datetime. Treat naive datetimes as UTC
if obj.tzinfo is None:
obj = obj.replace(tzinfo=UTC)
return obj

def _pack(self):
return self

def __str__(self):
return self.astimezone(DISPLAY_TZINFO).isoformat(" ") if DISPLAY_TZINFO else self.isoformat(" ")

def __repr__(self):
result = str(self)
return result
return str(self)

def __hash__(self):
return _dt.__hash__(self)
Expand Down
2 changes: 1 addition & 1 deletion flow/record/jsonpacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def pack_obj(self, obj):
}
return serial
if isinstance(obj, datetime):
serial = obj.strftime("%Y-%m-%dT%H:%M:%S.%f")
serial = obj.isoformat()
return serial
if isinstance(obj, fieldtypes.digest):
return {
Expand Down
15 changes: 9 additions & 6 deletions flow/record/packer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import functools
import warnings
from datetime import datetime, timezone

import msgpack

Expand Down Expand Up @@ -29,6 +29,8 @@
RECORD_PACK_TYPE_VARINT = 0x11
RECORD_PACK_TYPE_GROUPEDRECORD = 0x12

UTC = timezone.utc


def identifier_to_str(identifier):
if isinstance(identifier, tuple) and len(identifier) == 2:
Expand Down Expand Up @@ -61,9 +63,11 @@ def register(self, desc, notify=False):
def pack_obj(self, obj, unversioned=False):
packed = None

if isinstance(obj, datetime.datetime):
t = obj.utctimetuple()[:6] + (obj.microsecond,)
packed = (RECORD_PACK_TYPE_DATETIME, t)
if isinstance(obj, datetime):
if obj.tzinfo is None or obj.tzinfo == UTC:
packed = (RECORD_PACK_TYPE_DATETIME, (*obj.timetuple()[:6], obj.microsecond))
else:
packed = (RECORD_PACK_TYPE_DATETIME, (obj.isoformat(),))

elif isinstance(obj, int):
neg = obj < 0
Expand Down Expand Up @@ -102,8 +106,7 @@ def unpack_obj(self, t, data):
subtype, value = self.unpack(data)

if subtype == RECORD_PACK_TYPE_DATETIME:
dt = fieldtypes.datetime(*value)
return dt
return fieldtypes.datetime(*value)

if subtype == RECORD_PACK_TYPE_VARINT:
neg, h = value
Expand Down
4 changes: 2 additions & 2 deletions flow/record/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, path_template=None, name=None):

def rotate_existing_file(self, path):
if os.path.exists(path):
now = datetime.datetime.utcnow()
now = datetime.datetime.now(datetime.timezone.utc)
src = os.path.realpath(path)

src_dir = os.path.dirname(src)
Expand Down Expand Up @@ -226,7 +226,7 @@ def record_stream_for_path(self, path):
return self.writer

def write(self, record):
ts = record._generated or datetime.datetime.utcnow()
ts = record._generated or datetime.datetime.now(datetime.timezone.utc)
path = self.path_template.format(name=self.name, record=record, ts=ts)
rs = self.record_stream_for_path(path)
rs.write(record)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [
]
dependencies = [
"msgpack>=0.5.2",
"backports.zoneinfo[tzdata]; python_version<'3.9'",
yunzheng marked this conversation as resolved.
Show resolved Hide resolved
]
dynamic = ["version"]

Expand Down
4 changes: 2 additions & 2 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_records(count=100):
)

for i in range(count):
embedded = TestRecordEmbedded(datetime.datetime.utcnow())
embedded = TestRecordEmbedded(datetime.datetime.now(datetime.timezone.utc))
yield TestRecord(number=i, record=embedded)


Expand All @@ -33,4 +33,4 @@ def generate_plain_records(count=100):
)

for i in range(count):
yield TestRecord(number=i, dt=datetime.datetime.utcnow())
yield TestRecord(number=i, dt=datetime.datetime.now(datetime.timezone.utc))
Loading