Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make datetime fieldtypes timezone aware #78

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
index: str = "records",
http_compress: Union[str, bool] = True,
selector: Union[None, Selector, CompiledSelector] = None,
**kwargs
**kwargs,
) -> None:
self.index = index
self.uri = uri
Expand Down
5 changes: 3 additions & 2 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import sys
import warnings
from datetime import datetime
from datetime import datetime, timezone
from itertools import zip_longest
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
from urllib.parse import parse_qsl, urlparse
Expand Down Expand Up @@ -44,6 +44,7 @@
from .whitelist import WHITELIST, WHITELIST_TREE

log = logging.getLogger(__package__)
_utcnow = functools.partial(datetime.now, timezone.utc)

RECORD_VERSION = 1
RESERVED_FIELDS = OrderedDict(
Expand Down Expand Up @@ -422,7 +423,7 @@ def _generate_record_class(name: str, fields: Tuple[Tuple[str, str]]) -> type:
_globals = {
"Record": Record,
"RECORD_VERSION": RECORD_VERSION,
"_utcnow": datetime.utcnow,
"_utcnow": _utcnow,
"_zip_longest": zip_longest,
}
for field in all_fields.values():
Expand Down
94 changes: 62 additions & 32 deletions flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
from __future__ import annotations

import binascii
import math
import os
import pathlib
import re
import sys
import warnings
from binascii import a2b_hex, b2a_hex
from datetime import datetime as _dt
from datetime import timezone
from posixpath import basename, dirname
from typing import Any, Tuple

try:
import urlparse
except ImportError:
import urllib.parse as urlparse

import warnings
from typing import Any, Optional, Tuple
from urllib.parse import urlparse
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

from flow.record.base import FieldType

RE_NORMALIZE_PATH = re.compile(r"[\\/]+")
RE_STRIP_NANOSECS = re.compile(r"(\.\d{6})\d+")
NATIVE_UNICODE = isinstance("", str)

UTC = timezone.utc
ISO_FORMAT = "%Y-%m-%dT%H:%M:%S%z"
ISO_FORMAT_WITH_MS = "%Y-%m-%dT%H:%M:%S.%f%z"

PY_311 = sys.version_info >= (3, 11, 0)

PATH_POSIX = 0
PATH_WINDOWS = 1

Expand All @@ -32,6 +37,31 @@
path_type = pathlib.PurePath


def flow_record_tz(*, default_tz: str = "UTC") -> Optional[ZoneInfo | UTC]:
"""Return a ``ZoneInfo`` object based on the ``FLOW_RECORD_TZ`` environment variable.

Args:
default_tz: Default timezone if ``FLOW_RECORD_TZ`` is not set (default: UTC).

Returns:
None if ``FLOW_RECORD_TZ=NONE`` otherwise ``ZoneInfo(FLOW_RECORD_TZ)`` or ``UTC`` if ZoneInfo is not found.
"""
tz = os.environ.get("FLOW_RECORD_TZ", default_tz)
if tz.upper() == "NONE":
return None
try:
return ZoneInfo(tz)
except ZoneInfoNotFoundError as exc:
warnings.warn(f"{exc!r}, falling back to timezone.utc")
return UTC


# The environment variable ``FLOW_RECORD_TZ`` affects the display of datetime fields.
#
# The timezone to use when displaying datetime fields. By default this is UTC.
DISPLAY_TZINFO = flow_record_tz(default_tz="UTC")


def defang(value: str) -> str:
"""Defangs the value to make URLs or ip addresses unclickable"""
value = re.sub("^http://", "hxxp://", value, flags=re.IGNORECASE)
Expand Down Expand Up @@ -238,24 +268,24 @@ def __new__(cls, *args, **kwargs):
# String constructor is used for example in JsonRecordAdapter
# Note: ISO 8601 is fully implemented in fromisoformat() from Python 3.11 and onwards.
# Until then, we need to manually detect timezone info and handle it.
if any(z in arg[19:] for z in ["Z", "+", "-"]):
if "." in arg[19:]:
try:
return cls.strptime(arg, "%Y-%m-%dT%H:%M:%S.%f%z")
except ValueError:
# Sometimes nanoseconds need to be stripped
return cls.strptime(re.sub(RE_STRIP_NANOSECS, "\\1", arg), "%Y-%m-%dT%H:%M:%S.%f%z")
return cls.strptime(arg, "%Y-%m-%dT%H:%M:%S%z")
if not PY_311 and any(z in arg[19:] for z in ["Z", "+", "-"]):
spec = ISO_FORMAT_WITH_MS if "." in arg[19:] else ISO_FORMAT
try:
obj = cls.strptime(arg, spec)
except ValueError:
# Sometimes nanoseconds need to be stripped
obj = cls.strptime(re.sub(RE_STRIP_NANOSECS, "\\1", arg), spec)
else:
try:
return cls.fromisoformat(arg)
obj = cls.fromisoformat(arg)
except ValueError:
# Sometimes nanoseconds need to be stripped
return cls.fromisoformat(re.sub(RE_STRIP_NANOSECS, "\\1", arg))
obj = cls.fromisoformat(re.sub(RE_STRIP_NANOSECS, "\\1", arg))
elif isinstance(arg, (int, float_type)):
return cls.utcfromtimestamp(arg)
obj = cls.fromtimestamp(arg, UTC)
elif isinstance(arg, (_dt,)):
return _dt.__new__(
tzinfo = arg.tzinfo or UTC
obj = _dt.__new__(
cls,
arg.year,
arg.month,
Expand All @@ -264,24 +294,24 @@ def __new__(cls, *args, **kwargs):
arg.minute,
arg.second,
arg.microsecond,
arg.tzinfo,
tzinfo,
)
else:
obj = _dt.__new__(cls, *args, **kwargs)

return _dt.__new__(cls, *args, **kwargs)

def __eq__(self, other):
# Avoid TypeError: can't compare offset-naive and offset-aware datetimes
# naive datetimes are treated as UTC in flow.record instead of local time
ts1 = self.timestamp() if self.tzinfo else self.replace(tzinfo=timezone.utc).timestamp()
ts2 = other.timestamp() if other.tzinfo else other.replace(tzinfo=timezone.utc).timestamp()
return ts1 == ts2
# Ensure we always return a timezone aware datetime. Treat naive datetimes as UTC
if obj.tzinfo is None:
obj = obj.replace(tzinfo=UTC)
return obj

def _pack(self):
return self

def __str__(self):
return self.astimezone(DISPLAY_TZINFO).isoformat(" ") if DISPLAY_TZINFO else self.isoformat(" ")

def __repr__(self):
result = str(self)
return result
return str(self)

def __hash__(self):
return _dt.__hash__(self)
Expand Down Expand Up @@ -462,7 +492,7 @@ def _unpack(cls, data):

class uri(string, FieldType):
def __init__(self, value):
self._parsed = urlparse.urlparse(value)
self._parsed = urlparse(value)

@staticmethod
def normalize(path):
Expand Down
2 changes: 1 addition & 1 deletion flow/record/jsonpacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def pack_obj(self, obj):
}
return serial
if isinstance(obj, datetime):
serial = obj.strftime("%Y-%m-%dT%H:%M:%S.%f")
serial = obj.isoformat()
return serial
if isinstance(obj, fieldtypes.digest):
return {
Expand Down
15 changes: 9 additions & 6 deletions flow/record/packer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import functools
import warnings
from datetime import datetime, timezone

import msgpack

Expand Down Expand Up @@ -29,6 +29,8 @@
RECORD_PACK_TYPE_VARINT = 0x11
RECORD_PACK_TYPE_GROUPEDRECORD = 0x12

UTC = timezone.utc


def identifier_to_str(identifier):
if isinstance(identifier, tuple) and len(identifier) == 2:
Expand Down Expand Up @@ -61,9 +63,11 @@ def register(self, desc, notify=False):
def pack_obj(self, obj, unversioned=False):
packed = None

if isinstance(obj, datetime.datetime):
t = obj.utctimetuple()[:6] + (obj.microsecond,)
packed = (RECORD_PACK_TYPE_DATETIME, t)
if isinstance(obj, datetime):
if obj.tzinfo is None or obj.tzinfo == UTC:
packed = (RECORD_PACK_TYPE_DATETIME, (*obj.timetuple()[:6], obj.microsecond))
else:
packed = (RECORD_PACK_TYPE_DATETIME, (obj.isoformat(),))

elif isinstance(obj, int):
neg = obj < 0
Expand Down Expand Up @@ -102,8 +106,7 @@ def unpack_obj(self, t, data):
subtype, value = self.unpack(data)

if subtype == RECORD_PACK_TYPE_DATETIME:
dt = fieldtypes.datetime(*value)
return dt
return fieldtypes.datetime(*value)

if subtype == RECORD_PACK_TYPE_VARINT:
neg, h = value
Expand Down
4 changes: 2 additions & 2 deletions flow/record/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, path_template=None, name=None):

def rotate_existing_file(self, path):
if os.path.exists(path):
now = datetime.datetime.utcnow()
now = datetime.datetime.now(datetime.timezone.utc)
src = os.path.realpath(path)

src_dir = os.path.dirname(src)
Expand Down Expand Up @@ -226,7 +226,7 @@ def record_stream_for_path(self, path):
return self.writer

def write(self, record):
ts = record._generated or datetime.datetime.utcnow()
ts = record._generated or datetime.datetime.now(datetime.timezone.utc)
path = self.path_template.format(name=self.name, record=record, ts=ts)
rs = self.record_stream_for_path(path)
rs.write(record)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ classifiers = [
]
dependencies = [
"msgpack>=0.5.2",
"backports.zoneinfo[tzdata]; python_version<'3.9'",
yunzheng marked this conversation as resolved.
Show resolved Hide resolved
"tzdata; platform_system=='Windows'",
]
dynamic = ["version"]

Expand Down
4 changes: 2 additions & 2 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_records(count=100):
)

for i in range(count):
embedded = TestRecordEmbedded(datetime.datetime.utcnow())
embedded = TestRecordEmbedded(datetime.datetime.now(datetime.timezone.utc))
yield TestRecord(number=i, record=embedded)


Expand All @@ -33,4 +33,4 @@ def generate_plain_records(count=100):
)

for i in range(count):
yield TestRecord(number=i, dt=datetime.datetime.utcnow())
yield TestRecord(number=i, dt=datetime.datetime.now(datetime.timezone.utc))
Loading