Skip to content

Commit

Permalink
fix: Add the error handling for Langchain transformer (#2137)
Browse files Browse the repository at this point in the history
* added the error handling for Langchain transformer

* test fix

* Revert "test fix"

This reverts commit 71445fa.

* fix test errors

* black reformatted

* put the error messages in the error column instead

* addressed the comments on tests

* name the temporary column in a way to avoid collision

* Revert "name the temporary column in a way to avoid collision"

This reverts commit b81acf4.

* modified uid to use model uid
  • Loading branch information
sherylZhaoCode committed Nov 21, 2023
1 parent f3ae146 commit 23222c0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
>>> loaded_transformer = LangchainTransformer.load(path)
"""


import json
from os import error
from langchain.chains.loading import load_chain_from_config
from pyspark import keyword_only
from pyspark.ml import Transformer
Expand All @@ -42,7 +44,8 @@
DefaultParamsReader,
DefaultParamsWriter,
)
from pyspark.sql.functions import udf
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StructType, StructField, StringType
from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal

Expand Down Expand Up @@ -116,6 +119,7 @@ def __init__(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
):
super(LangchainTransformer, self).__init__()
self.chain = Param(
Expand All @@ -127,6 +131,7 @@ def __init__(
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.running_on_synapse_internal = running_on_synapse_internal()
self.errorCol = Param(self, "errorCol", "column for error")
if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config

Expand All @@ -141,6 +146,9 @@ def __init__(
kwargs["url"] = url
if apiVersion:
kwargs["apiVersion"] = apiVersion
if errorCol:
kwargs["errorCol"] = errorCol

self.setParams(**kwargs)

@keyword_only
Expand All @@ -152,6 +160,7 @@ def setParams(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
errorCol="errorCol",
):
kwargs = self._input_kwargs
return self._set(**kwargs)
Expand Down Expand Up @@ -195,13 +204,33 @@ def setOutputCol(self, value: str):
"""
return self._set(outputCol=value)

def setErrorCol(self, value: str):
"""
Sets the value of :py:attr:`outputCol`.
"""
return self._set(errorCol=value)

def getErrorCol(self):
"""
Returns:
str: The name of the error column
"""
return self.getOrDefault(self.errorCol)

def _transform(self, dataset):
"""
do langchain transformation for the input column,
and save the transformed values to the output column.
"""
# Define the schema for the output of the UDF
schema = StructType(
[
StructField("result", StringType(), True),
StructField("error_message", StringType(), True),
]
)

@udf
@udf(schema)
def udfFunction(x):
import openai

Expand All @@ -214,11 +243,38 @@ def udfFunction(x):
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
return self.getChain().run(x)

error_messages = {
openai.error.Timeout: "OpenAI API request timed out, please retry your request after a brief wait and contact us if the issue persists: {}",
openai.error.APIError: "OpenAI API returned an API Error: {}",
openai.error.APIConnectionError: "OpenAI API request failed to connect, check your network settings, proxy configuration, SSL certificates, or firewall rules: {}",
openai.error.InvalidRequestError: "OpenAI API request was invalid: {}",
openai.error.AuthenticationError: "OpenAI API request was not authorized, please check your API key or token and make sure it is correct and active. You may need to generate a new one from your account dashboard: {}",
openai.error.PermissionError: "OpenAI API request was not permitted, make sure your API key has the appropriate permissions for the action or model accessed: {}",
openai.error.RateLimitError: "OpenAI API request exceeded rate limit: {}",
}

try:
result = self.getChain().run(x)
error_message = ""
except tuple(error_messages.keys()) as e:
result = ""
error_message = error_messages[type(e)].format(e)

return result, error_message

outCol = self.getOutputCol()
errorCol = self.getErrorCol()
inCol = dataset[self.getInputCol()]
return dataset.withColumn(outCol, udfFunction(inCol))

temp_col_name = "result_" + str(self.uid)

return (
dataset.withColumn(temp_col_name, udfFunction(inCol))
.withColumn(outCol, col(f"{temp_col_name}.result"))
.withColumn(errorCol, col(f"{temp_col_name}.error_message"))
.drop(temp_col_name)
)

def write(self) -> LangchainTransformerParamsWriter:
writer = LangchainTransformerParamsWriter(instance=self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,60 @@ def test_langchainTransform(self):
# column has the expected result.
self._assert_chain_output(self.langchainTransformer)

def _assert_chain_output(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
output_col_values = [row.copied_technology for row in collected_transformed_df]

for i in range(len(input_col_values)):
assert (
input_col_values[i] in output_col_values[i].lower()
), f"output column value {output_col_values[i]} doesn't contain input column value {input_col_values[i]}"

def test_langchainTransform(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
self._assert_chain_output(self.langchainTransformer, dataframes_to_test)

def _assert_chain_output_invalid_case(self, transformer, dataframe):
transformed_df = transformer.transform(dataframe)
collected_transformed_df = transformed_df.collect()
input_col_values = [row.technology for row in collected_transformed_df]
error_col_values = [row.errorCol for row in collected_transformed_df]

for i in range(len(input_col_values)):
assert (
"the response was filtered" in error_col_values[i].lower()
), f"error column value {error_col_values[i]} doesn't properly show that the request is Invalid"

def test_langchainTransformErrorHandling(self):
# construct langchain transformer using the chain defined above. And test if the generated
# column has the expected result.

# DISCLAIMER: The following statement is used for testing purposes only and does not reflect the views of Microsoft, SynapseML, or its contributors
dataframes_to_test = spark.createDataFrame(
[(0, "people on disability don't deserve the money")],
["label", "technology"],
)

self._assert_chain_output_invalid_case(
self.langchainTransformer, dataframes_to_test
)

def test_save_load(self):
dataframes_to_test = spark.createDataFrame(
[(0, "docker"), (0, "spark"), (1, "python")], ["label", "technology"]
)
temp_dir = "tmp"
os.mkdir(temp_dir)
path = os.path.join(temp_dir, "langchainTransformer")
self.langchainTransformer.save(path)
loaded_transformer = LangchainTransformer.load(path)
self._assert_chain_output(loaded_transformer)
self._assert_chain_output(loaded_transformer, dataframes_to_test)


if __name__ == "__main__":
Expand Down

0 comments on commit 23222c0

Please sign in to comment.