Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,7 @@ def create_table(self, full_name: str, klass: Dataclass):
def _schema_for(cls, klass: Dataclass):
fields = []
for f in dataclasses.fields(klass):
field_type = f.type
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field_type, str):
try:
field_type = __builtins__[field_type]
except TypeError as e:
logger.warning(f"Could not load type {field_type}", exc_info=e)
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if field_type not in cls._builtin_type_mapping:
Expand All @@ -89,6 +82,17 @@ def _schema_for(cls, klass: Dataclass):
fields.append(f"{f.name} {spark_type}{not_null}")
return ", ".join(fields)

@classmethod
def _field_type(cls, field: dataclasses.Field):
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field.type, str):
try:
return __builtins__[field.type]
except TypeError as e:
logger.warning(f"Could not load type {field.type}", exc_info=e)
return field.type

@classmethod
def _filter_none_rows(cls, rows, klass):
if len(rows) == 0:
Expand Down Expand Up @@ -168,12 +172,12 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})'
self.execute(sql)

@staticmethod
def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
@classmethod
def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
data = []
for f in fields:
value = getattr(row, f.name)
field_type = f.type
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if value is None:
Expand Down