Skip to content

Commit 505a9e9

Browse files
Reduce namespace pollution in codegen (#2016)
The more relative imports we add to our codegen, the more we pollute the namespace. This PR attempts to use all the internal baml codegen under a single namespace `_baml`. Related issues: - #2003
1 parent 149b742 commit 505a9e9

21 files changed

Lines changed: 16079 additions & 16138 deletions

engine/language_client_codegen/src/python/mod.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ use internal_baml_core::{
1717
use self::python_language_features::{PythonLanguageFeatures, ToPython};
1818
use crate::{dir_writer::FileCollector, field_type_attributes};
1919

20+
#[derive(askama::Template)]
21+
#[template(path = "_baml.py.j2", escape = "none")]
22+
struct BamlNamespace {}
23+
2024
#[derive(askama::Template)]
2125
#[template(path = "config.py.j2", escape = "none")]
2226
struct PythonConfig {}
@@ -127,6 +131,7 @@ pub(crate) fn generate(
127131
.add_template::<generate_types::PythonStreamTypes>("partial_types.py", (ir, generator))?;
128132
collector.add_template::<generate_types::PythonTypes>("types.py", (ir, generator))?;
129133
collector.add_template::<generate_types::TypeBuilder>("type_builder.py", (ir, generator))?;
134+
collector.add_template::<BamlNamespace>("_baml.py", (ir, generator))?;
130135
collector.add_template::<AsyncPythonClient>("async_client.py", (ir, generator))?;
131136
collector.add_template::<SyncPythonClient>("sync_client.py", (ir, generator))?;
132137
collector.add_template::<PythonGlobals>("globals.py", (ir, generator))?;
@@ -149,6 +154,14 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonConfig
149154
}
150155
}
151156

157+
impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for BamlNamespace {
158+
type Error = anyhow::Error;
159+
160+
fn try_from(_: (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result<Self> {
161+
Ok(BamlNamespace {})
162+
}
163+
}
164+
152165
impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonTracing {
153166
type Error = anyhow::Error;
154167

@@ -277,14 +290,14 @@ impl ToTypeReferenceInClientDefinition for FieldType {
277290
.map(|e| e.item.attributes.get("dynamic_type").is_some())
278291
.unwrap_or(false)
279292
{
280-
format!("Union[types.{name}, str]")
293+
format!("Union[_baml.types.{name}, str]")
281294
} else {
282-
format!("types.{name}")
295+
format!("_baml.types.{name}")
283296
}
284297
}
285298
FieldType::Literal(value) => to_python_literal(value),
286-
FieldType::RecursiveTypeAlias(name) => format!("types.{name}"),
287-
FieldType::Class(name) => format!("types.{name}"),
299+
FieldType::RecursiveTypeAlias(name) => format!("_baml.types.{name}"),
300+
FieldType::Class(name) => format!("_baml.types.{name}"),
288301
FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)),
289302
FieldType::Map(key, value) => {
290303
format!("Dict[{}, {}]", key.to_type_ref(ir), value.to_type_ref(ir))
@@ -330,9 +343,9 @@ impl ToTypeReferenceInClientDefinition for FieldType {
330343
let with_state = metadata.1.state;
331344
let constraints = metadata.0;
332345
let module_prefix = if is_partial_type {
333-
"partial_types."
346+
"_baml.partial_types."
334347
} else {
335-
"types."
348+
"_baml.types."
336349
};
337350

338351
let base_rep = match &base_type {
@@ -347,19 +360,19 @@ impl ToTypeReferenceInClientDefinition for FieldType {
347360
// wrap primitives in `Optional` when generating partial types,
348361
// although we should probably only do this when `!needed`.
349362
if false {
350-
format!("Union[types.{name}, str]")
363+
format!("Union[_baml.types.{name}, str]")
351364
} else {
352-
format!("Optional[Union[types.{name}, str]]")
365+
format!("Optional[Union[_baml.types.{name}, str]]")
353366
}
354367
} else {
355368
// Note: The `false` here preserves potentially bugged codegen
356369
// from before this commit. As the `false` implies, we always
357370
// wrap primitives in `Optional` when generating partial types,
358371
// although we should probably only do this when `!needed`.
359372
if false {
360-
format!("types.{name}")
373+
format!("_baml.types.{name}")
361374
} else {
362-
format!("Optional[types.{name}]")
375+
format!("Optional[_baml.types.{name}]")
363376
}
364377
}
365378
}
@@ -372,9 +385,9 @@ impl ToTypeReferenceInClientDefinition for FieldType {
372385
}
373386
FieldType::RecursiveTypeAlias(name) => {
374387
if needed {
375-
format!("types.{name}")
388+
format!("_baml.types.{name}")
376389
} else {
377-
format!("Optional[types.{name}]")
390+
format!("Optional[_baml.types.{name}]")
378391
}
379392
}
380393
FieldType::Literal(value) => {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
"""Private internal Baml namespace and utilities."""
3+
4+
import os
5+
from typing import TypedDict, Dict, Optional, Union, List
6+
from typing_extensions import NotRequired
7+
8+
import baml_py
9+
10+
from . import types
11+
from . import partial_types
12+
from . import type_builder
13+
14+
15+
class BamlCallOptions(TypedDict, total=False):
16+
"""Additional parameters for Baml function calls."""
17+
tb: NotRequired[type_builder.TypeBuilder]
18+
client_registry: NotRequired[baml_py.baml_py.ClientRegistry]
19+
collector: NotRequired[Union[baml_py.baml_py.Collector, List[baml_py.baml_py.Collector]]]
20+
env: NotRequired[Dict[str, Optional[str]]]
21+
22+
23+
class BamlCallOptionsModApi(TypedDict, total=False):
24+
"""Additional parameters for modular API calls (doesn't take a collector)."""
25+
tb: NotRequired[type_builder.TypeBuilder]
26+
client_registry: NotRequired[baml_py.baml_py.ClientRegistry]
27+
env: NotRequired[Dict[str, Optional[str]]]
28+
29+
30+
def env_vars_to_dict(overrides: Dict[str, Optional[str]]) -> Dict[str, str]:
31+
base = os.environ.copy()
32+
for k, v in overrides.items():
33+
if v is not None:
34+
base[k] = v
35+
else:
36+
base.pop(k, None)
37+
return base
38+
39+
40+
__all__ = [
41+
"types",
42+
"partial_types",
43+
"type_builder",
44+
]

engine/language_client_codegen/src/python/templates/async_client.py.j2

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,29 @@
1-
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, cast
2-
from typing_extensions import NotRequired, Literal
3-
import pprint
4-
import os
1+
from typing import Dict, List, Optional, TypeVar, Union, cast
2+
from typing_extensions import Literal
3+
54
import baml_py
6-
from pydantic import BaseModel, ValidationError, create_model
75

8-
from . import partial_types, types
6+
from . import _baml
97
from .types import Checked, Check
10-
from .type_builder import TypeBuilder
118
from .parser import LlmResponseParser, LlmStreamParser
129
from .async_request import AsyncHttpRequest, AsyncHttpStreamRequest
1310
from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME
1411

15-
OutputType = TypeVar('OutputType')
16-
1712

18-
# Define the TypedDict with optional parameters having default values
19-
class BamlCallOptions(TypedDict, total=False):
20-
tb: NotRequired[TypeBuilder]
21-
client_registry: NotRequired[baml_py.baml_py.ClientRegistry]
22-
collector: NotRequired[Union[baml_py.baml_py.Collector, List[baml_py.baml_py.Collector]]]
23-
env: NotRequired[Dict[str, Optional[str]]]
24-
25-
def env_vars_to_dict(overrides: Dict[str, Optional[str]]) -> Dict[str, str]:
26-
base = os.environ.copy()
27-
for k, v in overrides.items():
28-
if v is not None:
29-
base[k] = v
30-
else:
31-
base.pop(k, None)
32-
return base
13+
OutputType = TypeVar('OutputType')
3314

3415

3516
class BamlAsyncClient:
3617
__runtime: baml_py.BamlRuntime
3718
__ctx_manager: baml_py.BamlCtxManager
3819
__stream_client: "BamlStreamClient"
39-
__http_request: "AsyncHttpRequest"
40-
__http_stream_request: "AsyncHttpStreamRequest"
20+
__http_request: AsyncHttpRequest
21+
__http_stream_request: AsyncHttpStreamRequest
4122
__llm_response_parser: LlmResponseParser
4223
__llm_stream_parser: LlmStreamParser
43-
__baml_options: BamlCallOptions
24+
__baml_options: _baml.BamlCallOptions
4425

45-
def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[BamlCallOptions] = None):
26+
def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[_baml.BamlCallOptions] = None):
4627
self.__runtime = runtime
4728
self.__ctx_manager = ctx_manager
4829
self.__stream_client = BamlStreamClient(self.__runtime, self.__ctx_manager, baml_options)
@@ -54,7 +35,7 @@ class BamlAsyncClient:
5435

5536
def with_options(
5637
self,
57-
tb: Optional[TypeBuilder] = None,
38+
tb: Optional[_baml.type_builder.TypeBuilder] = None,
5839
client_registry: Optional[baml_py.baml_py.ClientRegistry] = None,
5940
collector: Optional[Union[baml_py.baml_py.Collector, List[baml_py.baml_py.Collector]]] = None,
6041
env: Optional[Dict[str, Optional[str]]] = None,
@@ -103,9 +84,9 @@ class BamlAsyncClient:
10384
{% for (name, type, default_value) in fn.args -%}
10485
{{name}}: {{type}}{% if let Some(d) = default_value %} = {{d}}{% endif %},
10586
{%- endfor %}
106-
baml_options: BamlCallOptions = {},
87+
baml_options: _baml.BamlCallOptions = {},
10788
) -> {{fn.return_type}}:
108-
options: BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
89+
options: _baml.BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
10990

11091
__tb__ = options.get("tb", None)
11192
if __tb__ is not None:
@@ -115,7 +96,7 @@ class BamlAsyncClient:
11596
__cr__ = options.get("client_registry", None)
11697
collector = options.get("collector", None)
11798
collectors = collector if isinstance(collector, list) else [collector] if collector is not None else []
118-
env = env_vars_to_dict(options.get("env", {}))
99+
env = _baml.env_vars_to_dict(options.get("env", {}))
119100
raw = await self.__runtime.call_function(
120101
"{{fn.name}}",
121102
{
@@ -129,15 +110,15 @@ class BamlAsyncClient:
129110
collectors,
130111
env,
131112
)
132-
return cast({{fn.return_type}}, raw.cast_to(types, types, partial_types, False))
113+
return cast({{fn.return_type}}, raw.cast_to(_baml.types, _baml.types, _baml.partial_types, False))
133114
{% endfor %}
134115

135116

136117
class BamlStreamClient:
137118
__runtime: baml_py.BamlRuntime
138119
__ctx_manager: baml_py.BamlCtxManager
139-
__baml_options: BamlCallOptions
140-
def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[BamlCallOptions] = None):
120+
__baml_options: _baml.BamlCallOptions
121+
def __init__(self, runtime: baml_py.BamlRuntime, ctx_manager: baml_py.BamlCtxManager, baml_options: Optional[_baml.BamlCallOptions] = None):
141122
self.__runtime = runtime
142123
self.__ctx_manager = ctx_manager
143124
self.__baml_options = baml_options or {}
@@ -148,9 +129,9 @@ class BamlStreamClient:
148129
{% for (name, type, default_value) in fn.args -%}
149130
{{name}}: {{type}}{% if let Some(d) = default_value %} = {{d}}{% endif %},
150131
{%- endfor %}
151-
baml_options: BamlCallOptions = {},
132+
baml_options: _baml.BamlCallOptions = {},
152133
) -> baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}]:
153-
options: BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
134+
options: _baml.BamlCallOptions = {**self.__baml_options, **(baml_options or {})}
154135
__tb__ = options.get("tb", None)
155136
if __tb__ is not None:
156137
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
@@ -159,7 +140,7 @@ class BamlStreamClient:
159140
__cr__ = options.get("client_registry", None)
160141
collector = options.get("collector", None)
161142
collectors = collector if isinstance(collector, list) else [collector] if collector is not None else []
162-
env = env_vars_to_dict(options.get("env", {}))
143+
env = _baml.env_vars_to_dict(options.get("env", {}))
163144
raw = self.__runtime.stream_function(
164145
"{{fn.name}}",
165146
{
@@ -177,8 +158,8 @@ class BamlStreamClient:
177158

178159
return baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}](
179160
raw,
180-
lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, types, partial_types, True)),
181-
lambda x: cast({{fn.return_type}}, x.cast_to(types, types, partial_types, False)),
161+
lambda x: cast({{fn.partial_return_type}}, x.cast_to(_baml.types, _baml.types, _baml.partial_types, True)),
162+
lambda x: cast({{fn.return_type}}, x.cast_to(_baml.types, _baml.types, _baml.partial_types, False)),
182163
self.__ctx_manager.get(),
183164
)
184165
{% endfor %}

engine/language_client_codegen/src/python/templates/async_request.py.j2

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,11 @@
1-
from typing import Any, Dict, List, Optional, Union, TypedDict, Type
2-
from typing_extensions import NotRequired, Literal
3-
import os
1+
from typing import Dict, List, Optional, Union
2+
from typing_extensions import Literal
43

54
import baml_py
65

7-
from . import types
8-
from .types import Checked, Check
9-
from .type_builder import TypeBuilder
6+
from . import _baml
107

118

12-
class BamlCallOptions(TypedDict, total=False):
13-
tb: NotRequired[TypeBuilder]
14-
client_registry: NotRequired[baml_py.baml_py.ClientRegistry]
15-
env: NotRequired[Dict[str, Optional[str]]]
16-
17-
def env_vars_to_dict(overrides: Dict[str, Optional[str]]) -> Dict[str, str]:
18-
base = os.environ.copy()
19-
for k, v in overrides.items():
20-
if v is not None:
21-
base[k] = v
22-
else:
23-
base.pop(k, None)
24-
return base
25-
269
class AsyncHttpRequest:
2710
__runtime: baml_py.BamlRuntime
2811
__ctx_manager: baml_py.BamlCtxManager
@@ -37,15 +20,15 @@ class AsyncHttpRequest:
3720
{% for (name, type, default_value) in fn.args -%}
3821
{{name}}: {{type}}{% if let Some(d) = default_value %} = {{d}}{% endif %},
3922
{%- endfor %}
40-
baml_options: BamlCallOptions = {},
23+
baml_options: _baml.BamlCallOptionsModApi = {},
4124
) -> baml_py.HTTPRequest:
4225
__tb__ = baml_options.get("tb", None)
4326
if __tb__ is not None:
4427
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
4528
else:
4629
tb = None
4730
__cr__ = baml_options.get("client_registry", None)
48-
env = env_vars_to_dict(baml_options.get("env", {}))
31+
env = _baml.env_vars_to_dict(baml_options.get("env", {}))
4932

5033
return await self.__runtime.build_request(
5134
"{{fn.name}}",
@@ -77,15 +60,15 @@ class AsyncHttpStreamRequest:
7760
{% for (name, type, default_value) in fn.args -%}
7861
{{name}}: {{type}}{% if let Some(d) = default_value %} = {{d}}{% endif %},
7962
{%- endfor %}
80-
baml_options: BamlCallOptions = {},
63+
baml_options: _baml.BamlCallOptionsModApi = {},
8164
) -> baml_py.HTTPRequest:
8265
__tb__ = baml_options.get("tb", None)
8366
if __tb__ is not None:
8467
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
8568
else:
8669
tb = None
8770
__cr__ = baml_options.get("client_registry", None)
88-
env = env_vars_to_dict(baml_options.get("env", {}))
71+
env = _baml.env_vars_to_dict(baml_options.get("env", {}))
8972

9073
return await self.__runtime.build_request(
9174
"{{fn.name}}",

0 commit comments

Comments
 (0)