Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed May 14, 2024
1 parent a9f093b commit 3ba6764
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
1 change: 1 addition & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,4 +605,5 @@ def get(self):
"MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", False
)

#: Whether the current context is in databricks rag serving.
DATABRICKS_RAG_SERVING = _BooleanEnvironmentVariable("DATABRICKS_RAG_SERVING", False)
32 changes: 16 additions & 16 deletions mlflow/langchain/langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_databricks_vector_search_key,
)
from mlflow.pyfunc.context import Context, set_prediction_context
from mlflow.tracing.fluent import get_trace
from mlflow.tracing.export.inference_table import pop_trace
from mlflow.utils.autologging_utils import ExceptionSafeAbstractClass

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +86,7 @@ def _default_converter(o):
return o.isoformat()
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")

trace = get_trace(self._request_id)
trace = pop_trace(self._request_id)
return json.dumps(trace, default=_default_converter)

def _get_span_by_run_id(self, run_id: UUID) -> Optional[LiveSpan]:
Expand All @@ -104,25 +104,25 @@ def _start_span(
attributes: Optional[Dict[str, Any]] = None,
) -> LiveSpan:
"""Start MLflow Span (or Trace if it is root component)"""
parent = self._get_parent_span(parent_run_id)
if parent:
span = self._mlflow_client.start_span(
name=span_name,
request_id=parent.request_id,
parent_id=parent.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
)
else:
# When parent_run_id is None, this is root component so start trace
with set_prediction_context(self._prediction_context):
with set_prediction_context(self._prediction_context):
parent = self._get_parent_span(parent_run_id)
if parent:
span = self._mlflow_client.start_span(
name=span_name,
request_id=parent.request_id,
parent_id=parent.span_id,
span_type=span_type,
inputs=inputs,
attributes=attributes,
)
else:
# When parent_run_id is None, this is root component so start trace
span = self._mlflow_client.start_trace(
name=span_name, span_type=span_type, inputs=inputs, attributes=attributes
)
self._request_id = span.request_id

self._run_span_mapping[str(run_id)] = span
self._run_span_mapping[str(run_id)] = span
return span

def _get_parent_span(self, parent_run_id) -> Optional[LiveSpan]:
Expand Down
5 changes: 3 additions & 2 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
get_databricks_run_url,
is_in_databricks_model_serving_environment,
)
from mlflow.utils.exception_utils import get_stacktrace
from mlflow.utils.logging_utils import eprint
from mlflow.utils.mlflow_tags import (
MLFLOW_LOGGED_ARTIFACTS,
Expand Down Expand Up @@ -592,7 +593,7 @@ def start_trace(

return mlflow_span
except Exception as e:
_logger.warning(f"Failed to start span {name}: {e}")
_logger.warning(f"Failed to start trace {name}: {get_stacktrace(e)}")
if _MLFLOW_TESTING.get():
raise
return NoOpSpan()
Expand Down Expand Up @@ -789,7 +790,7 @@ def start_span(
trace_manager.register_span(span)
return span
except Exception as e:
_logger.warning(f"Failed to start span {name}: {e}")
_logger.warning(f"Failed to start span {name}: {get_stacktrace(e)}")
if _MLFLOW_TESTING.get():
raise
return NoOpSpan()
Expand Down

0 comments on commit 3ba6764

Please sign in to comment.