Skip to content

Commit

Permalink
Update dyn inf
Browse files Browse the repository at this point in the history
  • Loading branch information
vowelparrot committed Apr 29, 2023
1 parent ae7dd39 commit 06f32a9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
18 changes: 14 additions & 4 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,25 @@ def get_filtered_args(
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
return {k: schema[k] for k in valid_keys if k != "run_manager"}


class _SchemaConfig:
"""Configuration for the pydantic model."""

extra = Extra.forbid
arbitrary_types_allowed = True


def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
Expand Down Expand Up @@ -143,8 +153,8 @@ def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
return get_filtered_args(inferred_model, self._run)
schema = create_schema_from_function(self.name, self._run)
return schema.schema()["properties"]

def _parse_input(
self,
Expand Down
6 changes: 5 additions & 1 deletion langchain/tools/google_places/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from pydantic import Field
from pydantic import BaseModel, Field

from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
Expand All @@ -12,6 +12,10 @@
from langchain.utilities.google_places_api import GooglePlacesAPIWrapper


class GooglePlacesSchema(BaseModel):
query: str = Field(..., description="Query for goole maps")


class GooglePlacesTool(BaseTool):
"""Tool that adds the capability to query the Google places API."""

Expand Down
13 changes: 10 additions & 3 deletions langchain/tools/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Optional
from typing import Optional, Type

import requests
import yaml
Expand Down Expand Up @@ -49,9 +49,16 @@ def marshal_spec(txt: str) -> dict:
return yaml.safe_load(txt)


class AIPLuginToolSchema(BaseModel):
"""AIPLuginToolSchema."""

tool_input: Optional[str] = ""


class AIPluginTool(BaseTool):
plugin: AIPlugin
api_spec: str
args_schema: Type[AIPLuginToolSchema] = AIPLuginToolSchema

@classmethod
def from_plugin_url(cls, url: str) -> AIPluginTool:
Expand All @@ -78,15 +85,15 @@ def from_plugin_url(cls, url: str) -> AIPluginTool:

def _run(
self,
tool_input: str,
tool_input: Optional[str] = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
return self.api_spec

async def _arun(
self,
tool_input: str,
tool_input: Optional[str] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
Expand Down

0 comments on commit 06f32a9

Please sign in to comment.