Skip to content

Commit

Permalink
Add to_avro for AvroConverter
Browse files Browse the repository at this point in the history
`AvroConverter` can now convert `RecapType`s to Avro schemas.

The current implementation supports:

* Simple types
* Complex types
* Logical types
* Compact notation (`{"type": "string"}` -> `"string"`)

The implementation does not support:

* ProxyType

Closes #303
  • Loading branch information
criccomini committed Jul 3, 2023
1 parent 6627441 commit 8f34ab4
Show file tree
Hide file tree
Showing 2 changed files with 447 additions and 24 deletions.
202 changes: 200 additions & 2 deletions recap/converters/avro.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any

from recap.types import (
BoolType,
Expand All @@ -22,6 +23,205 @@ class AvroConverter:
def __init__(self) -> None:
self.registry = RecapTypeRegistry()

def from_recap(self, recap_type: RecapType) -> dict[str, Any]:
"""
Convert a Recap type to an Avro schema.
:param recap_type: The Recap type to convert.
:return: The Avro schema.
"""

avro_schema = self._from_recap(recap_type)
if isinstance(avro_schema, str):
avro_schema = {"type": avro_schema}
return avro_schema

def _from_recap(self, recap_type: RecapType) -> dict[str, Any] | str:
"""
Convert a Recap type to an Avro schema. Can return a string if the
schema is simple.
:param recap_type: The Recap type to convert.
:return: The Avro schema or simple string type.
"""

avro_schema = {}

if recap_type.doc:
avro_schema["doc"] = recap_type.doc

if ((name := recap_type.alias) and not isinstance(recap_type, ProxyType)) or (
name := recap_type.extra_attrs.get("name")
):
avro_schema["name"] = name

if "default" in recap_type.extra_attrs:
avro_schema["default"] = recap_type.extra_attrs["default"]

match recap_type:
case RecapType(logical=str()):
avro_schema |= self._from_recap_logical(recap_type)
case NullType():
avro_schema["type"] = "null"
case BoolType():
avro_schema["type"] = "boolean"
case IntType(bits=int(bits), signed=True) if bits <= 32:
avro_schema["type"] = "int"
case IntType(bits=int(bits), signed=bool(signed)) if (
bits <= 32 and not signed
) or (bits <= 64 and signed):
avro_schema["type"] = "long"
case IntType(bits=int(bits), signed=bool(signed)):
precision = len(str(2 ** (bits - (1 if signed else 0)))) - 1
avro_schema |= {
"type": "bytes",
"logicalType": "decimal",
"precision": precision,
}
case FloatType(bits=int(bits)):
avro_schema["type"] = "double" if bits == 64 else "float"
case StringType(bytes_=int(bytes_)) if bytes_ <= 9_223_372_036_854_775_807:
avro_schema["type"] = "string"
case BytesType(
bytes_=int(bytes_),
variable=True,
) if bytes_ <= 9_223_372_036_854_775_807:
avro_schema["type"] = "bytes"
case BytesType(
bytes_=int(bytes_),
variable=False,
) if bytes_ <= 9_223_372_036_854_775_807:
avro_schema["type"] = "fixed"
avro_schema["size"] = bytes_
case StructType(fields=list(struct_fields)):
avro_schema["type"] = "record"
record_fields = []
for struct_field in struct_fields:
record_field = {}
field_type = self._from_recap(struct_field)
if isinstance(field_type, dict):
if name := field_type.pop("name", None):
record_field["name"] = name
if default := field_type.pop("default", None):
record_field["default"] = default
if doc := field_type.pop("doc", None):
record_field["doc"] = doc
if len(field_type) == 1 and "type" in field_type:
# Convert {"type": "type"} to "type". Have to do
# this here since we're popping stuff out of the
# nested dict. We might end up with a simple type
# afterwards.
record_field["type"] = field_type["type"]
else:
record_field["type"] = field_type
record_fields.append(record_field)
avro_schema["fields"] = record_fields
case EnumType(symbols=list(symbols)):
avro_schema["type"] = "enum"
avro_schema["symbols"] = symbols
case ListType(values=values):
avro_schema["type"] = "array"
avro_schema["items"] = self._from_recap(values)
case MapType(
keys=StringType(),
values=values,
): # Avro only supports string keys
avro_schema["type"] = "map"
avro_schema["values"] = self._from_recap(values)
case UnionType(types=list(types)):
avro_schema["type"] = [self._from_recap(t) for t in types]
case _:
raise ValueError(f"Unsupported Recap type: {recap_type}")

if len(avro_schema) == 1 and isinstance(avro_schema["type"], str):
# Convert {"type": "type"} to "type"
avro_schema = avro_schema["type"]

return avro_schema

def _from_recap_logical(self, recap_type: RecapType) -> dict[str, Any]:
match recap_type:
case BytesType(
bytes_=int(bytes_),
variable=bool(),
logical=str(logical),
) if (
bytes_ <= 9_223_372_036_854_775_807
and logical == "build.recap.Decimal"
and isinstance(recap_type.extra_attrs.get("precision"), int)
):
return {
"type": "bytes",
"logicalType": "decimal",
"precision": recap_type.extra_attrs["precision"],
"scale": recap_type.extra_attrs.get("scale", 0),
}
case IntType(bits=int(bits), signed=True, logical=str(logical)) if (
bits <= 32 and logical == "build.recap.Date"
):
return {
"type": "int",
"logicalType": "date",
}
case IntType(bits=int(bits), signed=True, logical=str(logical)) if (
bits <= 32
and logical == "build.recap.Time"
and recap_type.extra_attrs.get("unit") == "millisecond"
):
return {
"type": "int",
"logicalType": "time-millis",
}
case IntType(bits=int(bits), signed=True, logical=str(logical)) if (
bits <= 64
and logical == "build.recap.Time"
and recap_type.extra_attrs.get("unit") == "microsecond"
):
return {
"type": "long",
"logicalType": "time-micros",
}
case IntType(bits=int(bits), signed=True, logical=str(logical)) if (
bits <= 64
and logical == "build.recap.Timestamp"
and recap_type.extra_attrs.get("unit") == "millisecond"
):
return {
"type": "long",
"logicalType": "timestamp-millis",
}
case IntType(bits=int(bits), signed=True, logical=str(logical)) if (
bits <= 64
and logical == "build.recap.Timestamp"
and recap_type.extra_attrs.get("unit") == "microsecond"
):
return {
"type": "long",
"logicalType": "timestamp-micros",
}
case BytesType(
bytes_=12,
variable=False,
logical=str(logical),
) if (
logical == "build.recap.Interval"
and recap_type.extra_attrs.get("unit") == "millisecond"
):
return {
"type": "fixed",
"logicalType": "duration",
"size": 12,
}
case StringType(bytes_=int(bytes_), logical=str(logical)) if (
bytes_ <= 9_223_372_036_854_775_807 and logical == "build.recap.UUID"
):
return {
"type": "string",
"logicalType": "uuid",
}
case _:
raise ValueError(f"Unsupported Recap logical type: {recap_type}")

def to_recap(self, avro_schema_str: str) -> StructType:
avro_schema = json.loads(avro_schema_str)
recap_schema = self._parse(avro_schema)
Expand Down Expand Up @@ -152,7 +352,6 @@ def _parse_logical(
bits=64,
signed=True,
unit="millisecond",
timezone="UTC",
**extra_attrs,
)
case "timestamp-micros":
Expand All @@ -161,7 +360,6 @@ def _parse_logical(
bits=64,
signed=True,
unit="microsecond",
timezone="UTC",
**extra_attrs,
)
case "duration":
Expand Down
Loading

0 comments on commit 8f34ab4

Please sign in to comment.