diff --git a/cloudquery/sdk/internal/memdb/memdb.py b/cloudquery/sdk/internal/memdb/memdb.py index 78f3ac9..4f422f5 100644 --- a/cloudquery/sdk/internal/memdb/memdb.py +++ b/cloudquery/sdk/internal/memdb/memdb.py @@ -1,7 +1,7 @@ from cloudquery.sdk import plugin from cloudquery.sdk import message from cloudquery.sdk import schema -from typing import List, Generator, Any, Dict +from typing import List, Generator, Dict import pyarrow as pa NAME = "memdb" @@ -24,13 +24,13 @@ def sync( for table, record in self._db.items(): yield message.SyncInsertMessage(record) - def write(self, msg_iterator: Generator[message.WriteMessage, None, None]) -> None: - for msg in msg_iterator: - if type(msg) == message.WriteMigrateTableMessage: + def write(self, writer: Generator[message.WriteMessage, None, None]) -> None: + for msg in writer: + if isinstance(msg, message.WriteMigrateTableMessage): if msg.table.name not in self._db: self._db[msg.table.name] = msg.table self._tables[msg.table.name] = msg.table - elif type(msg) == message.WriteInsertMessage: + elif isinstance(msg, message.WriteInsertMessage): table = schema.Table.from_arrow_schema(msg.record.schema) self._db[table.name] = msg.record else: diff --git a/cloudquery/sdk/scalar/binary.py b/cloudquery/sdk/scalar/binary.py index 0617c5e..fb31662 100644 --- a/cloudquery/sdk/scalar/binary.py +++ b/cloudquery/sdk/scalar/binary.py @@ -6,7 +6,7 @@ class Binary(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Binary: + if isinstance(scalar, Binary): return self._value == scalar._value and self._valid == scalar._valid return False @@ -23,13 +23,13 @@ def set(self, value: any): self._value = value.value return - if type(value) == bytes: + if isinstance(value, bytes): self._valid = True self._value = value - elif type(value) == str: + elif isinstance(value, str): self._valid = True self._value = value.encode() else: raise ScalarInvalidTypeError( - "Invalid type {} for Binary scalar".format(type(value)) + f"Invalid type {type(value)} for Binary scalar" ) diff --git a/cloudquery/sdk/scalar/bool.py b/cloudquery/sdk/scalar/bool.py index 6867998..fc8823c 100644 --- a/cloudquery/sdk/scalar/bool.py +++ b/cloudquery/sdk/scalar/bool.py @@ -10,17 +10,17 @@ def parse_string_to_bool(input_string): if lower_input in true_strings: return True - elif lower_input in false_strings: + if lower_input in false_strings: return False else: - raise ScalarInvalidTypeError("Invalid boolean string: {}".format(input_string)) + raise ScalarInvalidTypeError(f"Invalid boolean string: {input_string}") class Bool(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Bool: + if isinstance(scalar, Bool): return self._value == scalar._value and self._valid == scalar._valid return False @@ -38,13 +38,11 @@ def set(self, value: Any): self._value = value.value return - if type(value) == bool: + if isinstance(value, bool): self._value = value - elif type(value) == str: + elif isinstance(value, str): self._value = parse_string_to_bool(value) else: - raise ScalarInvalidTypeError( - "Invalid type {} for Bool scalar".format(type(value)) - ) + raise ScalarInvalidTypeError(f"Invalid type {type(value)} for Bool scalar") self._valid = True diff --git a/cloudquery/sdk/scalar/date32.py b/cloudquery/sdk/scalar/date32.py index 24597f0..148f0a1 100644 --- a/cloudquery/sdk/scalar/date32.py +++ b/cloudquery/sdk/scalar/date32.py @@ -7,7 +7,7 @@ class Date32(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Date32: + if isinstance(scalar, Date32): return self._value == scalar._value and self._valid == scalar._valid return False @@ -25,15 +25,15 @@ def set(self, value: Any): self._value = value.value return - if type(value) == datetime: + if isinstance(value, datetime): self._value = value - elif type(value) == str: + elif isinstance(value, str): self._value = datetime.strptime(value, "%Y-%m-%d") - elif type(value) == time: + elif isinstance(value, time): self._value = datetime.combine(datetime.today(), value) else: raise ScalarInvalidTypeError( - "Invalid type {} for Date32 scalar".format(type(value)) + f"Invalid type {type(value)} for Date32 scalar" ) self._valid = True diff --git a/cloudquery/sdk/scalar/date64.py b/cloudquery/sdk/scalar/date64.py index ba88953..3c0dbc0 100644 --- a/cloudquery/sdk/scalar/date64.py +++ b/cloudquery/sdk/scalar/date64.py @@ -7,7 +7,7 @@ class Date64(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Date64: + if isinstance(scalar, Date64): return self._value == scalar._value and self._valid == scalar._valid return False @@ -25,15 +25,15 @@ def set(self, value: Any): self._value = value.value return - if type(value) == datetime: + if isinstance(value, datetime): self._value = value - elif type(value) == str: + elif isinstance(value, str): self._value = datetime.strptime(value, "%Y-%m-%d") - elif type(value) == time: + elif isinstance(value, time): self._value = datetime.combine(datetime.today(), value) else: raise ScalarInvalidTypeError( - "Invalid type {} for Date64 scalar".format(type(value)) + f"Invalid type {type(value)} for Date64 scalar" ) self._valid = True diff --git a/cloudquery/sdk/scalar/float.py b/cloudquery/sdk/scalar/float.py index 49ee15d..33eadb3 100644 --- a/cloudquery/sdk/scalar/float.py +++ b/cloudquery/sdk/scalar/float.py @@ -9,7 +9,7 @@ def __init__(self, valid: bool = False, value: any = None, bitwidth: int = 64): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Float and self._bitwidth == scalar.bitwidth: + if isinstance(scalar, Float) and self._bitwidth == scalar.bitwidth: return self._value == scalar._value and self._valid == scalar._valid return False @@ -31,19 +31,19 @@ def set(self, value: any): self._value = value.value return - if type(value) == int: + if isinstance(value, int): self._value = float(value) - elif type(value) == float: + elif isinstance(value, float): self._value = value - elif type(value) == str: + elif isinstance(value, str): try: self._value = float(value) - except ValueError: + except ValueError as e: raise ScalarInvalidTypeError( - "Invalid value for Float{} scalar".format(self._bitwidth) - ) + f"Invalid value for Float{self._bitwidth} scalar" + ) from e else: raise ScalarInvalidTypeError( - "Invalid type {} for Float{} scalar".format(type(value), self._bitwidth) + f"Invalid type {type(value)} for Float{self._bitwidth} scalar" ) self._valid = True diff --git a/cloudquery/sdk/scalar/int.py b/cloudquery/sdk/scalar/int.py index 8133429..3e57e1b 100644 --- a/cloudquery/sdk/scalar/int.py +++ b/cloudquery/sdk/scalar/int.py @@ -11,7 +11,7 @@ def __init__(self, valid: bool = False, value: any = None, bitwidth: int = 64): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Int and self._bitwidth == scalar.bitwidth: + if isinstance(scalar, Int) and self._bitwidth == scalar.bitwidth: return self._value == scalar._value and self._valid == scalar._valid return False @@ -33,24 +33,24 @@ def set(self, value: any): self._value = value.value return - if type(value) == int: + if isinstance(value, int): val = value - elif type(value) == float: + elif isinstance(value, float): val = int(value) - elif type(value) == str: + elif isinstance(value, str): try: val = int(value) except ValueError as e: raise ScalarInvalidTypeError( - "Invalid type for Int{} scalar".format(self._bitwidth) + f"Invalid type for Int{self._bitwidth} scalar" ) from e else: raise ScalarInvalidTypeError( - "Invalid type {} for Int{} scalar".format(type(value), self._bitwidth) + f"Invalid type {type(value)} for Int{self._bitwidth} scalar" ) if val < self._min or val >= self._max: raise ScalarInvalidTypeError( - "Invalid Int{} scalar with value {}".format(self._bitwidth, val) + f"Invalid Int{self._bitwidth} scalar with value {val}" ) self._value = val self._valid = True diff --git a/cloudquery/sdk/scalar/json.py b/cloudquery/sdk/scalar/json.py index b06bb79..6160db8 100644 --- a/cloudquery/sdk/scalar/json.py +++ b/cloudquery/sdk/scalar/json.py @@ -6,7 +6,7 @@ class JSON(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == JSON: + if isinstance(scalar, JSON): return self._value == scalar._value and self._valid == scalar._valid return False @@ -18,7 +18,7 @@ def set(self, value: any): if value is None: return - if type(value) == str or type(value) == bytes: + if isinstance(value, (str, bytes)): # test if it is a valid json json.loads(value) self._value = value diff --git a/cloudquery/sdk/scalar/list.py b/cloudquery/sdk/scalar/list.py index 66560d3..0d0f1df 100644 --- a/cloudquery/sdk/scalar/list.py +++ b/cloudquery/sdk/scalar/list.py @@ -1,7 +1,7 @@ from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError from .scalar import NULL_VALUE from .vector import Vector -from typing import Any, Type, Union +from typing import Any, Type class List(Scalar): @@ -10,12 +10,8 @@ def __init__(self, scalar_type: Type[Scalar]): self._value = Vector(scalar_type) self._type = scalar_type - def __eq__(self, other: Union[None, "List"]) -> bool: - if other is None: - return False - if type(self) != type(other): - return False - if self._valid != other._valid: + def __eq__(self, other: "List") -> bool: + if (not isinstance(other, self.__class__)) or self._valid != other._valid: return False return self._value == other._value @@ -27,29 +23,29 @@ def type(self): def value(self): return self._value - def set(self, val: Any): - if val is None: + def set(self, value: Any): + if value is None: self._valid = False self._value = Vector() return - if isinstance(val, Scalar) and type(val) == self._type: - if not val.is_valid: + if isinstance(value, self._type): + if not value.is_valid: self._valid = False self._value = Vector() return - return self.set([val.value]) + return self.set([value.value]) - if isinstance(val, (list, tuple)): + if isinstance(value, (list, tuple)): self._value = Vector() - for item in val: + for item in value: scalar = self._type() scalar.set(item) self._value.append(scalar) self._valid = True return - raise ScalarInvalidTypeError("Invalid type {} for List".format(type(val))) + raise ScalarInvalidTypeError(f"Invalid type {type(value)} for List") def __str__(self) -> str: if not self._valid: diff --git a/cloudquery/sdk/scalar/scalar_factory.py b/cloudquery/sdk/scalar/scalar_factory.py index 8bbf1d4..1e7bc51 100644 --- a/cloudquery/sdk/scalar/scalar_factory.py +++ b/cloudquery/sdk/scalar/scalar_factory.py @@ -1,3 +1,5 @@ +from functools import partial + import pyarrow as pa from cloudquery.sdk.types import UUIDType, JSONType @@ -22,74 +24,41 @@ def __init__(self): def new_scalar(self, dt: pa.DataType): dt_id = dt.id - if dt_id == pa.types.lib.Type_INT64: - return Int(bitwidth=64) - elif dt_id == pa.types.lib.Type_INT32: - return Int(bitwidth=32) - elif dt_id == pa.types.lib.Type_INT16: - return Int(bitwidth=16) - elif dt_id == pa.types.lib.Type_INT8: - return Int(bitwidth=8) - elif dt_id == pa.types.lib.Type_UINT64: - return Uint(bitwidth=64) - elif dt_id == pa.types.lib.Type_UINT32: - return Uint(bitwidth=32) - elif dt_id == pa.types.lib.Type_UINT16: - return Uint(bitwidth=16) - elif dt_id == pa.types.lib.Type_UINT8: - return Uint(bitwidth=8) - elif ( - dt_id == pa.types.lib.Type_BINARY - or dt_id == pa.types.lib.Type_LARGE_BINARY - or dt_id == pa.types.lib.Type_FIXED_SIZE_BINARY - ): - return Binary() - elif dt_id == pa.types.lib.Type_BOOL: - return Bool() - elif dt_id == pa.types.lib.Type_DATE64: - return Date64() - elif dt_id == pa.types.lib.Type_DATE32: - return Date32() - # elif dt_id == pa.types.lib.Type_DECIMAL256: - # return () - # elif dt_id == pa.types.lib.Type_DECIMAL128: - # return () - # elif dt_id == pa.types.lib.Type_DICTIONARY: - # return () - # elif dt_id == pa.types.lib.Type_DURATION: - # return () - elif dt_id == pa.types.lib.Type_DOUBLE: - return Float(bitwidth=64) - elif dt_id == pa.types.lib.Type_FLOAT: - return Float(bitwidth=32) - elif dt_id == pa.types.lib.Type_HALF_FLOAT: - return Float(bitwidth=16) - # elif dt_id == pa.types.lib.Type_INTERVAL_MONTH_DAY_NANO: - # return () - elif ( - dt_id == pa.types.lib.Type_LIST - or dt_id == pa.types.lib.Type_LARGE_LIST - or dt_id == pa.types.lib.Type_FIXED_SIZE_LIST - ): - item = ScalarFactory.new_scalar(dt.field(0).type) - return List(type(item)) - # elif dt_id == pa.types.lib.Type_MAP: - # return () - elif ( - dt_id == pa.types.lib.Type_STRING or dt_id == pa.types.lib.Type_LARGE_STRING - ): - return String() - # elif dt_id == pa.types.lib.Type_STRUCT: - # return () - # elif dt_id == pa.types.lib.Type_TIME32: - # return () - # elif dt_id == pa.types.lib.Type_TIME64: - # return () - elif dt_id == pa.types.lib.Type_TIMESTAMP: - return Timestamp() - elif dt == UUIDType(): + type_id__map = { + pa.types.lib.Type_INT64: partial(Int, bitwidth=64), + pa.types.lib.Type_INT32: partial(Int, bitwidth=32), + pa.types.lib.Type_INT16: partial(Int, bitwidth=16), + pa.types.lib.Type_INT8: partial(Int, bitwidth=8), + pa.types.lib.Type_UINT64: partial(Uint, bitwidth=64), + pa.types.lib.Type_UINT32: partial(Uint, bitwidth=32), + pa.types.lib.Type_UINT16: partial(Uint, bitwidth=16), + pa.types.lib.Type_UINT8: partial(Uint, bitwidth=8), + pa.types.lib.Type_BINARY: Binary, + pa.types.lib.Type_LARGE_BINARY: Binary, + pa.types.lib.Type_FIXED_SIZE_BINARY: Binary, + pa.types.lib.Type_BOOL: Bool, + pa.types.lib.Type_DATE64: Date64, + pa.types.lib.Type_DATE32: Date32, + pa.types.lib.Type_DOUBLE: partial(Float, bitwidth=64), + pa.types.lib.Type_FLOAT: partial(Float, bitwidth=32), + pa.types.lib.Type_HALF_FLOAT: partial(Float, bitwidth=16), + pa.types.lib.Type_LIST: List, + pa.types.lib.Type_LARGE_LIST: List, + pa.types.lib.Type_FIXED_SIZE_LIST: List, + pa.types.lib.Type_STRING: String, + pa.types.lib.Type_LARGE_STRING: String, + pa.types.lib.Type_TIMESTAMP: Timestamp, + } + # Built-in Types + if dt_id in type_id__map: + scalar_type = type_id__map[dt_id] + if scalar_type == List: + item = self.new_scalar(dt.field(0).type) + return scalar_type(type(item)) + return scalar_type() + # Extension Types - Can't do the same trick as above as they don't have `id`s and they are not hashable. :( + if dt == UUIDType(): return UUID() - elif dt == JSONType(): + if dt == JSONType(): return JSON() - else: - raise ScalarInvalidTypeError("Invalid type {} for scalar".format(dt)) + raise ScalarInvalidTypeError(f"Invalid type {dt} for scalar") diff --git a/cloudquery/sdk/scalar/string.py b/cloudquery/sdk/scalar/string.py index 9fd3a79..c6b1da6 100644 --- a/cloudquery/sdk/scalar/string.py +++ b/cloudquery/sdk/scalar/string.py @@ -5,7 +5,7 @@ class String(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == String: + if isinstance(scalar, String): return self._value == scalar._value and self._valid == scalar._valid return False @@ -22,10 +22,10 @@ def set(self, value: any): self._value = value.value return - if type(value) == str: + if isinstance(value, str): self._valid = True self._value = value else: raise ScalarInvalidTypeError( - "Invalid type {} for String scalar".format(type(value)) + f"Invalid type {type(value)} for String scalar" ) diff --git a/cloudquery/sdk/scalar/timestamp.py b/cloudquery/sdk/scalar/timestamp.py index 8009f0f..1f14780 100644 --- a/cloudquery/sdk/scalar/timestamp.py +++ b/cloudquery/sdk/scalar/timestamp.py @@ -8,7 +8,7 @@ class Timestamp(Scalar): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Timestamp: + if isinstance(scalar, Timestamp): return self._value == scalar._value and self._valid == scalar._valid return False @@ -27,15 +27,15 @@ def set(self, value: Any): if isinstance(value, pd.Timestamp): self._value = value - elif type(value) == datetime: + elif isinstance(value, datetime): self._value = pd.to_datetime(value) - elif type(value) == str: + elif isinstance(value, str): self._value = pd.to_datetime(value) - elif type(value) == time: + elif isinstance(value, time): self._value = pd.to_datetime(datetime.combine(datetime.today(), value)) else: raise ScalarInvalidTypeError( - "Invalid type {} for Timestamp scalar".format(type(value)) + f"Invalid type {type(value)} for Timestamp scalar" ) self._valid = True diff --git a/cloudquery/sdk/scalar/uint.py b/cloudquery/sdk/scalar/uint.py index 51aeab9..ea876d4 100644 --- a/cloudquery/sdk/scalar/uint.py +++ b/cloudquery/sdk/scalar/uint.py @@ -10,7 +10,7 @@ def __init__(self, valid: bool = False, value: any = None, bitwidth: int = 64): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == Uint: + if isinstance(scalar, Uint): return ( self._bitwidth == scalar.bitwidth and self._value == scalar._value @@ -36,25 +36,23 @@ def set(self, value: any): self._value = value.value return - if type(value) == int: + if isinstance(value, int): val = value - elif type(value) == float: + elif isinstance(value, float): val = int(value) - elif type(value) == str: + elif isinstance(value, str): try: val = int(value) except ValueError as e: raise ScalarInvalidTypeError( - "Invalid value for Int{} scalar".format(self._bitwidth) + f"Invalid value for Int{self._bitwidth} scalar" ) from e else: raise ScalarInvalidTypeError( - "Invalid type {} for Int{} scalar".format(type(value), self._bitwidth) + f"Invalid type {type(value)} for Int{self._bitwidth} scalar" ) if val < 0 or val >= self._max: - raise ScalarInvalidTypeError( - "Invalid Uint{} scalar".format(self._bitwidth, val) - ) + raise ScalarInvalidTypeError(f"Invalid Uint{self._bitwidth} scalar {val}") self._value = val self._valid = True diff --git a/cloudquery/sdk/scalar/uuid.py b/cloudquery/sdk/scalar/uuid.py index 39462a8..4e49f9f 100644 --- a/cloudquery/sdk/scalar/uuid.py +++ b/cloudquery/sdk/scalar/uuid.py @@ -9,7 +9,7 @@ def __init__(self, valid: bool = False, value: uuid.UUID = None): def __eq__(self, scalar: Scalar) -> bool: if scalar is None: return False - if type(scalar) == UUID: + if isinstance(scalar, UUID): return self._value == scalar._value and self._valid == scalar._valid return False @@ -27,15 +27,13 @@ def set(self, value: any): self._value = value.value return - if type(value) == uuid.UUID: + if isinstance(value, uuid.UUID): self._value = value - elif type(value) == str: + elif isinstance(value, str): try: self._value = uuid.UUID(value) except ValueError as e: raise ScalarInvalidTypeError("Invalid type for UUID scalar") from e else: - raise ScalarInvalidTypeError( - "Invalid type {} for UUID scalar".format(type(value)) - ) + raise ScalarInvalidTypeError(f"Invalid type {type(value)} for UUID scalar") self._valid = True diff --git a/cloudquery/sdk/scheduler/__init__.py b/cloudquery/sdk/scheduler/__init__.py index 85396e3..9cdffee 100644 --- a/cloudquery/sdk/scheduler/__init__.py +++ b/cloudquery/sdk/scheduler/__init__.py @@ -1,2 +1,3 @@ from .scheduler import Scheduler from .table_resolver import TableResolver +from .table_resolver import Client diff --git a/cloudquery/sdk/scheduler/scheduler.py b/cloudquery/sdk/scheduler/scheduler.py index 8f8d111..cfd37b2 100644 --- a/cloudquery/sdk/scheduler/scheduler.py +++ b/cloudquery/sdk/scheduler/scheduler.py @@ -1,6 +1,6 @@ import queue from concurrent import futures -from typing import List, Generator, Any +from typing import List, Generator, Any, Optional import structlog @@ -17,7 +17,7 @@ class ThreadPoolExecutorWithQueueSizeLimit(futures.ThreadPoolExecutor): def __init__(self, maxsize, *args, **kwargs): - super(ThreadPoolExecutorWithQueueSizeLimit, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._work_queue = queue.Queue(maxsize=maxsize) @@ -84,7 +84,7 @@ def resolve_table( resolver: TableResolver, depth: int, client: Client, - parent_item: Resource, + parent_item: Optional[Resource], res: queue.Queue, ): try: @@ -108,7 +108,7 @@ def resolve_table( resource = self.resolve_resource( resolver, client, parent_item, item ) - except Exception as e: + except Exception: self._logger.error( "failed to resolve resource", client_id=client.id(), @@ -145,7 +145,7 @@ def resolve_table( resources=total_resources, depth=depth, ) - except Exception as e: + except Exception: self._logger.error( "table resolver finished with error", client_id=client.id(), @@ -184,12 +184,12 @@ def sync( finished_table_resolvers = 0 while True: message = res.get() - if type(message) == TableResolverStarted: + if isinstance(message, TableResolverStarted): total_table_resolvers += 1 if total_table_resolvers == finished_table_resolvers: break continue - elif type(message) == TableResolverFinished: + if isinstance(message, TableResolverFinished): finished_table_resolvers += 1 if total_table_resolvers == finished_table_resolvers: break diff --git a/cloudquery/sdk/scheduler/table_resolver.py b/cloudquery/sdk/scheduler/table_resolver.py index d1ffd77..be13d7c 100644 --- a/cloudquery/sdk/scheduler/table_resolver.py +++ b/cloudquery/sdk/scheduler/table_resolver.py @@ -1,6 +1,6 @@ from cloudquery.sdk.schema.table import Table from cloudquery.sdk.schema import Resource -from typing import Any, Generator, List +from typing import Any, Generator, List, Optional class Client: @@ -8,12 +8,10 @@ def id(self) -> str: raise NotImplementedError() -from cloudquery.sdk.schema import Resource -from cloudquery.sdk.schema.table import Table - - class TableResolver: - def __init__(self, table: Table, child_resolvers=[]) -> None: + def __init__(self, table: Table, child_resolvers: Optional[List] = None) -> None: + if child_resolvers is None: + child_resolvers = [] self._table = table self._child_resolvers = child_resolvers @@ -35,7 +33,7 @@ def pre_resource_resolve(self, client: Client, resource): return def resolve_column(self, client: Client, resource: Resource, column_name: str): - if type(resource.item) is dict: + if isinstance(resource.item, dict): if column_name in resource.item: resource.set(column_name, resource.item[column_name]) else: diff --git a/cloudquery/sdk/schema/column.py b/cloudquery/sdk/schema/column.py index 9be427c..4473d57 100644 --- a/cloudquery/sdk/schema/column.py +++ b/cloudquery/sdk/schema/column.py @@ -31,7 +31,7 @@ def __repr__(self) -> str: return f"Column(name={self.name}, type={self.type}, description={self.description}, primary_key={self.primary_key}, not_null={self.not_null}, incremental_key={self.incremental_key}, unique={self.unique})" def __eq__(self, __value: object) -> bool: - if type(__value) == Column: + if isinstance(__value, Column): return ( self.name == __value.name and self.type == __value.type diff --git a/cloudquery/sdk/schema/table.py b/cloudquery/sdk/schema/table.py index 7198d63..0c740c8 100644 --- a/cloudquery/sdk/schema/table.py +++ b/cloudquery/sdk/schema/table.py @@ -35,8 +35,8 @@ def __init__( self.relations = relations self.is_incremental = is_incremental - def multiplex(self, client) -> List[Table]: - raise [client] + def multiplex(self, client: Client) -> List[Client]: + return [client] def index_column(self, column_name: str) -> int: for i, column in enumerate(self.columns): @@ -44,6 +44,10 @@ def index_column(self, column_name: str) -> int: return i raise ValueError(f"Column {column_name} not found") + @property + def resolver(self): + raise NotImplementedError + @property def primary_keys(self): return [column.name for column in self.columns if column.primary_key] diff --git a/cloudquery/sdk/transformers/openapi.py b/cloudquery/sdk/transformers/openapi.py index 300acce..0ba28a8 100644 --- a/cloudquery/sdk/transformers/openapi.py +++ b/cloudquery/sdk/transformers/openapi.py @@ -6,22 +6,20 @@ def oapi_type_to_arrow_type(field) -> pa.DataType: oapi_type = field.get("type") - if oapi_type == "string": - return pa.string() - elif oapi_type == "number": - return pa.int64() - elif oapi_type == "integer": - return pa.int64() - elif oapi_type == "boolean": - return pa.bool_() - elif oapi_type == "array": - return JSONType() - elif oapi_type == "object": - return JSONType() + type_map = { + "string": pa.string, + "number": pa.int64, + "integer": pa.int64, + "boolean": pa.bool_, + "array": JSONType, + "object": JSONType, + } + _type = pa.string + if oapi_type in type_map: + _type = type_map[oapi_type] elif oapi_type is None and "$ref" in field: - return JSONType() - else: - return pa.string() + _type = JSONType + return _type() def get_column_by_name(columns: List[Column], name: str) -> Optional[Column]: @@ -31,14 +29,18 @@ def get_column_by_name(columns: List[Column], name: str) -> Optional[Column]: return None -def oapi_definition_to_columns(definition: Dict, override_columns=[]) -> List[Column]: +def oapi_definition_to_columns( + definition: Dict, override_columns: Optional[List] = None +) -> List[Column]: columns = [] for key, value in definition["properties"].items(): column_type = oapi_type_to_arrow_type(value) column = Column( name=key, type=column_type, description=value.get("description") ) - override_column = get_column_by_name(override_columns, key) + override_column = get_column_by_name( + override_columns if override_columns is not None else [], key + ) if override_column is not None: column.type = override_column.type column.primary_key = override_column.primary_key diff --git a/tests/serve/plugin.py b/tests/serve/plugin.py index 1619e53..26c212c 100644 --- a/tests/serve/plugin.py +++ b/tests/serve/plugin.py @@ -5,7 +5,6 @@ from concurrent import futures from cloudquery.sdk.schema import Table, Column from cloudquery.sdk import serve -from cloudquery.sdk import message from cloudquery.plugin_v3 import plugin_pb2_grpc, plugin_pb2, arrow from cloudquery.sdk.internal.memdb import MemDB