From 786bafdadac590d1154bed945e68f44c91cc15f2 Mon Sep 17 00:00:00 2001 From: Rishab Ramanathan Date: Thu, 11 Sep 2025 12:25:52 -0700 Subject: [PATCH] fix: custom args with run_batch_from_df --- src/openlayer/lib/core/base_model.py | 10 ++-------- src/openlayer/lib/integrations/bedrock_tracer.py | 4 +--- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/openlayer/lib/core/base_model.py b/src/openlayer/lib/core/base_model.py index c69fabcb..9bd25a45 100644 --- a/src/openlayer/lib/core/base_model.py +++ b/src/openlayer/lib/core/base_model.py @@ -91,12 +91,10 @@ def batch(self, dataset_path: str, output_dir: str) -> None: raise ValueError(f"Unsupported dataset format: {dataset_path}") # Call the model's run_batch method, passing in the DataFrame - output_df, config = self.run_batch_from_df(df, custom_args=self.custom_args) + output_df, config = self.run_batch_from_df(df) self.write_output_to_directory(output_df, config, output_dir, fmt) - def run_batch_from_df( - self, df: pd.DataFrame, custom_args: dict = None - ) -> Tuple[pd.DataFrame, dict]: + def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: """Function that runs the model and returns the result.""" # Ensure the 'output' column exists if "output" not in df.columns: @@ -105,10 +103,6 @@ def run_batch_from_df( # Get the signature of the 'run' method run_signature = inspect.signature(self.run) - # If the model has a custom_args attribute, update it - if hasattr(self, "custom_args") and custom_args is not None: - self.custom_args.update(custom_args) - for index, row in df.iterrows(): # Filter row_dict to only include keys that are valid parameters # for the 'run' method diff --git a/src/openlayer/lib/integrations/bedrock_tracer.py b/src/openlayer/lib/integrations/bedrock_tracer.py index 336d7cda..a497474e 100644 --- a/src/openlayer/lib/integrations/bedrock_tracer.py +++ b/src/openlayer/lib/integrations/bedrock_tracer.py @@ -25,9 +25,7 @@ logger = logging.getLogger(__name__) -def trace_bedrock( - client: "boto3.client", -) -> "boto3.client": +def trace_bedrock(client: "boto3.client") -> "boto3.client": """Patch the Bedrock client to trace model invocations. The following information is collected for each model invocation: