From 6e4c10f154b535e698efe79a0038274475c9df5e Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 27 Jul 2023 12:52:41 +0100 Subject: [PATCH 1/3] feat: Lint with black --- .github/workflows/lint.yml | 22 ++++++++++++++++++++++ Makefile | 6 +++++- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..c9fc421 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,22 @@ +name: Lint with Black + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install Black + run: pip install black + + - name: Run Black + run: black --check . diff --git a/Makefile b/Makefile index f464aab..cf4b3af 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,6 @@ test: - pytest . \ No newline at end of file + pytest . + +fmt: + pip install -q black + black . From a042842f7181be851ae7f74bd9c248962f98e283 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 27 Jul 2023 12:54:47 +0100 Subject: [PATCH 2/3] better 'on' --- .github/workflows/lint.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c9fc421..6d2bb3d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,10 @@ name: Lint with Black -on: [push, pull_request] +on: + pull_request: + push: + branches: + - main jobs: build: From f18a7092a3a1bf42b895b52467297026b398f145 Mon Sep 17 00:00:00 2001 From: Kemal Hadimli Date: Thu, 27 Jul 2023 12:54:57 +0100 Subject: [PATCH 3/3] fmt --- cloudquery/sdk/docs/generator.py | 36 +-- cloudquery/sdk/internal/memdb/__init__.py | 1 - cloudquery/sdk/internal/memdb/memdb.py | 28 ++- .../internal/servers/plugin_v3/__init__.py | 3 +- .../sdk/internal/servers/plugin_v3/plugin.py | 18 +- cloudquery/sdk/plugin/__init__.py | 2 +- cloudquery/sdk/scalar/binary.py | 43 ++-- cloudquery/sdk/scalar/bool.py | 49 ++-- cloudquery/sdk/scalar/date32.py | 50 ++-- cloudquery/sdk/scalar/float64.py | 45 ++-- cloudquery/sdk/scalar/int64.py | 45 ++-- cloudquery/sdk/scalar/scalar.py | 11 +- cloudquery/sdk/scalar/scalar_factory.py | 18 +- cloudquery/sdk/scalar/uuid.py | 43 ++-- cloudquery/sdk/scheduler/scheduler.py | 229 +++++++++++------- cloudquery/sdk/scheduler/table_resolver.py | 36 +-- cloudquery/sdk/schema/__init__.py | 1 + cloudquery/sdk/schema/arrow.py | 13 +- cloudquery/sdk/schema/column.py | 42 +++- cloudquery/sdk/schema/resource.py | 30 +-- cloudquery/sdk/schema/table.py | 18 +- cloudquery/sdk/serve/plugin.py | 66 +++-- cloudquery/sdk/transformers/__init__.py | 1 - cloudquery/sdk/types/uuid.py | 13 +- tests/docs/test_generator.py | 79 +++--- tests/internal/memdb/memdb.py | 4 +- tests/scalar/binary.py | 15 +- tests/scalar/bool.py | 15 +- tests/scalar/float64.py | 12 +- tests/scalar/int64.py | 12 +- tests/scalar/uuid.py | 13 +- tests/scheduler/scheduler.py | 27 ++- tests/serve/plugin.py | 29 ++- 33 files changed, 615 insertions(+), 432 deletions(-) diff --git a/cloudquery/sdk/docs/generator.py b/cloudquery/sdk/docs/generator.py index 5509753..7bc569e 100644 --- a/cloudquery/sdk/docs/generator.py +++ b/cloudquery/sdk/docs/generator.py @@ -24,7 +24,7 @@ def to_dict(self): "title": self.title, "description": self.description, "columns": [col.to_dict() for col in self.columns], - "relations": [rel.to_dict() for rel in self.relations] + "relations": [rel.to_dict() for rel in self.relations], } @@ -40,7 +40,7 @@ def to_dict(self): "name": self.name, "type": self.type, "is_primary_key": self.is_primary_key, - "is_incremental_key": self.is_incremental_key + "is_incremental_key": self.is_incremental_key, } @@ -58,9 +58,9 @@ def generate(self, directory: str, format: str): def _generate_json(self, directory: str): json_tables = self._jsonify_tables(self._tables) - buffer = bytes(json.dumps(json_tables, indent=2, ensure_ascii=False), 'utf-8') + buffer = bytes(json.dumps(json_tables, indent=2, ensure_ascii=False), "utf-8") output_path = pathlib.Path(directory) / "__tables.json" - with output_path.open('wb') as f: + with output_path.open("wb") as f: f.write(buffer) return None @@ -88,10 +88,12 @@ def _jsonify_tables(self, tables): def _generate_markdown(self, directory: str): env = jinja2.Environment() - env.globals['indent_to_depth'] = self._indent_to_depth - env.globals['all_tables_entry'] = self._all_tables_entry + env.globals["indent_to_depth"] = self._indent_to_depth + env.globals["all_tables_entry"] = self._all_tables_entry all_tables_template = env.from_string(ALL_TABLES) - rendered_all_tables = all_tables_template.render(plugin_name=self._plugin_name, tables=self._tables) + rendered_all_tables = all_tables_template.render( + plugin_name=self._plugin_name, tables=self._tables + ) formatted_all_tables = self._format_markdown(rendered_all_tables) with open(os.path.join(directory, "README.md"), "w") as f: @@ -111,9 +113,9 @@ def _render_table(self, directory: str, env: jinja2.Environment, table: Table): def _all_tables_entry(self, table: Table): env = jinja2.Environment() - env.globals['indent_to_depth'] = self._indent_to_depth - env.globals['all_tables_entry'] = self._all_tables_entry - env.globals['indent_table_to_depth'] = self._indent_table_to_depth + env.globals["indent_to_depth"] = self._indent_to_depth + env.globals["all_tables_entry"] = self._all_tables_entry + env.globals["indent_table_to_depth"] = self._indent_table_to_depth entry_template = env.from_string(ALL_TABLES_ENTRY) return entry_template.render(table=table) @@ -129,15 +131,15 @@ def _indent_table_to_depth(table: Table) -> str: @staticmethod def _indent_to_depth(text: str, depth: int) -> str: indentation = depth * 4 # You can adjust the number of spaces as needed - lines = text.split('\n') - indented_lines = [(' ' * indentation) + line for line in lines] - return '\n'.join(indented_lines) + lines = text.split("\n") + indented_lines = [(" " * indentation) + line for line in lines] + return "\n".join(indented_lines) @staticmethod def _format_markdown(text: str) -> str: - re_match_newlines = re.compile(r'\n{3,}') - re_match_headers = re.compile(r'(#{1,6}.+)\n+') + re_match_newlines = re.compile(r"\n{3,}") + re_match_headers = re.compile(r"(#{1,6}.+)\n+") - text = re_match_newlines.sub(r'\n\n', text) - text = re_match_headers.sub(r'\1\n\n', text) + text = re_match_newlines.sub(r"\n\n", text) + text = re_match_headers.sub(r"\1\n\n", text) return text diff --git a/cloudquery/sdk/internal/memdb/__init__.py b/cloudquery/sdk/internal/memdb/__init__.py index eb11793..5aafe4c 100644 --- a/cloudquery/sdk/internal/memdb/__init__.py +++ b/cloudquery/sdk/internal/memdb/__init__.py @@ -1,2 +1 @@ - from .memdb import MemDB diff --git a/cloudquery/sdk/internal/memdb/memdb.py b/cloudquery/sdk/internal/memdb/memdb.py index f7ce987..0147247 100644 --- a/cloudquery/sdk/internal/memdb/memdb.py +++ b/cloudquery/sdk/internal/memdb/memdb.py @@ -1,4 +1,3 @@ - from cloudquery.sdk import plugin from cloudquery.sdk import message from cloudquery.sdk import schema @@ -8,17 +7,20 @@ NAME = "memdb" VERSION = "development" + class MemDB(plugin.Plugin): def __init__(self) -> None: - super().__init__(NAME, VERSION) - self._tables: List[schema.Table] = [] - self._memory_db : Dict[str, pa.record] = { - "test_table": pa.record_batch([pa.array([1, 2, 3])], names=["test_column"]) - } - - def get_tables(self, options : plugin.TableOptions = None) -> List[plugin.Table]: - return self._tables - - def sync(self, options: plugin.SyncOptions) -> Generator[message.SyncMessage, None, None]: - for table, record in self._memory_db.items(): - yield message.SyncInsertMessage(record) + super().__init__(NAME, VERSION) + self._tables: List[schema.Table] = [] + self._memory_db: Dict[str, pa.record] = { + "test_table": pa.record_batch([pa.array([1, 2, 3])], names=["test_column"]) + } + + def get_tables(self, options: plugin.TableOptions = None) -> List[plugin.Table]: + return self._tables + + def sync( + self, options: plugin.SyncOptions + ) -> Generator[message.SyncMessage, None, None]: + for table, record in self._memory_db.items(): + yield message.SyncInsertMessage(record) diff --git a/cloudquery/sdk/internal/servers/plugin_v3/__init__.py b/cloudquery/sdk/internal/servers/plugin_v3/__init__.py index 9955984..9ad49dd 100644 --- a/cloudquery/sdk/internal/servers/plugin_v3/__init__.py +++ b/cloudquery/sdk/internal/servers/plugin_v3/__init__.py @@ -1,2 +1 @@ - -from .plugin import PluginServicer \ No newline at end of file +from .plugin import PluginServicer diff --git a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py index 706995c..62ee483 100644 --- a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py +++ b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py @@ -23,7 +23,9 @@ def Init(self, request: plugin_pb2.Init.Request, context): return plugin_pb2.Init.Response() def GetTables(self, request: plugin_pb2.GetTables.Request, context): - tables = self._plugin.get_tables(TableOptions(tables=request.tables, skip_tables=request.skip_tables)) + tables = self._plugin.get_tables( + TableOptions(tables=request.tables, skip_tables=request.skip_tables) + ) schema = tables_to_arrow_schemas(tables) tablesBytes = [] for s in schema: @@ -51,13 +53,15 @@ def Sync(self, request, context): writer.write_batch(msg.record) writer.close() buf = sink.getvalue().to_pybytes() - yield plugin_pb2.Sync.Response(insert=plugin_pb2.Sync.MessageInsert( - record=buf - )) + yield plugin_pb2.Sync.Response( + insert=plugin_pb2.Sync.MessageInsert(record=buf) + ) elif isinstance(msg, SyncMigrateTableMessage): - yield plugin_pb2.Sync.Response(migrate_table=plugin_pb2.Sync.MessageMigrateTable( - table=msg.table.to_arrow_schema().serialize().to_pybytes() - )) + yield plugin_pb2.Sync.Response( + migrate_table=plugin_pb2.Sync.MessageMigrateTable( + table=msg.table.to_arrow_schema().serialize().to_pybytes() + ) + ) else: # unknown sync message type raise NotImplementedError() diff --git a/cloudquery/sdk/plugin/__init__.py b/cloudquery/sdk/plugin/__init__.py index cd947b5..194433f 100644 --- a/cloudquery/sdk/plugin/__init__.py +++ b/cloudquery/sdk/plugin/__init__.py @@ -1 +1 @@ -from .plugin import Plugin, Table, TableOptions, SyncOptions \ No newline at end of file +from .plugin import Plugin, Table, TableOptions, SyncOptions diff --git a/cloudquery/sdk/scalar/binary.py b/cloudquery/sdk/scalar/binary.py index 03d99de..0311a32 100644 --- a/cloudquery/sdk/scalar/binary.py +++ b/cloudquery/sdk/scalar/binary.py @@ -1,34 +1,35 @@ from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError from .scalar import NULL_VALUE + class Binary(Scalar): def __init__(self, valid: bool = False, value: bytes = None): - self._valid = valid - self._value = value + self._valid = valid + self._value = value def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == Binary: - return self._value == scalar._value and self._valid == scalar._valid - return False - + if scalar is None: + return False + if type(scalar) == Binary: + return self._value == scalar._value and self._valid == scalar._valid + return False + def __str__(self) -> str: - return str(self._value) if self._valid else NULL_VALUE + return str(self._value) if self._valid else NULL_VALUE @property def value(self): - return self._value - + return self._value + def set(self, scalar): - if scalar is None: - return + if scalar is None: + return - if type(scalar) == bytes: - self._valid = True - self._value = scalar - elif type(scalar) == str: - self._valid = True - self._value = scalar.encode() - else: - raise ScalarInvalidTypeError("Invalid type for Binary scalar") + if type(scalar) == bytes: + self._valid = True + self._value = scalar + elif type(scalar) == str: + self._valid = True + self._value = scalar.encode() + else: + raise ScalarInvalidTypeError("Invalid type for Binary scalar") diff --git a/cloudquery/sdk/scalar/bool.py b/cloudquery/sdk/scalar/bool.py index e50983a..42c7b21 100644 --- a/cloudquery/sdk/scalar/bool.py +++ b/cloudquery/sdk/scalar/bool.py @@ -1,10 +1,10 @@ - from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError, NULL_VALUE from typing import Any + def parse_string_to_bool(input_string): - true_strings = ['true', 't', 'yes', 'y', '1'] - false_strings = ['false', 'f', 'no', 'n', '0'] + true_strings = ["true", "t", "yes", "y", "1"] + false_strings = ["false", "f", "no", "n", "0"] lower_input = input_string.lower() @@ -15,34 +15,35 @@ def parse_string_to_bool(input_string): else: raise ScalarInvalidTypeError("Invalid boolean string: {}".format(input_string)) + class Bool(Scalar): def __init__(self, valid: bool = False, value: bool = False) -> None: - self._valid = valid - self._value = value - + self._valid = valid + self._value = value + def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == Bool: - return self._value == scalar._value and self._valid == scalar._valid - return False + if scalar is None: + return False + if type(scalar) == Bool: + return self._value == scalar._value and self._valid == scalar._valid + return False def __str__(self) -> str: - return str(self._value) if self._valid else NULL_VALUE + return str(self._value) if self._valid else NULL_VALUE @property def value(self): - return self._value - + return self._value + def set(self, value: Any): - if value is None: - return + if value is None: + return - if type(value) == bool: - self._value = value - elif type(value) == str: - self._value = parse_string_to_bool(value) - else: - raise ScalarInvalidTypeError("Invalid type for Bool scalar") - - self._valid = True + if type(value) == bool: + self._value = value + elif type(value) == str: + self._value = parse_string_to_bool(value) + else: + raise ScalarInvalidTypeError("Invalid type for Bool scalar") + + self._valid = True diff --git a/cloudquery/sdk/scalar/date32.py b/cloudquery/sdk/scalar/date32.py index c6be5fe..cbc8bca 100644 --- a/cloudquery/sdk/scalar/date32.py +++ b/cloudquery/sdk/scalar/date32.py @@ -1,38 +1,38 @@ - from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError, NULL_VALUE from datetime import datetime, time from typing import Any + class Date32(Scalar): def __init__(self, valid: bool = False, value: bool = False) -> None: - self._valid = valid - self._value = value - + self._valid = valid + self._value = value + def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == Date32: - return self._value == scalar._value and self._valid == scalar._valid - return False + if scalar is None: + return False + if type(scalar) == Date32: + return self._value == scalar._value and self._valid == scalar._valid + return False def __str__(self) -> str: - return str(self._value) if self._valid else NULL_VALUE - + return str(self._value) if self._valid else NULL_VALUE + @property def value(self): - return self._value - + return self._value + def set(self, value: Any): - if value is None: - return + if value is None: + return - if type(value) == datetime: - self._value = value - elif type(value) == str: - self._value = datetime.strptime(value, "%Y-%m-%d") - elif type(value) == time: - self._value = datetime.combine(datetime.today(), value) - else: - raise ScalarInvalidTypeError("Invalid type for Bool scalar") - - self._valid = True + if type(value) == datetime: + self._value = value + elif type(value) == str: + self._value = datetime.strptime(value, "%Y-%m-%d") + elif type(value) == time: + self._value = datetime.combine(datetime.today(), value) + else: + raise ScalarInvalidTypeError("Invalid type for Bool scalar") + + self._valid = True diff --git a/cloudquery/sdk/scalar/float64.py b/cloudquery/sdk/scalar/float64.py index fa7f7fc..2b9b86e 100644 --- a/cloudquery/sdk/scalar/float64.py +++ b/cloudquery/sdk/scalar/float64.py @@ -1,34 +1,35 @@ from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError + class Float64(Scalar): def __init__(self, valid: bool = False, value: float = None): - self._valid = valid - self._value = value + self._valid = valid + self._value = value def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == Float64: - return self._value == scalar._value and self._valid == scalar._valid - return False + if scalar is None: + return False + if type(scalar) == Float64: + return self._value == scalar._value and self._valid == scalar._valid + return False @property def value(self): - return self._value - + return self._value + def set(self, value): - if value is None: - return + if value is None: + return - if type(value) == int: - self._value = float(value) - elif type(value) == float: - self._value = value - elif type(value) == str: - try: + if type(value) == int: self._value = float(value) - except ValueError: - raise ScalarInvalidTypeError("Invalid type for Float64 scalar") - else: - raise ScalarInvalidTypeError("Invalid type for Binary scalar") - self._valid = True + elif type(value) == float: + self._value = value + elif type(value) == str: + try: + self._value = float(value) + except ValueError: + raise ScalarInvalidTypeError("Invalid type for Float64 scalar") + else: + raise ScalarInvalidTypeError("Invalid type for Binary scalar") + self._valid = True diff --git a/cloudquery/sdk/scalar/int64.py b/cloudquery/sdk/scalar/int64.py index faec568..9a7d77c 100644 --- a/cloudquery/sdk/scalar/int64.py +++ b/cloudquery/sdk/scalar/int64.py @@ -1,34 +1,37 @@ from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError + class Int64(Scalar): def __init__(self, valid: bool = False, value: float = None): - self._valid = valid - self._value = value + self._valid = valid + self._value = value def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == Int64: - return self._value == scalar._value and self._valid == scalar._valid - return False + if scalar is None: + return False + if type(scalar) == Int64: + return self._value == scalar._value and self._valid == scalar._valid + return False @property def value(self): - return self._value + return self._value def set(self, value): - if value is None: - return + if value is None: + return - if type(value) == int: - self._value = value - elif type(value) == float: - self._value = int(value) - elif type(value) == str: - try: + if type(value) == int: + self._value = value + elif type(value) == float: self._value = int(value) - except ValueError as e: - raise ScalarInvalidTypeError("Invalid type for Int64 scalar") from e - else: - raise ScalarInvalidTypeError("Invalid type {} for Int64 scalar".format(type(value))) - self._valid = True + elif type(value) == str: + try: + self._value = int(value) + except ValueError as e: + raise ScalarInvalidTypeError("Invalid type for Int64 scalar") from e + else: + raise ScalarInvalidTypeError( + "Invalid type {} for Int64 scalar".format(type(value)) + ) + self._valid = True diff --git a/cloudquery/sdk/scalar/scalar.py b/cloudquery/sdk/scalar/scalar.py index a8891fc..379a912 100644 --- a/cloudquery/sdk/scalar/scalar.py +++ b/cloudquery/sdk/scalar/scalar.py @@ -1,14 +1,15 @@ - NULL_VALUE = "null" + class ScalarInvalidTypeError(Exception): - pass + pass + class Scalar: @property def is_valid(self) -> bool: - return self._valid - + return self._valid + @property def value(self): - raise NotImplementedError("Scalar value not implemented") + raise NotImplementedError("Scalar value not implemented") diff --git a/cloudquery/sdk/scalar/scalar_factory.py b/cloudquery/sdk/scalar/scalar_factory.py index d7e7303..27f7f6a 100644 --- a/cloudquery/sdk/scalar/scalar_factory.py +++ b/cloudquery/sdk/scalar/scalar_factory.py @@ -1,15 +1,15 @@ - import pyarrow as pa from .scalar import ScalarInvalidTypeError from .int64 import Int64 + class ScalarFactory: - def __init__(self): - pass + def __init__(self): + pass - def new_scalar(self, dt): - dt_id = dt.id - if dt_id == pa.types.lib.Type_INT64: - return Int64() - else: - raise ScalarInvalidTypeError("Invalid type {} for scalar".format(dt)) + def new_scalar(self, dt): + dt_id = dt.id + if dt_id == pa.types.lib.Type_INT64: + return Int64() + else: + raise ScalarInvalidTypeError("Invalid type {} for scalar".format(dt)) diff --git a/cloudquery/sdk/scalar/uuid.py b/cloudquery/sdk/scalar/uuid.py index ef309a8..c97e4da 100644 --- a/cloudquery/sdk/scalar/uuid.py +++ b/cloudquery/sdk/scalar/uuid.py @@ -1,33 +1,36 @@ import uuid from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError + class UUID(Scalar): def __init__(self, valid: bool = False, value: uuid.UUID = None): - self._valid = valid - self._value = value + self._valid = valid + self._value = value def __eq__(self, scalar: Scalar) -> bool: - if scalar is None: - return False - if type(scalar) == UUID: - return self._value == scalar._value and self._valid == scalar._valid - return False + if scalar is None: + return False + if type(scalar) == UUID: + return self._value == scalar._value and self._valid == scalar._valid + return False @property def value(self): - return self._value + return self._value def set(self, value): - if value is None: - return + if value is None: + return - if type(value) == uuid.UUID: - self._value = value - elif type(value) == str: - try: - self._value = uuid.UUID(value) - except ValueError as e: - raise ScalarInvalidTypeError("Invalid type for UUID scalar") from e - else: - raise ScalarInvalidTypeError("Invalid type {} for UUID scalar".format(type(value))) - self._valid = True + if type(value) == uuid.UUID: + self._value = value + elif type(value) == str: + try: + self._value = uuid.UUID(value) + except ValueError as e: + raise ScalarInvalidTypeError("Invalid type for UUID scalar") from e + else: + raise ScalarInvalidTypeError( + "Invalid type {} for UUID scalar".format(type(value)) + ) + self._valid = True diff --git a/cloudquery/sdk/scheduler/scheduler.py b/cloudquery/sdk/scheduler/scheduler.py index 6f4ea99..a2cb270 100644 --- a/cloudquery/sdk/scheduler/scheduler.py +++ b/cloudquery/sdk/scheduler/scheduler.py @@ -1,11 +1,14 @@ - from typing import List, Generator, Any import queue import time import structlog from enum import Enum from cloudquery.sdk.schema import Table, Resource -from cloudquery.sdk.message import SyncMessage, SyncInsertMessage, SyncMigrateTableMessage +from cloudquery.sdk.message import ( + SyncMessage, + SyncInsertMessage, + SyncMigrateTableMessage, +) from concurrent import futures from typing import Generator from .table_resolver import TableResolver @@ -21,105 +24,161 @@ def __init__(self, maxsize, *args, **kwargs): class TableResolverStarted: - def __init__(self, count=1) -> None: - self._count = count - - @property - def count(self): - return self._count + def __init__(self, count=1) -> None: + self._count = count + + @property + def count(self): + return self._count class TableResolverFinished: - def __init__(self) -> None: - pass + def __init__(self) -> None: + pass class Scheduler: - def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3, logger=None): + def __init__( + self, concurrency: int, queue_size: int = 0, max_depth: int = 3, logger=None + ): self._queue = queue.Queue() self._max_depth = max_depth if logger is None: - self._logger = structlog.get_logger() + self._logger = structlog.get_logger() if concurrency <= 0: - raise ValueError("concurrency must be greater than 0") + raise ValueError("concurrency must be greater than 0") if max_depth <= 0: raise ValueError("max_depth must be greater than 0") - self._queue_size = queue_size if queue_size > 0 else concurrency * QUEUE_PER_WORKER - self._pools : List[ThreadPoolExecutorWithQueueSizeLimit] = [] + self._queue_size = ( + queue_size if queue_size > 0 else concurrency * QUEUE_PER_WORKER + ) + self._pools: List[ThreadPoolExecutorWithQueueSizeLimit] = [] current_depth_concurrency = concurrency current_depth_queue_size = queue_size for _ in range(max_depth + 1): - self._pools.append(ThreadPoolExecutorWithQueueSizeLimit(maxsize=current_depth_queue_size,max_workers=current_depth_concurrency)) - current_depth_concurrency = current_depth_concurrency // 2 if current_depth_concurrency > 1 else 1 - current_depth_queue_size = current_depth_queue_size // 2 if current_depth_queue_size > 1 else 1 - + self._pools.append( + ThreadPoolExecutorWithQueueSizeLimit( + maxsize=current_depth_queue_size, + max_workers=current_depth_concurrency, + ) + ) + current_depth_concurrency = ( + current_depth_concurrency // 2 if current_depth_concurrency > 1 else 1 + ) + current_depth_queue_size = ( + current_depth_queue_size // 2 if current_depth_queue_size > 1 else 1 + ) + def shutdown(self): - for pool in self._pools: - pool.shutdown() + for pool in self._pools: + pool.shutdown() + + def resolve_resource( + self, resolver: TableResolver, client, parent: Resource, item: Any + ) -> Resource: + resource = Resource(resolver.table, parent, item) + resolver.pre_resource_resolve(client, resource) + for column in resolver.table.columns: + resolver.resolve_column(client, resource, column.name) + resolver.post_resource_resolve(client, resource) + return resource + + def resolve_table( + self, + resolver: TableResolver, + depth: int, + client, + parent_item: Resource, + res: queue.Queue, + ): + table_resolvers_started = 0 + try: + if depth == 0: + self._logger.info( + "table resolver started", table=resolver.table.name, depth=depth + ) + else: + self._logger.debug( + "table resolver started", table=resolver.table.name, depth=depth + ) + total_resources = 0 + for item in resolver.resolve(client, parent_item): + resource = self.resolve_resource(resolver, client, parent_item, item) + res.put(SyncInsertMessage(resource.to_arrow_record())) + for child_resolvers in resolver.child_resolvers: + self._pools[depth + 1].submit( + self.resolve_table, + child_resolvers, + depth + 1, + client, + resource, + res, + ) + table_resolvers_started += 1 + total_resources += 1 + if depth == 0: + self._logger.info( + "table resolver finished successfully", + table=resolver.table.name, + depth=depth, + ) + else: + self._logger.debug( + "table resolver finished successfully", + table=resolver.table.name, + depth=depth, + ) + except Exception as e: + self._logger.error( + "table resolver finished with error", + table=resolver.table.name, + depth=depth, + exception=e, + ) + finally: + res.put(TableResolverStarted(count=table_resolvers_started)) + res.put(TableResolverFinished()) - def resolve_resource(self, resolver: TableResolver, client, parent: Resource, item: Any) -> Resource: - resource = Resource(resolver.table, parent, item) - resolver.pre_resource_resolve(client, resource) - for column in resolver.table.columns: - resolver.resolve_column(client, resource, column.name) - resolver.post_resource_resolve(client, resource) - return resource + def _sync( + self, + client, + resolvers: List[TableResolver], + res: queue.Queue, + deterministic_cq_id=False, + ): + total_table_resolvers = 0 + try: + for resolver in resolvers: + clients = resolver.multiplex(client) + for client in clients: + self._pools[0].submit( + self.resolve_table, resolver, 0, client, None, res + ) + total_table_resolvers += 1 + finally: + res.put(TableResolverStarted(total_table_resolvers)) - def resolve_table(self, resolver: TableResolver, depth: int, client, parent_item: Resource, res: queue.Queue): - table_resolvers_started = 0 - try: - if depth == 0: - self._logger.info("table resolver started", table=resolver.table.name, depth=depth) - else: - self._logger.debug("table resolver started", table=resolver.table.name, depth=depth) - total_resources = 0 - for item in resolver.resolve(client, parent_item): - resource = self.resolve_resource(resolver, client, parent_item, item) - res.put(SyncInsertMessage(resource.to_arrow_record())) - for child_resolvers in resolver.child_resolvers: - self._pools[depth + 1].submit(self.resolve_table, child_resolvers, depth + 1, client, resource, res) - table_resolvers_started += 1 - total_resources += 1 - if depth == 0: - self._logger.info("table resolver finished successfully", table=resolver.table.name, depth=depth) - else: - self._logger.debug("table resolver finished successfully", table=resolver.table.name, depth=depth) - except Exception as e: - self._logger.error("table resolver finished with error", table=resolver.table.name, depth=depth, exception=e) - finally: - res.put(TableResolverStarted(count=table_resolvers_started)) - res.put(TableResolverFinished()) - - def _sync(self, client, resolvers: List[TableResolver], res: queue.Queue, deterministic_cq_id=False): - total_table_resolvers = 0 - try: + def sync( + self, client, resolvers: List[TableResolver], deterministic_cq_id=False + ) -> Generator[SyncMessage, None, None]: + res = queue.Queue() for resolver in resolvers: - clients = resolver.multiplex(client) - for client in clients: - self._pools[0].submit(self.resolve_table, resolver, 0, client, None, res) - total_table_resolvers += 1 - finally: - res.put(TableResolverStarted(total_table_resolvers)) - - def sync(self, client, resolvers: List[TableResolver], deterministic_cq_id=False) -> Generator[SyncMessage, None, None]: - res = queue.Queue() - for resolver in resolvers: - yield SyncMigrateTableMessage(schema=resolver.table.to_arrow_schema()) - thread = futures.ThreadPoolExecutor() - thread.submit(self._sync, client, resolvers, res, deterministic_cq_id) - total_table_resolvers = 0 - finished_table_resolvers = 0 - while True: - message = res.get() - if type(message) == TableResolverStarted: - total_table_resolvers += message.count - if total_table_resolvers == finished_table_resolvers: - break - continue - elif type(message) == TableResolverFinished: - finished_table_resolvers += 1 - if total_table_resolvers == finished_table_resolvers: - break - continue - yield message - thread.shutdown() + yield SyncMigrateTableMessage(schema=resolver.table.to_arrow_schema()) + thread = futures.ThreadPoolExecutor() + thread.submit(self._sync, client, resolvers, res, deterministic_cq_id) + total_table_resolvers = 0 + finished_table_resolvers = 0 + while True: + message = res.get() + if type(message) == TableResolverStarted: + total_table_resolvers += message.count + if total_table_resolvers == finished_table_resolvers: + break + continue + elif type(message) == TableResolverFinished: + finished_table_resolvers += 1 + if total_table_resolvers == finished_table_resolvers: + break + continue + yield message + thread.shutdown() diff --git a/cloudquery/sdk/scheduler/table_resolver.py b/cloudquery/sdk/scheduler/table_resolver.py index f1a50c9..84cdc25 100644 --- a/cloudquery/sdk/scheduler/table_resolver.py +++ b/cloudquery/sdk/scheduler/table_resolver.py @@ -1,37 +1,37 @@ - from cloudquery.sdk.schema.table import Table from cloudquery.sdk.schema import Resource -from typing import Any,Generator +from typing import Any, Generator + class TableResolver: def __init__(self, table: Table, child_resolvers=[]) -> None: self._table = table self._child_resolvers = child_resolvers - + @property def table(self) -> Table: - return self._table + return self._table @property def child_resolvers(self): - return self._child_resolvers + return self._child_resolvers def multiplex(self, client): - return [client] - + return [client] + def resolve(self, client, parent_resource) -> Generator[Any, None, None]: - raise NotImplementedError() + raise NotImplementedError() def pre_resource_resolve(self, client, resource): - return - + return + def resolve_column(self, client, resource: Resource, column_name: str): - if type(resource.item) is dict: - if column_name in resource.item: - resource.set(column_name, resource.item[column_name]) - else: - if hasattr(resource.item, column_name): - resource.set(column_name, resource.item[column_name]) - + if type(resource.item) is dict: + if column_name in resource.item: + resource.set(column_name, resource.item[column_name]) + else: + if hasattr(resource.item, column_name): + resource.set(column_name, resource.item[column_name]) + def post_resource_resolve(self, client, resource): - return + return diff --git a/cloudquery/sdk/schema/__init__.py b/cloudquery/sdk/schema/__init__.py index 7aeb5bb..3e3ecc7 100644 --- a/cloudquery/sdk/schema/__init__.py +++ b/cloudquery/sdk/schema/__init__.py @@ -1,4 +1,5 @@ from .column import Column from .table import Table, tables_to_arrow_schemas from .resource import Resource + # from .table_resolver import TableReso diff --git a/cloudquery/sdk/schema/arrow.py b/cloudquery/sdk/schema/arrow.py index 98d876d..3800326 100644 --- a/cloudquery/sdk/schema/arrow.py +++ b/cloudquery/sdk/schema/arrow.py @@ -1,10 +1,9 @@ - -METADATA_UNIQUE = "cq:extension:unique" -METADATA_PRIMARY_KEY = "cq:extension:primary_key" +METADATA_UNIQUE = "cq:extension:unique" +METADATA_PRIMARY_KEY = "cq:extension:primary_key" METADATA_CONSTRAINT_NAME = "cq:extension:constraint_name" -METADATA_INCREMENTAL = "cq:extension:incremental" +METADATA_INCREMENTAL = "cq:extension:incremental" -METADATA_TRUE = "true" -METADATA_FALSE = "false" -METADATA_TABLE_NAME = "cq:table_name" +METADATA_TRUE = "true" +METADATA_FALSE = "false" +METADATA_TABLE_NAME = "cq:table_name" METADATA_TABLE_DESCRIPTION = "cq:table_description" diff --git a/cloudquery/sdk/schema/column.py b/cloudquery/sdk/schema/column.py index 9df08bb..55362e7 100644 --- a/cloudquery/sdk/schema/column.py +++ b/cloudquery/sdk/schema/column.py @@ -6,22 +6,35 @@ class Column: - def __init__(self, name: str, type: pa.DataType, - description: str = '', primary_key: bool = False, not_null: bool = False, - incremental_key: bool = False, unique: bool = False) -> None: + def __init__( + self, + name: str, + type: pa.DataType, + description: str = "", + primary_key: bool = False, + not_null: bool = False, + incremental_key: bool = False, + unique: bool = False, + ) -> None: self.name = name self.type = type self.description = description self.primary_key = primary_key self.not_null = not_null self.incremental_key = incremental_key - self.unique = unique + self.unique = unique def to_arrow_field(self): metadata = { - arrow.METADATA_PRIMARY_KEY: arrow.METADATA_TRUE if self.primary_key else arrow.METADATA_FALSE, - arrow.METADATA_UNIQUE: arrow.METADATA_TRUE if self.unique else arrow.METADATA_FALSE, - arrow.METADATA_INCREMENTAL: arrow.METADATA_TRUE if self.incremental_key else arrow.METADATA_FALSE, + arrow.METADATA_PRIMARY_KEY: arrow.METADATA_TRUE + if self.primary_key + else arrow.METADATA_FALSE, + arrow.METADATA_UNIQUE: arrow.METADATA_TRUE + if self.unique + else arrow.METADATA_FALSE, + arrow.METADATA_INCREMENTAL: arrow.METADATA_TRUE + if self.incremental_key + else arrow.METADATA_FALSE, } return pa.field(self.name, self.type, metadata=metadata) @@ -30,7 +43,14 @@ def from_arrow_field(field: pa.Field) -> Column: metadata = field.metadata primary_key = metadata.get(arrow.METADATA_PRIMARY_KEY) == arrow.METADATA_TRUE unique = metadata.get(arrow.METADATA_UNIQUE) == arrow.METADATA_TRUE - incremental_key = metadata.get(arrow.METADATA_INCREMENTAL) == arrow.METADATA_TRUE - return Column(field.name, field.type, - primary_key=primary_key, not_null=not field.nullable, unique=unique, - incremental_key=incremental_key) + incremental_key = ( + metadata.get(arrow.METADATA_INCREMENTAL) == arrow.METADATA_TRUE + ) + return Column( + field.name, + field.type, + primary_key=primary_key, + not_null=not field.nullable, + unique=unique, + incremental_key=incremental_key, + ) diff --git a/cloudquery/sdk/schema/resource.py b/cloudquery/sdk/schema/resource.py index 12121a3..7798d0c 100644 --- a/cloudquery/sdk/schema/resource.py +++ b/cloudquery/sdk/schema/resource.py @@ -1,27 +1,29 @@ - from .table import Table from typing import Any import pyarrow as pa from cloudquery.sdk.scalar import ScalarFactory + class Resource: def __init__(self, table: Table, parent, item: Any) -> None: - self._table = table - self._parent = parent - self._item = item - factory = ScalarFactory() - self._data = [factory.new_scalar(i.type) for i in table.columns] - + self._table = table + self._parent = parent + self._item = item + factory = ScalarFactory() + self._data = [factory.new_scalar(i.type) for i in table.columns] + @property def item(self): - return self._item + return self._item def set(self, column_name: str, value: Any): - index_column = self._table.index_column(column_name) - self._data[index_column].set(value) - + index_column = self._table.index_column(column_name) + self._data[index_column].set(value) + def to_list_of_arr(self): - return [[self._data[i].value] for i, _ in enumerate(self._table.columns)] - + return [[self._data[i].value] for i, _ in enumerate(self._table.columns)] + def to_arrow_record(self): - return pa.record_batch(self.to_list_of_arr(), schema=self._table.to_arrow_schema()) + return pa.record_batch( + self.to_list_of_arr(), schema=self._table.to_arrow_schema() + ) diff --git a/cloudquery/sdk/schema/table.py b/cloudquery/sdk/schema/table.py index fd159f1..5cee64e 100644 --- a/cloudquery/sdk/schema/table.py +++ b/cloudquery/sdk/schema/table.py @@ -7,12 +7,22 @@ from cloudquery.sdk.schema import arrow from .column import Column + class Client: pass + class Table: - def __init__(self, name: str, columns: List[Column], title: str = "", description: str = "", - parent: Table = None, relations: List[Table] = None, is_incremental: bool = False) -> None: + def __init__( + self, + name: str, + columns: List[Column], + title: str = "", + description: str = "", + parent: Table = None, + relations: List[Table] = None, + is_incremental: bool = False, + ) -> None: self.name = name self.columns = columns self.title = title @@ -24,10 +34,10 @@ def __init__(self, name: str, columns: List[Column], title: str = "", descriptio self.is_incremental = is_incremental def multiplex(self, client) -> List[Table]: - raise [client] + raise [client] def resolver(self, client: Client, parent=None) -> Generator[Any]: - raise NotImplementedError() + raise NotImplementedError() def index_column(self, column_name: str) -> int: for i, column in enumerate(self.columns): diff --git a/cloudquery/sdk/serve/plugin.py b/cloudquery/sdk/serve/plugin.py index 55e7232..f51503d 100644 --- a/cloudquery/sdk/serve/plugin.py +++ b/cloudquery/sdk/serve/plugin.py @@ -19,26 +19,45 @@ def get_logger(args): log = structlog.get_logger() return log + class PluginCommand: def __init__(self, plugin: Plugin): self._plugin = plugin def run(self, args): parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest='command', required=True) + subparsers = parser.add_subparsers(dest="command", required=True) serve_parser = subparsers.add_parser("serve", help="Start plugin server") - serve_parser.add_argument("--log-level", type=str, default="info", - choices=["trace", "debug", "info", "warn", "error"], help="log level") - serve_parser.add_argument("--log-format", type=str, default="text", choices=["text", "json"]) - serve_parser.add_argument("--address", type=str, default="localhost:7777", - help="address to serve on. can be tcp: 'localhost:7777' or unix socket: '/tmp/plugin.rpc.sock'") - serve_parser.add_argument("--network", type=str, default="tcp", choices=["tcp", "unix"], - help="network to serve on. can be tcp or unix") - - doc_parser = subparsers.add_parser("doc", formatter_class=argparse.RawTextHelpFormatter, - help="Generate documentation for tables", - description="""Generate documentation for tables. + serve_parser.add_argument( + "--log-level", + type=str, + default="info", + choices=["trace", "debug", "info", "warn", "error"], + help="log level", + ) + serve_parser.add_argument( + "--log-format", type=str, default="text", choices=["text", "json"] + ) + serve_parser.add_argument( + "--address", + type=str, + default="localhost:7777", + help="address to serve on. can be tcp: 'localhost:7777' or unix socket: '/tmp/plugin.rpc.sock'", + ) + serve_parser.add_argument( + "--network", + type=str, + default="tcp", + choices=["tcp", "unix"], + help="network to serve on. can be tcp or unix", + ) + + doc_parser = subparsers.add_parser( + "doc", + formatter_class=argparse.RawTextHelpFormatter, + help="Generate documentation for tables", + description="""Generate documentation for tables. If format is markdown, a destination directory will be created (if necessary) containing markdown files. Example: @@ -47,10 +66,15 @@ def run(self, args): If format is JSON, a destination directory will be created (if necessary) with a single json file called __tables.json. Example: doc --format json . -""") +""", + ) doc_parser.add_argument("directory", type=str) - doc_parser.add_argument("--format", type=str, default="json", - help="output format. one of: {}".format(",".join(DOC_FORMATS))) + doc_parser.add_argument( + "--format", + type=str, + default="json", + help="output format. one of: {}".format(",".join(DOC_FORMATS)), + ) parsed_args = parser.parse_args(args) if parsed_args.command == "serve": @@ -65,18 +89,22 @@ def _serve(self, args): logger = get_logger(args) self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) discovery_pb2_grpc.add_DiscoveryServicer_to_server( - DiscoveryServicer([3]), self._server) + DiscoveryServicer([3]), self._server + ) plugin_pb2_grpc.add_PluginServicer_to_server( - PluginServicer(self._plugin, logger), self._server) + PluginServicer(self._plugin, logger), self._server + ) self._server.add_insecure_port(args.address) print("Starting server. Listening on " + args.address) self._server.start() self._server.wait_for_termination() - + def stop(self): self._server.stop(5) def _generate_docs(self, args): print("Generating docs in format: " + args.format) - generator = Generator(self._plugin.name(), self._plugin.get_tables(tables=["*"], skip_tables=[])) + generator = Generator( + self._plugin.name(), self._plugin.get_tables(tables=["*"], skip_tables=[]) + ) generator.generate(args.directory, args.format) diff --git a/cloudquery/sdk/transformers/__init__.py b/cloudquery/sdk/transformers/__init__.py index a26fb95..b762412 100644 --- a/cloudquery/sdk/transformers/__init__.py +++ b/cloudquery/sdk/transformers/__init__.py @@ -1,2 +1 @@ - from .transformers import transform_list_of_dict diff --git a/cloudquery/sdk/types/uuid.py b/cloudquery/sdk/types/uuid.py index 0516a38..940a04a 100644 --- a/cloudquery/sdk/types/uuid.py +++ b/cloudquery/sdk/types/uuid.py @@ -2,6 +2,7 @@ import pyarrow import sys + class UuidType(pa.PyExtensionType): def __init__(self): pa.PyExtensionType.__init__(self, pa.binary(16)) @@ -10,12 +11,12 @@ def __reduce__(self): return UuidType, () def __arrow_ext_serialize__(self): - # since we don't have a parameterized type, we don't need extra - # metadata to be deserialized - return b'uuid-serialized' + # since we don't have a parameterized type, we don't need extra + # metadata to be deserialized + return b"uuid-serialized" @classmethod def __arrow_ext_deserialize__(self, storage_type, serialized): - # return an instance of this subclass given the serialized - # metadata. - return UuidType() + # return an instance of this subclass given the serialized + # metadata. + return UuidType() diff --git a/tests/docs/test_generator.py b/tests/docs/test_generator.py index ad5b228..2bcae3f 100644 --- a/tests/docs/test_generator.py +++ b/tests/docs/test_generator.py @@ -25,26 +25,49 @@ def update_snapshot(name, content): class T(unittest.TestCase): def test_docs_generator_markdown(self): tables = [ - Table(name="test_table", title="Test Table", columns=[ - Column("string", pa.string(), primary_key=True), - Column("int32", pa.int32()), - ]), - Table(name="test_table_composite_pk", title="Composite PKs", is_incremental=True, columns=[ - Column("pk1", pa.string(), primary_key=True, incremental_key=True), - Column("pk2", pa.string(), primary_key=True, incremental_key=True), - Column("int32", pa.int32()), - ]), - Table(name="test_table_relations", title="Table Relations", is_incremental=True, columns=[ - Column("pk1", pa.string(), primary_key=True), - ], relations=[ - Table(name="test_table_child", title="Child Table", columns=[ + Table( + name="test_table", + title="Test Table", + columns=[ + Column("string", pa.string(), primary_key=True), + Column("int32", pa.int32()), + ], + ), + Table( + name="test_table_composite_pk", + title="Composite PKs", + is_incremental=True, + columns=[ + Column("pk1", pa.string(), primary_key=True, incremental_key=True), + Column("pk2", pa.string(), primary_key=True, incremental_key=True), + Column("int32", pa.int32()), + ], + ), + Table( + name="test_table_relations", + title="Table Relations", + is_incremental=True, + columns=[ Column("pk1", pa.string(), primary_key=True), - Column("fk1", pa.string()), - ], relations=[ - Table(name="test_table_grandchild", title="Grandchild Table", columns=[ - Column("pk1", pa.string(), primary_key=True)]) - ]) - ]), + ], + relations=[ + Table( + name="test_table_child", + title="Child Table", + columns=[ + Column("pk1", pa.string(), primary_key=True), + Column("fk1", pa.string()), + ], + relations=[ + Table( + name="test_table_grandchild", + title="Grandchild Table", + columns=[Column("pk1", pa.string(), primary_key=True)], + ) + ], + ) + ], + ), ] # set parent relations @@ -57,14 +80,16 @@ def test_docs_generator_markdown(self): files = glob.glob(os.path.join(d, "*.md")) file_names = [os.path.basename(f) for f in files] - assert sorted(file_names) == sorted([ - "README.md", - "test_table_composite_pk.md", - "test_table.md", - "test_table_relations.md", - "test_table_child.md", - "test_table_grandchild.md", - ]) + assert sorted(file_names) == sorted( + [ + "README.md", + "test_table_composite_pk.md", + "test_table.md", + "test_table_relations.md", + "test_table_child.md", + "test_table_grandchild.md", + ] + ) updated_snapshots = False for file_name in file_names: diff --git a/tests/internal/memdb/memdb.py b/tests/internal/memdb/memdb.py index f72a29f..3b0afca 100644 --- a/tests/internal/memdb/memdb.py +++ b/tests/internal/memdb/memdb.py @@ -1,11 +1,11 @@ - from cloudquery.sdk.internal import memdb from cloudquery.sdk.plugin import SyncOptions + def test_memdb(): p = memdb.MemDB() p.init(None) msgs = [] for msg in p.sync(SyncOptions()): msgs.append(msg) - assert len(msgs) == 1 \ No newline at end of file + assert len(msgs) == 1 diff --git a/tests/scalar/binary.py b/tests/scalar/binary.py index 837f889..560877e 100644 --- a/tests/scalar/binary.py +++ b/tests/scalar/binary.py @@ -2,12 +2,15 @@ from cloudquery.sdk.scalar import Binary -@pytest.mark.parametrize("input_value,expected_scalar", [ - (b'123', Binary(True, b'123')), - (b'', Binary(True, b'')), - (None, Binary()), - (bytes([1,2,3]), Binary(True, b'\x01\x02\x03')), -]) +@pytest.mark.parametrize( + "input_value,expected_scalar", + [ + (b"123", Binary(True, b"123")), + (b"", Binary(True, b"")), + (None, Binary()), + (bytes([1, 2, 3]), Binary(True, b"\x01\x02\x03")), + ], +) def test_binary_set(input_value, expected_scalar): b = Binary() b.set(input_value) diff --git a/tests/scalar/bool.py b/tests/scalar/bool.py index 6491bc9..47d2b6a 100644 --- a/tests/scalar/bool.py +++ b/tests/scalar/bool.py @@ -2,12 +2,15 @@ from cloudquery.sdk.scalar import Bool -@pytest.mark.parametrize("input_value,expected_scalar", [ - (True, Bool(True, True)), - (False, Bool(True, False)), - ("true", Bool(True, True)), - ("false", Bool(True, False)), -]) +@pytest.mark.parametrize( + "input_value,expected_scalar", + [ + (True, Bool(True, True)), + (False, Bool(True, False)), + ("true", Bool(True, True)), + ("false", Bool(True, False)), + ], +) def test_binary_set(input_value, expected_scalar): b = Bool() b.set(input_value) diff --git a/tests/scalar/float64.py b/tests/scalar/float64.py index 87e5fe0..7eb2aef 100644 --- a/tests/scalar/float64.py +++ b/tests/scalar/float64.py @@ -1,10 +1,14 @@ import pytest from cloudquery.sdk.scalar import Float64 -@pytest.mark.parametrize("input_value,expected_scalar", [ - (1, Float64(True, float(1))), - ("1", Float64(True, float(1))), -]) + +@pytest.mark.parametrize( + "input_value,expected_scalar", + [ + (1, Float64(True, float(1))), + ("1", Float64(True, float(1))), + ], +) def test_binary_set(input_value, expected_scalar): b = Float64() b.set(input_value) diff --git a/tests/scalar/int64.py b/tests/scalar/int64.py index 9204027..47becf6 100644 --- a/tests/scalar/int64.py +++ b/tests/scalar/int64.py @@ -1,10 +1,14 @@ import pytest from cloudquery.sdk.scalar import Int64 -@pytest.mark.parametrize("input_value,expected_scalar", [ - (1, Int64(True, float(1))), - ("1", Int64(True, float(1))), -]) + +@pytest.mark.parametrize( + "input_value,expected_scalar", + [ + (1, Int64(True, float(1))), + ("1", Int64(True, float(1))), + ], +) def test_binary_set(input_value, expected_scalar): b = Int64() b.set(input_value) diff --git a/tests/scalar/uuid.py b/tests/scalar/uuid.py index c5d8ca1..7dadfaa 100644 --- a/tests/scalar/uuid.py +++ b/tests/scalar/uuid.py @@ -2,9 +2,16 @@ import uuid from cloudquery.sdk.scalar import UUID -@pytest.mark.parametrize("input_value,expected_scalar", [ - ("550e8400-e29b-41d4-a716-446655440000", UUID(True, uuid.UUID("550e8400-e29b-41d4-a716-446655440000"))), -]) + +@pytest.mark.parametrize( + "input_value,expected_scalar", + [ + ( + "550e8400-e29b-41d4-a716-446655440000", + UUID(True, uuid.UUID("550e8400-e29b-41d4-a716-446655440000")), + ), + ], +) def test_binary_set(input_value, expected_scalar): b = UUID() b.set(input_value) diff --git a/tests/scheduler/scheduler.py b/tests/scheduler/scheduler.py index 8d5a70a..a60321c 100644 --- a/tests/scheduler/scheduler.py +++ b/tests/scheduler/scheduler.py @@ -1,4 +1,3 @@ - from typing import Any, List, Generator import pyarrow as pa import pytest @@ -7,35 +6,39 @@ from cloudquery.sdk.message import SyncMessage from cloudquery.sdk.schema.table import Table + class SchedulerTestTable(Table): def __init__(self): - super().__init__("test_table", [ - Column("test_column", pa.int64()) - ]) + super().__init__("test_table", [Column("test_column", pa.int64())]) + class SchedulerTestChildTable(Table): def __init__(self): - super().__init__("test_child_table", [ - Column("test_child_column", pa.int64()) - ]) + super().__init__("test_child_table", [Column("test_child_column", pa.int64())]) + class SchedulerTestTableResolver(TableResolver): def __init__(self) -> None: - super().__init__(SchedulerTestTable(), child_resolvers=[SchedulerTestChildTableResolver()]) - + super().__init__( + SchedulerTestTable(), child_resolvers=[SchedulerTestChildTableResolver()] + ) + def resolve(self, client, parent_resource) -> Generator[Any, None, None]: - yield {"test_column": 1} + yield {"test_column": 1} + class SchedulerTestChildTableResolver(TableResolver): def __init__(self) -> None: super().__init__(SchedulerTestChildTable()) - + def resolve(self, client, parent_resource) -> Generator[Any, None, None]: - yield {"test_child_column": 2} + yield {"test_child_column": 2} + class TestClient: pass + def test_scheduler(): client = TestClient() s = Scheduler(10) diff --git a/tests/serve/plugin.py b/tests/serve/plugin.py index 2dce255..9270620 100644 --- a/tests/serve/plugin.py +++ b/tests/serve/plugin.py @@ -1,4 +1,3 @@ - import random import grpc import time @@ -17,20 +16,20 @@ def test_plugin_serve(): pool.submit(cmd.run, ["serve", "--address", f"[::]:{port}"]) time.sleep(1) try: - with grpc.insecure_channel(f'localhost:{port}') as channel: - stub = plugin_pb2_grpc.PluginStub(channel) - response = stub.GetName(plugin_pb2.GetName.Request()) - assert response.name == "memdb" - - response = stub.GetVersion(plugin_pb2.GetVersion.Request()) - assert response.version == "development" + with grpc.insecure_channel(f"localhost:{port}") as channel: + stub = plugin_pb2_grpc.PluginStub(channel) + response = stub.GetName(plugin_pb2.GetName.Request()) + assert response.name == "memdb" + + response = stub.GetVersion(plugin_pb2.GetVersion.Request()) + assert response.version == "development" - response = stub.Init(plugin_pb2.Init.Request(spec=b"")) - assert response is not None + response = stub.Init(plugin_pb2.Init.Request(spec=b"")) + assert response is not None - response = stub.GetTables(plugin_pb2.GetTables.Request()) - print(response.tables) - assert response.tables is not None + response = stub.GetTables(plugin_pb2.GetTables.Request()) + print(response.tables) + assert response.tables is not None finally: - cmd.stop() - pool.shutdown() + cmd.stop() + pool.shutdown()