Skip to content

Commit 7736f22

Browse files
authored
Fix Tracing bug and LSP Crash (#2030)
Fix an issue where BAML_LOG wouldn't be printed to the console in async contexts in python due to the context stack not being propagated correctly. Log an error when we run into issues with the tracing spans Fix an issue where LSP would crash when sending diagnostics for a baml file not in a baml_src.
1 parent 97b2408 commit 7736f22

57 files changed

Lines changed: 4185 additions & 607 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

engine/baml-lib/baml/tests/validation_files/prompt_fiddle_example.baml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ client<llm> Claude {
199199
client<llm> Gemini {
200200
provider google-ai
201201
options {
202-
model "gemini-1.5-pro-001"
202+
model "gemini-1.5-pro"
203203
api_key env.GOOGLE_API_KEY
204204
}
205205
}

engine/baml-runtime/src/lib.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ impl BamlRuntime {
591591
.finish_call(call, ctx, None)
592592
{
593593
Ok(id) => {}
594-
Err(e) => log::debug!("Error during logging: {}", e),
594+
Err(e) => baml_log::error!("Error during logging: {}", e),
595595
}
596596
#[cfg(target_arch = "wasm32")]
597597
match self
@@ -601,7 +601,7 @@ impl BamlRuntime {
601601
.await
602602
{
603603
Ok(id) => {}
604-
Err(e) => log::debug!("Error during logging: {}", e),
604+
Err(e) => log::error!("Error during logging: {}", e),
605605
}
606606
}
607607

@@ -689,6 +689,7 @@ impl BamlRuntime {
689689

690690
log::trace!("Calling function: {}", function_name);
691691
log::debug!("collectors: {:#?}", &collectors);
692+
692693
let call = self
693694
.tracer_wrapper
694695
.get_or_create_tracer(&env_vars)
@@ -854,7 +855,7 @@ impl BamlRuntime {
854855
.finish_baml_call(call, ctx, &response)
855856
{
856857
Ok(id) => {}
857-
Err(e) => log::debug!("Error during logging: {}", e),
858+
Err(e) => baml_log::error!("Error during logging: {}", e),
858859
}
859860
#[cfg(target_arch = "wasm32")]
860861
match self
@@ -864,7 +865,7 @@ impl BamlRuntime {
864865
.await
865866
{
866867
Ok(id) => {}
867-
Err(e) => log::debug!("Error during logging: {}", e),
868+
Err(e) => log::error!("Error during logging: {}", e),
868869
}
869870

870871
(response, curr_call_id)

engine/baml-runtime/src/tracing/mod.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ cfg_if! {
4949
#[derive(Debug, Clone)]
5050
pub struct TracingCall {
5151
pub call_id: Uuid,
52+
pub function_name: String,
5253
pub new_call_id_stack: Vec<baml_ids::FunctionCallId>,
5354
params: BamlMap<String, BamlValue>,
5455
start_time: web_time::SystemTime,
@@ -406,16 +407,25 @@ impl BamlTracer {
406407
self.trace_stats.guard().start();
407408
let (call_id, call_stack, last_tags, global_tags) = ctx.enter(function_name);
408409

409-
log::trace!(" Entering call {:#?} in {:?}", call_id, function_name);
410+
log::trace!(
411+
"\n{}------------------- Entering {:?}, ctx chain {:#?}",
412+
" ".repeat(ctx.context_depth()),
413+
function_name,
414+
ctx
415+
);
416+
410417
let call = TracingCall {
411418
call_id,
419+
function_name: function_name.to_string(),
412420
new_call_id_stack: call_stack.clone(),
413421
params: params.clone(),
414422
start_time: web_time::SystemTime::now(),
415423
// Note these tags are the ones currently on the stack. While the function runs we may register
416424
// more tags with set_tags(). Those are picked up via a diff event (SetTags)
417425
tags: last_tags.clone(),
418426
};
427+
// println!("---- {} ctx {:#?}", function_name, ctx);
428+
// baml_log::info!("---- {} ctx {:#?}", function_name, ctx);
419429

420430
// This must happen before the first event is sent.
421431
if let Some(collectors) = collectors {
@@ -506,8 +516,9 @@ impl BamlTracer {
506516
);
507517
};
508518
log::trace!(
509-
"Finishing call: {:#?} {}\nevent chain {:?}",
510-
call,
519+
"\n{}------------------- Finishing call: {:#?} {}\nevent chain {:#?}",
520+
" ".repeat(ctx.context_depth()),
521+
call.function_name,
511522
call_id,
512523
event_chain
513524
);
@@ -626,8 +637,8 @@ impl BamlTracer {
626637
};
627638

628639
log::trace!(
629-
"Finishing baml call: {:#?} {}\nevent chain {:?}",
630-
call,
640+
"Finishing baml call: {:#?} {}\nevent chain {:#?}",
641+
call.function_name,
631642
call_id,
632643
event_chain
633644
);

engine/baml-runtime/src/tracingv2/publisher/publisher.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,10 @@ impl TracePublisher {
469469
Ok(())
470470
}
471471
Err(e) => {
472-
baml_log::error!("Failed to upload batch of {} events: {}", batch.len(), e);
472+
log::info!("Failed to upload batch of {} events: {}", batch.len(), e);
473473
// If batch size is at or below minimum, give up
474474
if batch.len() <= min_batch_size {
475-
baml_log::error!(
475+
log::info!(
476476
"Failed to upload single/minimum batch of {} events: {}",
477477
batch.len(),
478478
e
@@ -507,15 +507,15 @@ impl TracePublisher {
507507
Ok(())
508508
}
509509
(Err(e1), Ok(())) => {
510-
baml_log::error!("First half failed: {}", e1);
510+
log::info!("First half failed: {}", e1);
511511
Err(e1)
512512
}
513513
(Ok(()), Err(e2)) => {
514-
baml_log::error!("Second half failed: {}", e2);
514+
log::info!("Second half failed: {}", e2);
515515
Err(e2)
516516
}
517517
(Err(e1), Err(e2)) => {
518-
baml_log::debug!("Both halves failed - first: {}, second: {}", e1, e2);
518+
log::debug!("Both halves failed - first: {}, second: {}", e1, e2);
519519
Err(e1) // Return the first error
520520
}
521521
}
@@ -574,7 +574,7 @@ impl TracePublisher {
574574
{
575575
Ok(response) => response,
576576
Err(e) => {
577-
baml_log::debug!("Failed to upload trace events: {}", e);
577+
log::debug!("Failed to upload trace events: {}", e);
578578
return Err(e.into());
579579
}
580580
};

engine/baml-runtime/src/tracingv2/storage/storage.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ fn build_function_log(
264264
// Build each LLMCall or LLMStreamCall
265265
let mut calls = Vec::new();
266266
for (_rid, call_acc) in calls_map {
267-
println!("### _rid: {:?}", _rid);
268267
let (client, provider) = parse_llm_client_and_provider(call_acc.llm_request.as_ref());
269268
let start_t = call_acc.timestamp_first_seen.unwrap_or(start_ms);
270269
let end_t = call_acc.timestamp_last_seen.unwrap_or(start_t);

engine/language_client_codegen/src/python/templates/async_client.py.j2

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ from typing_extensions import Literal
44
import baml_py
55

66
from . import _baml
7+
from ._baml import BamlCallOptions
78
from .types import Checked, Check
89
from .parser import LlmResponseParser, LlmStreamParser
910
from .async_request import AsyncHttpRequest, AsyncHttpStreamRequest
@@ -104,7 +105,7 @@ class BamlAsyncClient:
104105
"{{name}}": {{name}},
105106
{%- endfor %}
106107
},
107-
self.__ctx_manager.get(),
108+
self.__ctx_manager.clone_context(),
108109
tb,
109110
__cr__,
110111
collectors,
@@ -167,4 +168,4 @@ class BamlStreamClient:
167168

168169
b = BamlAsyncClient(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
169170

170-
__all__ = ["b"]
171+
__all__ = ["b", "BamlCallOptions"]

engine/language_client_codegen/src/python/templates/sync_client.py.j2

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ from typing_extensions import Literal
44
import baml_py
55

66
from . import _baml
7+
from ._baml import BamlCallOptions
78
from .types import Checked, Check
89
from .parser import LlmResponseParser, LlmStreamParser
910
from .sync_request import HttpRequest, HttpStreamRequest
@@ -166,4 +167,4 @@ class BamlStreamClient:
166167

167168
b = BamlSyncClient(DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX)
168169

169-
__all__ = ["b"]
170+
__all__ = ["b", "BamlCallOptions"]

engine/language_client_python/python_src/baml_py/ctx_manager.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .baml_py import BamlLogEvent, RuntimeContextManager, BamlRuntime, BamlSpan
1111
import atexit
1212
import threading
13+
from typing import Dict
1314

1415
F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any])
1516

@@ -26,11 +27,28 @@ def current_thread_id() -> int:
2627
return current_thread.ident or 0
2728

2829

30+
prev_ctx_manager: typing.Optional["CtxManager"] = None
31+
32+
2933
class CtxManager:
34+
def __new__(cls, *args, **kwargs):
35+
if prev_ctx_manager is not None:
36+
return prev_ctx_manager
37+
return super().__new__(cls)
38+
3039
def __init__(self, rt: BamlRuntime):
40+
global prev_ctx_manager
41+
if prev_ctx_manager is not None:
42+
self.rt = prev_ctx_manager.rt
43+
self.ctx = prev_ctx_manager.ctx
44+
return
45+
46+
prev_ctx_manager = self
47+
3148
self.rt = rt
49+
3250
self.ctx = contextvars.ContextVar[typing.Dict[int, RuntimeContextManager]](
33-
"baml_ctx", default={current_thread_id(): rt.create_context_manager()}
51+
"baml_ctx", default={}
3452
)
3553
atexit.register(self.rt.flush)
3654

@@ -71,20 +89,35 @@ def get(self) -> RuntimeContextManager:
7189
return self.__ctx()
7290

7391
def start_trace_sync(
74-
self, name: str, args: typing.Dict[str, typing.Any], env_vars: typing.Dict[str, str]
92+
self,
93+
name: str,
94+
args: typing.Dict[str, typing.Any],
95+
env_vars: typing.Dict[str, str],
7596
) -> BamlSpan:
97+
# Clone the current context before creating the span
7698
mng = self.__ctx()
7799
return BamlSpan.new(self.rt, name, args, mng, env_vars)
78100

79101
def start_trace_async(
80-
self, name: str, args: typing.Dict[str, typing.Any], env_vars: typing.Dict[str, str]
102+
self,
103+
name: str,
104+
args: typing.Dict[str, typing.Any],
105+
env_vars: typing.Dict[str, str],
81106
) -> BamlSpan:
82107
mng = self.__ctx()
83108
cln = mng.deep_clone()
84109
self.ctx.set({current_thread_id(): cln})
85110
return BamlSpan.new(self.rt, name, args, cln, env_vars)
86111

87-
def end_trace(self, span: BamlSpan, response: typing.Any, env_vars: typing.Dict[str, str]) -> None:
112+
def clone_context(self) -> RuntimeContextManager:
113+
mng = self.__ctx()
114+
cln = mng.deep_clone()
115+
self.ctx.set({current_thread_id(): cln})
116+
return cln
117+
118+
def end_trace(
119+
self, span: BamlSpan, response: typing.Any, env_vars: typing.Dict[str, str]
120+
) -> None:
88121
span.finish(response, self.__ctx(), env_vars)
89122

90123
def flush(self) -> None:

engine/language_server/src/server/api/diagnostics.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ pub fn publish_session_lsp_diagnostics(
8787
) -> Result<()> {
8888
// let keys = session.index().documents.keys();
8989
let path = file_url.to_file_path().unwrap_or(PathBuf::new());
90+
if !file_url.to_string().contains("baml_src") {
91+
return Ok(());
92+
}
9093
let project = session
9194
.get_or_create_project(&path)
9295
.expect("We just ensured the session is valid");

integ-tests/baml_src/clients.baml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ client<llm> GPT35AzureFailed {
139139
api_key env.AZURE_OPENAI_API_KEY
140140
}
141141
}
142-
142+
143143
client<llm> Gemini {
144144
provider google-ai
145145
options {
146-
model gemini-1.5-pro-001
146+
model gemini-1.5-pro
147147
api_key env.GOOGLE_API_KEY
148148
safetySettings {
149149
category HARM_CATEGORY_HATE_SPEECH

0 commit comments

Comments
 (0)