Skip to content

Commit

Permalink
Refactor get_stored_state to perform join in memory
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <marcel@mooncoon.com>
  • Loading branch information
Pipboyguy committed Jun 15, 2024
1 parent a11f0e4 commit aff0032
Showing 1 changed file with 84 additions and 38 deletions.
122 changes: 84 additions & 38 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@


TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {
v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()
}


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -179,7 +181,9 @@ def upload_batch(
tbl.add(records, mode="overwrite")
elif write_disposition == "merge":
if not id_field_name:
raise ValueError("To perform a merge update, 'id_field_name' must be specified.")
raise ValueError(
"To perform a merge update, 'id_field_name' must be specified."
)
tbl.merge_insert(
id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(records)
Expand Down Expand Up @@ -251,7 +255,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema:
)

@lancedb_error
def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table:
def create_table(
self, table_name: str, schema: TArrowSchema, mode: str = "create"
) -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
Expand Down Expand Up @@ -354,7 +360,9 @@ def update_stored_schema(
applied_update: TSchemaTables = {}

try:
schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash)
schema_info = self.get_stored_schema_by_hash(
self.schema.stored_version_hash
)
except DestinationUndefinedEntity:
schema_info = None

Expand Down Expand Up @@ -403,35 +411,47 @@ def add_table_fields(

# Check if any of the new fields already exist in the table.
existing_fields = set(arrow_table.schema.names)
new_fields = [field for field in field_schemas if field.name not in existing_fields]
new_fields = [
field for field in field_schemas if field.name not in existing_fields
]

if not new_fields:
# All fields already present, skip.
return None

null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields]
null_arrays = [
pa.nulls(len(arrow_table), type=field.type) for field in new_fields
]

for field, null_array in zip(new_fields, null_arrays):
arrow_table = arrow_table.append_column(field, null_array)

try:
return self.db_client.create_table(table_name, arrow_table, mode="overwrite")
return self.db_client.create_table(
table_name, arrow_table, mode="overwrite"
)
except OSError:
# Error occurred while creating the table, skip.
return None

def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists, existing_columns = self.get_storage_table(table_name)
new_columns = self.schema.get_new_table_columns(table_name, existing_columns)
new_columns = self.schema.get_new_table_columns(
table_name, existing_columns
)
embedding_fields: List[str] = get_columns_names_with_prop(
self.schema.get_table(table_name), VECTORIZE_HINT
)
logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}")
logger.info(
f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}"
)
if len(new_columns) > 0:
if exists:
field_schemas: List[TArrowField] = [
make_arrow_field_schema(column["name"], column, self.type_mapper)
make_arrow_field_schema(
column["name"], column, self.type_mapper
)
for column in new_columns
]
fq_table_name = self.make_qualified_table_name(table_name)
Expand All @@ -444,7 +464,9 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
vector_field_name = self.vector_field_name
id_field_name = self.id_field_name
embedding_model_func = self.model_func
embedding_model_dimensions = self.config.embedding_model_dimensions
embedding_model_dimensions = (
self.config.embedding_model_dimensions
)
else:
embedding_fields = None
vector_field_name = None
Expand Down Expand Up @@ -479,7 +501,9 @@ def update_schema_in_storage(self) -> None:
"schema": json.dumps(self.schema.to_dict()),
}
]
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
write_disposition = self.schema.get_table(self.schema.version_table_name).get(
"write_disposition"
)
Expand All @@ -492,35 +516,49 @@ def update_schema_in_storage(self) -> None:

@lancedb_error
def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
"""Retrieves the latest completed state for a pipeline."""
fq_state_table_name = self.make_qualified_table_name(
self.schema.state_table_name
)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)

state_records = (
# Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less
# data into memory as possible.
state_table = (
self.db_client.open_table(fq_state_table_name)
.search()
.where(f'pipeline_name = "{pipeline_name}" ORDER BY _dlt_load_id DESC')
.to_list()
.where(f"pipeline_name = '{pipeline_name}'", prefilter=True)
.to_arrow()
)
loads_table = (
self.db_client.open_table(fq_loads_table_name)
.search()
.where("status = 0", prefilter=True)
.to_arrow()
)
if len(state_records) == 0:

# Join arrow tables in-memory.
joined_table: pa.Table = state_table.join(
loads_table, keys="_dlt_load_id", right_keys="load_id", join_type="inner"
).sort_by([("_dlt_id", "descending")])

if joined_table.num_rows == 0:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
# If there is a load for this state which was successful, return the state.
if (
self.db_client.open_table(fq_loads_table_name)
.search()
.where(f'load_id = "{load_id}" AND status = 0')
.limit(1)
.to_list()
):
state["dlt_load_id"] = state.pop("_dlt_load_id")
return StateInfo(**{k: v for k, v in state.items() if k in StateInfo._fields})
return None

state = joined_table.take([0]).to_pylist()[0]
state["dlt_load_id"] = state.pop("_dlt_load_id")
state["created_at"] = pendulum.instance(state["created_at"])
return StateInfo(**{k: v for k, v in state.items() if k in StateInfo._fields})

@lancedb_error
def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
def get_stored_schema_by_hash(
self, schema_hash: str
) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)

try:
response = (
Expand All @@ -537,7 +575,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)

try:
response = (
Expand Down Expand Up @@ -573,7 +613,9 @@ def complete_load(self, load_id: str) -> None:
"schema_version_hash": None, # Payload schema must match the target schema.
}
]
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
write_disposition = self.schema.get_table(self.schema.loads_table_name).get(
"write_disposition"
)
Expand All @@ -587,7 +629,9 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -626,7 +670,9 @@ def __init__(
self.table_name: str = table_schema["name"]
self.fq_table_name: str = fq_table_name
self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema)
self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_fields: List[str] = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_model_func: TextEmbeddingFunction = model_func
self.embedding_model_dimensions: int = client_config.embedding_model_dimensions
self.id_field_name: str = client_config.id_field_name
Expand Down

0 comments on commit aff0032

Please sign in to comment.