Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: fix
packages:
- "@typespec/http-client-python"
---

Fix import when body parameter is union of models
6 changes: 5 additions & 1 deletion packages/http-client-python/eng/scripts/ci/regenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const argv = parseArgs({
});

// Add this near the top with other constants
const SKIP_SPECS = ["type/union/discriminated"];
const SKIP_SPECS = [];

// Get the directory of the current file
const PLUGIN_DIR = argv.values.pluginDir
Expand Down Expand Up @@ -272,6 +272,10 @@ const EMITTER_OPTIONS: Record<string, Record<string, string> | Record<string, st
"package-name": "typetest-union",
namespace: "typetest.union",
},
"type/union/discriminated": {
"package-name": "typetest-union-discriminated",
namespace: "typetest.union.discriminated",
},
};

function toPosix(dir: string): string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
file_import.merge(self.get_request_builder_import(self.request_builder, async_mode, serialize_namespace))
if self.overloads:
file_import.add_submodule_import("typing", "overload", ImportType.STDLIB)
for overload in self.overloads:
if overload.parameters.has_body:
file_import.merge(overload.parameters.body_parameter.type.imports(**kwargs))
if self.code_model.options["models-mode"] == "dpg":
relative_path = self.code_model.get_relative_import_path(
serialize_namespace, module_name="_utils.model_base"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,14 @@ def _initialize_overloads(self, builder: OperationType, is_paging: bool = False)
client_names = [
overload.request_builder.parameters.body_parameter.client_name for overload in builder.overloads
]
for v in sorted(set(client_names), key=client_names.index):
retval.append(f"_{v} = None")
all_dpg_model_overloads = False
if self.code_model.options["models-mode"] == "dpg" and builder.overloads:
all_dpg_model_overloads = all(
isinstance(o.parameters.body_parameter.type, DPGModelType) for o in builder.overloads
)
if not all_dpg_model_overloads:
for v in sorted(set(client_names), key=client_names.index):
retval.append(f"_{v} = None")
try:
# if there is a binary overload, we do a binary check first.
binary_overload = cast(
Expand All @@ -803,17 +809,20 @@ def _initialize_overloads(self, builder: OperationType, is_paging: bool = False)
f'"{other_overload.parameters.body_parameter.default_content_type}"{check_body_suffix}'
)
except StopIteration:
for idx, overload in enumerate(builder.overloads):
if_statement = "if" if idx == 0 else "elif"
body_param = overload.parameters.body_parameter
retval.append(
f"{if_statement} {body_param.type.instance_check_template.format(body_param.client_name)}:"
)
if body_param.default_content_type and not same_content_type:
if all_dpg_model_overloads:
retval.extend(f"{l}" for l in self._create_body_parameter(cast(OperationType, builder.overloads[0])))
else:
for idx, overload in enumerate(builder.overloads):
if_statement = "if" if idx == 0 else "elif"
body_param = overload.parameters.body_parameter
retval.append(
f' content_type = content_type or "{body_param.default_content_type}"{check_body_suffix}'
f"{if_statement} {body_param.type.instance_check_template.format(body_param.client_name)}:"
)
retval.extend(f" {l}" for l in self._create_body_parameter(cast(OperationType, overload)))
if body_param.default_content_type and not same_content_type:
retval.append(
f' content_type = content_type or "{body_param.default_content_type}"{check_body_suffix}'
)
retval.extend(f" {l}" for l in self._create_body_parameter(cast(OperationType, overload)))
return retval

def _create_request_builder_call(
Expand Down
Loading