From e6e916f2bb474f7a986aedf66a6adf9bd28db44d Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 4 Sep 2024 11:23:18 +0200 Subject: [PATCH] also fix field type in _row_to_sql --- src/databricks/labs/lsql/backends.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 4d32fb9a..6be86f7a 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -69,14 +69,7 @@ def create_table(self, full_name: str, klass: Dataclass): def _schema_for(cls, klass: Dataclass): fields = [] for f in dataclasses.fields(klass): - field_type = f.type - # workaround rare (Python?) issue where f.type is the type name instead of the type itself - # this seems to happen when the dataclass is first used from a file importing it - if isinstance(field_type, str): - try: - field_type = __builtins__[field_type] - except TypeError as e: - logger.warning(f"Could not load type {field_type}", exc_info=e) + field_type = cls._field_type(f) if isinstance(field_type, UnionType): field_type = field_type.__args__[0] if field_type not in cls._builtin_type_mapping: @@ -89,6 +82,17 @@ def _schema_for(cls, klass: Dataclass): fields.append(f"{f.name} {spark_type}{not_null}") return ", ".join(fields) + @classmethod + def _field_type(cls, field: dataclasses.Field): + # workaround rare (Python?) issue where f.type is the type name instead of the type itself + # this seems to happen when the dataclass is first used from a file importing it + if isinstance(field.type, str): + try: + return __builtins__[field.type] + except TypeError as e: + logger.warning(f"Could not load type {field.type}", exc_info=e) + return field.type + @classmethod def _filter_none_rows(cls, rows, klass): if len(rows) == 0: @@ -168,12 +172,12 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})' self.execute(sql) - @staticmethod - def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]): + @classmethod + def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]): data = [] for f in fields: value = getattr(row, f.name) - field_type = f.type + field_type = cls._field_type(f) if isinstance(field_type, UnionType): field_type = field_type.__args__[0] if value is None: