1010from .baml_py import BamlLogEvent , RuntimeContextManager , BamlRuntime , BamlSpan
1111import atexit
1212import threading
13+ from typing import Dict
1314
1415F = typing .TypeVar ("F" , bound = typing .Callable [..., typing .Any ])
1516
@@ -26,11 +27,28 @@ def current_thread_id() -> int:
2627 return current_thread .ident or 0
2728
2829
30+ prev_ctx_manager : typing .Optional ["CtxManager" ] = None
31+
32+
2933class CtxManager :
34+ def __new__ (cls , * args , ** kwargs ):
35+ if prev_ctx_manager is not None :
36+ return prev_ctx_manager
37+ return super ().__new__ (cls )
38+
3039 def __init__ (self , rt : BamlRuntime ):
40+ global prev_ctx_manager
41+ if prev_ctx_manager is not None :
42+ self .rt = prev_ctx_manager .rt
43+ self .ctx = prev_ctx_manager .ctx
44+ return
45+
46+ prev_ctx_manager = self
47+
3148 self .rt = rt
49+
3250 self .ctx = contextvars .ContextVar [typing .Dict [int , RuntimeContextManager ]](
33- "baml_ctx" , default = {current_thread_id (): rt . create_context_manager () }
51+ "baml_ctx" , default = {}
3452 )
3553 atexit .register (self .rt .flush )
3654
@@ -71,20 +89,35 @@ def get(self) -> RuntimeContextManager:
7189 return self .__ctx ()
7290
7391 def start_trace_sync (
74- self , name : str , args : typing .Dict [str , typing .Any ], env_vars : typing .Dict [str , str ]
92+ self ,
93+ name : str ,
94+ args : typing .Dict [str , typing .Any ],
95+ env_vars : typing .Dict [str , str ],
7596 ) -> BamlSpan :
97+ # Clone the current context before creating the span
7698 mng = self .__ctx ()
7799 return BamlSpan .new (self .rt , name , args , mng , env_vars )
78100
79101 def start_trace_async (
80- self , name : str , args : typing .Dict [str , typing .Any ], env_vars : typing .Dict [str , str ]
102+ self ,
103+ name : str ,
104+ args : typing .Dict [str , typing .Any ],
105+ env_vars : typing .Dict [str , str ],
81106 ) -> BamlSpan :
82107 mng = self .__ctx ()
83108 cln = mng .deep_clone ()
84109 self .ctx .set ({current_thread_id (): cln })
85110 return BamlSpan .new (self .rt , name , args , cln , env_vars )
86111
87- def end_trace (self , span : BamlSpan , response : typing .Any , env_vars : typing .Dict [str , str ]) -> None :
112+ def clone_context (self ) -> RuntimeContextManager :
113+ mng = self .__ctx ()
114+ cln = mng .deep_clone ()
115+ self .ctx .set ({current_thread_id (): cln })
116+ return cln
117+
118+ def end_trace (
119+ self , span : BamlSpan , response : typing .Any , env_vars : typing .Dict [str , str ]
120+ ) -> None :
88121 span .finish (response , self .__ctx (), env_vars )
89122
90123 def flush (self ) -> None :
0 commit comments