diff --git a/src/openlayer/lib/core/base_model.py b/src/openlayer/lib/core/base_model.py index c69fabcb..9bd25a45 100644 --- a/src/openlayer/lib/core/base_model.py +++ b/src/openlayer/lib/core/base_model.py @@ -91,12 +91,10 @@ def batch(self, dataset_path: str, output_dir: str) -> None: raise ValueError(f"Unsupported dataset format: {dataset_path}") # Call the model's run_batch method, passing in the DataFrame - output_df, config = self.run_batch_from_df(df, custom_args=self.custom_args) + output_df, config = self.run_batch_from_df(df) self.write_output_to_directory(output_df, config, output_dir, fmt) - def run_batch_from_df( - self, df: pd.DataFrame, custom_args: dict = None - ) -> Tuple[pd.DataFrame, dict]: + def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: """Function that runs the model and returns the result.""" # Ensure the 'output' column exists if "output" not in df.columns: @@ -105,10 +103,6 @@ def run_batch_from_df( # Get the signature of the 'run' method run_signature = inspect.signature(self.run) - # If the model has a custom_args attribute, update it - if hasattr(self, "custom_args") and custom_args is not None: - self.custom_args.update(custom_args) - for index, row in df.iterrows(): # Filter row_dict to only include keys that are valid parameters # for the 'run' method diff --git a/src/openlayer/lib/integrations/bedrock_tracer.py b/src/openlayer/lib/integrations/bedrock_tracer.py index 336d7cda..a497474e 100644 --- a/src/openlayer/lib/integrations/bedrock_tracer.py +++ b/src/openlayer/lib/integrations/bedrock_tracer.py @@ -25,9 +25,7 @@ logger = logging.getLogger(__name__) -def trace_bedrock( - client: "boto3.client", -) -> "boto3.client": +def trace_bedrock(client: "boto3.client") -> "boto3.client": """Patch the Bedrock client to trace model invocations. The following information is collected for each model invocation: