Skip to content

Commit

Permalink
langchain: cherry-pick moderation fix into v0.1 (#21544)
Browse files Browse the repository at this point in the history
```
git checkout v0.1
git pull
git checkout -b cc/cherry_pick_into_v01
git cherry-pick d3ca2cc
```

Co-authored-by: Matt Florence <matt@mattflo.com>
Co-authored-by: Emilia Katari <emilia@outpace.com>
Co-authored-by: Erick Friis <erickfriis@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
5 people committed May 10, 2024
1 parent 73a5b3d commit 7440ce0
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions libs/langchain/langchain/chains/moderation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Pass input through a moderation endpoint."""

from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import check_package_version, get_from_dict_or_env

from langchain.chains.base import Chain

Expand All @@ -25,6 +29,7 @@ class OpenAIModerationChain(Chain):
"""

client: Any #: :meta private:
async_client: Any #: :meta private:
model_name: Optional[str] = None
"""Moderation model name to use."""
error: bool = False
Expand All @@ -33,6 +38,7 @@ class OpenAIModerationChain(Chain):
output_key: str = "output" #: :meta private:
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
_openai_pre_1_0: bool = Field(default=None)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand All @@ -52,7 +58,16 @@ def validate_environment(cls, values: Dict) -> Dict:
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
values["client"] = openai.Moderation # type: ignore
values["_openai_pre_1_0"] = False
try:
check_package_version("openai", gte_version="1.0")
except ValueError:
values["_openai_pre_1_0"] = True
if values["_openai_pre_1_0"]:
values["client"] = openai.Moderation
else:
values["client"] = openai.OpenAI()
values["async_client"] = openai.AsyncOpenAI()
except ImportError:
raise ImportError(
"Could not import openai python package. "
Expand All @@ -76,8 +91,12 @@ def output_keys(self) -> List[str]:
"""
return [self.output_key]

def _moderate(self, text: str, results: dict) -> str:
if results["flagged"]:
def _moderate(self, text: str, results: Any) -> str:
if self._openai_pre_1_0:
condition = results["flagged"]
else:
condition = results.flagged
if condition:
error_str = "Text was found that violates OpenAI's content policy."
if self.error:
raise ValueError(error_str)
Expand All @@ -87,10 +106,26 @@ def _moderate(self, text: str, results: dict) -> str:

def _call(
self,
inputs: Dict[str, str],
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self._openai_pre_1_0:
results = self.client.create(text)
output = self._moderate(text, results["results"][0])
else:
results = self.client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self._openai_pre_1_0:
return await super()._acall(inputs, run_manager=run_manager)
text = inputs[self.input_key]
results = self.client.create(text)
output = self._moderate(text, results["results"][0])
results = await self.async_client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}

0 comments on commit 7440ce0

Please sign in to comment.