Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #7785: Make fail-fast return list instead of single element #7908

Closed
wants to merge 8 commits into from
29 changes: 15 additions & 14 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@ def __init__(self, node):
self.node = node


def skip_result(node, message):
thread_id = threading.current_thread().name
return RunResult(
status=RunStatus.Skipped,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)


class BaseRunner(metaclass=ABCMeta):
def __init__(self, config, adapter, node, node_index, num_nodes):
self.config = config
Expand Down Expand Up @@ -299,19 +313,6 @@ def from_run_result(self, result, start_time, timing_info):
failures=result.failures,
)

def skip_result(self, node, message):
thread_id = threading.current_thread().name
return RunResult(
status=RunStatus.Skipped,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)

def compile_and_execute(self, manifest, ctx):
result = None
with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext():
Expand Down Expand Up @@ -486,7 +487,7 @@ def on_skip(self):
)
)

node_result = self.skip_result(self.node, error_message)
node_result = skip_result(self.node, error_message)
return node_result

def do_skip(self, cause=None):
Expand Down
62 changes: 32 additions & 30 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
import os
import time
from pathlib import Path
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from pathlib import Path
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet

from .printer import (
print_run_result_error,
print_run_end_messages,
)

from dbt.task.base import ConfiguredTask
import dbt.exceptions
import dbt.tracking
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
)
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
from dbt.events.contextvars import log_contextvars
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
Formatting,
Expand All @@ -36,24 +28,29 @@
EndRunResult,
NothingToDo,
)
from dbt.events.contextvars import log_contextvars
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import (
DbtInternalError,
NotImplementedError,
DbtRuntimeError,
FailFastError,
)

from dbt.flags import get_flags
from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
)
from dbt.parser.manifest import write_manifest
import dbt.tracking

import dbt.exceptions
from dbt.flags import get_flags
import dbt.utils
from dbt.task.base import ConfiguredTask, skip_result
from dbt.task.printer import (
print_run_result_error,
print_run_end_messages,
)

RESULT_FILE_NAME = "run_results.json"
RUNNING_STATE = DbtProcessState("running")
Expand Down Expand Up @@ -357,14 +354,19 @@ def execute_nodes(self):
fire_event(Formatting(""))

pool = ThreadPool(num_threads)

try:
self.run_queue(pool)

except FailFastError as failure:
self._cancel_connections(pool)
print_run_result_error(failure.result)
raise

# Mark all remaining queued/in-progress nodes as skipped
for node_id in self.job_queue.queued.union(self.job_queue.in_progress):
node = self.manifest.nodes[node_id]
self.node_results.append(skip_result(node, "Skipping due to fail-fast"))

raise
except KeyboardInterrupt:
self._cancel_connections(pool)
print_run_end_messages(self.node_results, keyboard_interrupt=True)
Expand Down Expand Up @@ -483,7 +485,7 @@ def interpret_results(cls, results):
NodeStatus.RuntimeErr,
NodeStatus.Error,
NodeStatus.Fail,
NodeStatus.Skipped, # propogate error message causing skip
NodeStatus.Skipped, # propagate error message causing skip
)
]
return len(failures) == 0
Expand Down Expand Up @@ -567,7 +569,7 @@ def create_schema(relation: BaseRelation) -> None:
create_futures.append(fut)

for create_future in as_completed(create_futures):
# trigger/re-raise any excceptions while creating schemas
# trigger/re-raise any exceptions while creating schemas
create_future.result()

def get_result(self, results, elapsed_time, generated_at):
Expand Down
40 changes: 24 additions & 16 deletions tests/functional/retry/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,6 @@ def test_run_operation(self, project):
results = run_dbt(["retry"], expect_pass=False)
assert {n.unique_id: n.status for n in results.results} == expected_statuses

def test_fail_fast(self, project):
result = run_dbt(["--warn-error", "build", "--fail-fast"], expect_pass=False)

assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

results = run_dbt(["retry"], expect_pass=False)

assert len(results.results) == 1
assert results.results[0].status == RunStatus.Error
assert results.results[0].node.name == "sample_model"

result = run_dbt(["retry", "--fail-fast"], expect_pass=False)
assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

def test_removed_file(self, project):
run_dbt(["build"], expect_pass=False)

Expand All @@ -180,3 +164,27 @@ def test_removed_file_leaf_node(self, project):
rm_file("models", "third_model.sql")
with pytest.raises(ValueError, match="Couldn't find model 'model.test.third_model'"):
run_dbt(["retry"], expect_pass=False)


class TestFailFast:
@pytest.fixture(scope="class")
def models(self):
return {
"second_model.sql": models__second_model,
"sample_model.sql": "-- depends_on: {{ ref('second_model') }}\n"
+ models__sample_model,
"union_model.sql": models__union_model,
}

def test_fail_fast(self, project):
results = run_dbt(["--fail-fast", "run"], expect_pass=False)
assert len(results.results) == 3

results = run_dbt(["retry"], expect_pass=False)
assert len(results.results) == 2

fixed_sql = "select 1 as id, 1 as foo"
write_file(fixed_sql, "models", "sample_model.sql")

results = run_dbt(["retry"])
assert len(results.results) == 2