1- from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, cast
2- from typing_extensions import NotRequired, Literal
3- import pprint
4- import os
1+ from typing import Dict, List, Optional, TypeVar, Union, cast
2+ from typing_extensions import Literal
3+
54import baml_py
6- from pydantic import BaseModel, ValidationError, create_model
75
8- from . import partial_types, types
6+ from . import _baml
97from .types import Checked, Check
10- from .type_builder import TypeBuilder
118from .parser import LlmResponseParser, LlmStreamParser
129from .async_request import AsyncHttpRequest, AsyncHttpStreamRequest
1310from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME
1411
15- OutputType = TypeVar('OutputType')
16-
1712
18- # Define the TypedDict with optional parameters having default values
19- class BamlCallOptions(TypedDict, total=False):
20- tb: NotRequired[TypeBuilder]
21- client_registry: NotRequired[baml_py.baml_py.ClientRegistry]
22- collector: NotRequired[Union[baml_py.baml_py.Collector, List[baml_py.baml_py.Collector]]]
23- env: NotRequired[Dict[str, Optional[str]]]
24-
25- def env_vars_to_dict(overrides: Dict[str, Optional[str]]) -> Dict[str, str]:
26- base = os.environ.copy()
27- for k, v in overrides.items():
28- if v is not None:
29- base[k] = v
30- else:
31- base.pop(k, None)
32- return base
13+ OutputType = TypeVar('OutputType')
3314
3415
3516class BamlAsyncClient:
3617 __runtime: baml_py.BamlRuntime
3718 __ctx_manager: baml_py.BamlCtxManager
3819 __stream_client: "BamlStreamClient"
39- __http_request: " AsyncHttpRequest"
40- __http_stream_request: " AsyncHttpStreamRequest"
20+ __http_request: AsyncHttpRequest
21+ __http_stream_request: AsyncHttpStreamRequest
4122 __llm_response_parser: LlmResponseParser
4223 __llm_stream_parser: LlmStreamParser
43- __baml_options: BamlCallOptions
24+ __baml_options: _baml. BamlCallOptions
4425
45- def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[BamlCallOptions] = None):
26+ def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[_baml. BamlCallOptions] = None):
4627 self.__runtime = runtime
4728 self.__ctx_manager = ctx_manager
4829 self.__stream_client = BamlStreamClient(self.__runtime, self.__ctx_manager, baml_options)
@@ -54,7 +35,7 @@ class BamlAsyncClient:
5435
5536 def with_options(
5637 self,
57- tb: Optional[TypeBuilder] = None,
38+ tb: Optional[_baml.type_builder. TypeBuilder] = None,
5839 client_registry: Optional[baml_py.baml_py.ClientRegistry] = None,
5940 collector: Optional[Union[baml_py.baml_py.Collector, List[baml_py.baml_py.Collector]]] = None,
6041 env: Optional[Dict[str, Optional[str]]] = None,
@@ -103,9 +84,9 @@ class BamlAsyncClient:
10384 {% for (name , type , default_value ) in fn .args -%}
10485 {{name}}: {{type}}{% if let Some (d ) = default_value %} = {{d}}{% endif %} ,
10586 {% - endfor %}
106- baml_options: BamlCallOptions = {},
87+ baml_options: _baml. BamlCallOptions = {},
10788 ) -> {{fn.return_type}}:
108- options: BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
89+ options: _baml. BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
10990
11091 __tb__ = options.get("tb", None)
11192 if __tb__ is not None:
@@ -115,7 +96,7 @@ class BamlAsyncClient:
11596 __cr__ = options.get("client_registry", None)
11697 collector = options.get("collector", None)
11798 collectors = collector if isinstance(collector, list) else [collector] if collector is not None else []
118- env = env_vars_to_dict(options.get("env", {}))
99+ env = _baml. env_vars_to_dict(options.get("env", {}))
119100 raw = await self.__runtime.call_function(
120101 "{{fn.name}}",
121102 {
@@ -129,15 +110,15 @@ class BamlAsyncClient:
129110 collectors,
130111 env,
131112 )
132- return cast({{fn.return_type}}, raw.cast_to(types, types, partial_types, False))
113+ return cast({{fn.return_type}}, raw.cast_to(_baml. types, _baml. types, _baml. partial_types, False))
133114 {% endfor %}
134115
135116
136117class BamlStreamClient:
137118 __runtime: baml_py.BamlRuntime
138119 __ctx_manager: baml_py.BamlCtxManager
139- __baml_options: BamlCallOptions
140- def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[BamlCallOptions] = None):
120+ __baml_options: _baml. BamlCallOptions
121+ def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[_baml. BamlCallOptions] = None):
141122 self.__runtime = runtime
142123 self.__ctx_manager = ctx_manager
143124 self.__baml_options = baml_options or {}
@@ -148,9 +129,9 @@ class BamlStreamClient:
148129 {% for (name , type , default_value ) in fn .args -%}
149130 {{name}}: {{type}}{% if let Some (d ) = default_value %} = {{d}}{% endif %} ,
150131 {% - endfor %}
151- baml_options: BamlCallOptions = {},
132+ baml_options: _baml. BamlCallOptions = {},
152133 ) -> baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}]:
153- options: BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
134+ options: _baml. BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
154135 __tb__ = options.get("tb", None)
155136 if __tb__ is not None:
156137 tb = __tb__._tb # type: ignore (we know how to use this private attribute)
@@ -159,7 +140,7 @@ class BamlStreamClient:
159140 __cr__ = options.get("client_registry", None)
160141 collector = options.get("collector", None)
161142 collectors = collector if isinstance(collector, list) else [collector] if collector is not None else []
162- env = env_vars_to_dict(options.get("env", {}))
143+ env = _baml. env_vars_to_dict(options.get("env", {}))
163144 raw = self.__runtime.stream_function(
164145 "{{fn.name}}",
165146 {
@@ -177,8 +158,8 @@ class BamlStreamClient:
177158
178159 return baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}](
179160 raw,
180- lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, types, partial_types, True)),
181- lambda x: cast({{fn.return_type}}, x.cast_to(types, types, partial_types, False)),
161+ lambda x: cast({{fn.partial_return_type}}, x.cast_to(_baml. types, _baml. types, _baml. partial_types, True)),
162+ lambda x: cast({{fn.return_type}}, x.cast_to(_baml. types, _baml. types, _baml. partial_types, False)),
182163 self.__ctx_manager.get(),
183164 )
184165 {% endfor %}
0 commit comments