Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/datachain/delta.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import hashlib
from collections.abc import Sequence
from copy import copy
from functools import wraps
from typing import TYPE_CHECKING, TypeVar

from attrs import frozen

import datachain
from datachain.dataset import DatasetDependency, DatasetRecord
from datachain.error import DatasetNotFoundError
from datachain.project import Project
from datachain.query.dataset import Step, step_result

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Concatenate

from typing_extensions import ParamSpec

from datachain.catalog import Catalog
from datachain.lib.dc import DataChain
from datachain.query.dataset import QueryGenerator

P = ParamSpec("P")

Expand Down Expand Up @@ -43,11 +49,38 @@ def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T:
return _inner


@frozen
class _RegenerateSystemColumnsStep(Step):
catalog: "Catalog"

def hash_inputs(self) -> str:
return hashlib.sha256(b"regenerate_sys_columns").hexdigest()

def apply(self, query_generator: "QueryGenerator", temp_tables: list[str]):
selectable = query_generator.select()
regenerated = self.catalog.warehouse._regenerate_system_columns(
selectable,
keep_existing_columns=True,
regenerate_columns=None,
)

def q(*columns):
return regenerated.with_only_columns(*columns)

return step_result(q, regenerated.selected_columns)


def _append_steps(dc: "DataChain", other: "DataChain"):
"""Returns cloned chain with appended steps from other chain.
Steps are all those modification methods applied like filters, mappers etc.
"""
dc = dc.clone()
dc._query.steps.append(
_RegenerateSystemColumnsStep(
catalog=dc.session.catalog,
)
)

dc._query.steps += other._query.steps.copy()
dc.signals_schema = other.signals_schema
return dc
Expand Down
37 changes: 37 additions & 0 deletions tests/func/test_delta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid

import pytest
import regex as re
Expand Down Expand Up @@ -224,6 +225,42 @@ def test_delta_update_unsafe(test_session):
}


def test_delta_replay_regenerates_system_columns(test_session):
source_name = f"regen_source_{uuid.uuid4().hex[:8]}"
result_name = f"regen_result_{uuid.uuid4().hex[:8]}"

dc.read_values(
measurement_id=[1, 2],
err=["", ""],
num=[1, 2],
session=test_session,
).save(source_name)
Comment on lines +228 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding assertions to verify the presence and correctness of regenerated system columns.

Add assertions to confirm that regenerated system columns, such as sys__id, exist and contain correct values after replay.


def build_chain(delta: bool):
read_kwargs = {"session": test_session}
if delta:
read_kwargs.update({"delta": True, "delta_on": "measurement_id"})
return (
dc.read_dataset(source_name, **read_kwargs)
.filter(C.err == "")
.select_except("err")
.map(double=lambda num: num * 2, output=int)
.select_except("num")
)

build_chain(delta=False).save(result_name)

build_chain(delta=True).save(
result_name,
delta=True,
delta_on="measurement_id",
)

assert set(
dc.read_dataset(result_name, session=test_session).to_values("measurement_id")
) == {1, 2}


def test_delta_update_from_storage(test_session, tmp_dir, tmp_path):
ds_name = "delta_ds"
path = tmp_dir.as_uri()
Expand Down
Loading