forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
413fba2
commit 6cc3833
Showing
7 changed files
with
291 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""Chain that interprets a prompt and executes python code to do math. | ||
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
"""Chain that interprets a prompt and executes python code to do math.""" | ||
import datetime | ||
import json | ||
from typing import Dict, List, Any | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.chains.base import Chain | ||
from langchain.chains.llm import LLMChain | ||
from langchain.chains.llm_google_calendar.prompt import CREATE_EVENT_PROMPT, CLASSIFICATION_PROMPT | ||
from langchain.llms.base import BaseLLM | ||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain.python import PythonREPL | ||
from langchain.utilities.google_calendar.loader import google_credentials_loader | ||
|
||
|
||
class LLMGoogleCalendarChain(Chain, BaseModel): | ||
"""Chain that interprets a prompt and executes python code to do math. | ||
Example: | ||
.. code-block:: python | ||
from langchain import LLMMathChain, OpenAI | ||
llm_math = LLMMathChain(llm=OpenAI()) | ||
""" | ||
|
||
llm: BaseLLM | ||
"""LLM wrapper to use.""" | ||
create_event_prompt: BasePromptTemplate = CREATE_EVENT_PROMPT | ||
"""Prompt to use for creating event.""" | ||
classification_prompt: BasePromptTemplate = CLASSIFICATION_PROMPT | ||
"""Prompt to use for classification.""" | ||
|
||
query:str | ||
query_input_key: str = "query" #: :meta private: | ||
date_input_key: str = "date" #: :meta private: | ||
u_timezone_input_key: str = "u_timezone" #: :meta private: | ||
|
||
service: Any #: :meta private: | ||
google_http_error: Any #: :meta private: | ||
creds: Any #: :meta private: | ||
|
||
|
||
output_key: str = "answer" #: :meta private: | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@classmethod | ||
def from_default(cls, query: str) -> LLMGoogleCalendarChain: | ||
"""Load with default LLM.""" | ||
return cls(llm=cls.llm, query=query) | ||
|
||
|
||
# @property | ||
# def input_keys(self) -> List[str]: | ||
# """Expect input key. | ||
|
||
# :meta private: | ||
# """ | ||
# return [self.query_input_key, self.date_input_key, self.u_timezone_input_key] | ||
|
||
# @property | ||
# def output_keys(self) -> List[str]: | ||
# """Expect output key. | ||
|
||
# :meta private: | ||
# """ | ||
# return [self.output_key] | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
# | ||
# Auth done through OAuth2.0 | ||
|
||
try: | ||
from langchain.utilities.google_calendar.loader import google_credentials_loader | ||
# save the values from loader to values | ||
values.update(google_credentials_loader()) | ||
|
||
except ImportError: | ||
raise ValueError( | ||
"Could not import google python packages. " | ||
"""Please it install it with `pip install --upgrade | ||
google-api-python-client google-auth-httplib2 google-auth-oauthlib`.""" | ||
) | ||
return values | ||
|
||
|
||
|
||
def run_classification(self, query: str) -> str: | ||
"""Run classification on query.""" | ||
from langchain import LLMChain, OpenAI, PromptTemplate | ||
|
||
prompt = PromptTemplate( | ||
template=CLASSIFICATION_PROMPT, input_variables=["query"] | ||
) | ||
llm_chain = LLMChain( | ||
llm=OpenAI(temperature=0, model="text-davinci-003"), | ||
prompt=prompt, | ||
verbose=True, | ||
) | ||
return llm_chain.run(query=query).strip().lower() | ||
|
||
def run_create_event(self, query: str) -> str: | ||
create_event_chain = LLMChain( | ||
llm=self.llm, | ||
prompt=self.create_event_prompt, | ||
verbose=True, | ||
) | ||
date = datetime.datetime.utcnow().isoformat() + "Z" | ||
u_timezone = str( | ||
datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo | ||
) | ||
|
||
date = datetime.datetime.utcnow().isoformat() + "Z" | ||
u_timezone = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo | ||
output = create_event_chain.run( | ||
query=query, date=date, u_timezone=u_timezone | ||
).strip() | ||
|
||
loaded = json.loads(output) | ||
( | ||
event_summary, | ||
event_start_time, | ||
event_end_time, | ||
event_location, | ||
event_description, | ||
user_timezone, | ||
) = loaded.values() | ||
|
||
event = self.create_event( | ||
event_summary=event_summary, | ||
event_start_time=event_start_time, | ||
event_end_time=event_end_time, | ||
user_timezone=user_timezone, | ||
event_location=event_location, | ||
event_description=event_description, | ||
) | ||
return "Event created successfully, details: event " + event.get("htmlLink") | ||
|
||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | ||
output = "" | ||
classification = self.run_classification(self.query) | ||
if classification == "create_event": | ||
output = self.run_create_event(query=self.query) | ||
|
||
|
||
return {self.output_key: output} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# flake8: noqa | ||
from langchain.prompts.prompt import PromptTemplate | ||
|
||
_CREATE_EVENT_PROMPT = """ | ||
Date format: YYYY-MM-DDThh:mm:ss+00:00 | ||
Based on this event description:\n'Joey birthday tomorrow at 7 pm', | ||
output a json of the following parameters: \n | ||
Today's datetime on UTC time 2021-05-02T10:00:00+00:00 and timezone | ||
of the user is -5, take into account the timezone of the user and today's date. | ||
1. event_summary \n | ||
2. event_start_time \n | ||
3. event_end_time \n | ||
4. event_location \n | ||
5. event_description \n | ||
6. user_timezone \n | ||
event_summary:\n | ||
{{ | ||
"event_summary": "Joey birthday", | ||
"event_start_time": "2021-05-03T19:00:00-05:00", | ||
"event_end_time": "2021-05-03T20:00:00-05:00", | ||
"event_location": "", | ||
"event_description": "", | ||
"user_timezone": "America/New_York" | ||
}} | ||
Date format: YYYY-MM-DDThh:mm:ss+00:00 | ||
Based on this event description:\n{query}, output a json of the | ||
following parameters: \n | ||
Today's datetime on UTC time {date} and timezone of the user {u_timezone}, | ||
take into account the timezone of the user and today's date. | ||
1. event_summary \n | ||
2. event_start_time \n | ||
3. event_end_time \n | ||
4. event_location \n | ||
5. event_description \n | ||
6. user_timezone \n | ||
event_summary: \n | ||
""" | ||
|
||
CREATE_EVENT_PROMPT = PromptTemplate(input_variables=["query","date","u_timezone"], template=_CREATE_EVENT_PROMPT) | ||
|
||
|
||
_CLASSIFICATION_PROMPT = """ | ||
Reschedule our meeting for 5 pm today. \n | ||
The following is an action to be taken in a calendar. | ||
Classify it as one of the following: \n\n | ||
1. create_event \n | ||
2. view_event \n | ||
3. view_events \n | ||
4. delete_event \n | ||
5. reschedule_event \n | ||
Classification: Reschedule an event | ||
{query} | ||
The following is an action to be taken in a calendar. | ||
Classify it as one of the following: \n\n | ||
1. create_event \n | ||
2. view_event \n | ||
3. view_events \n | ||
4. delete_event \n | ||
5. reschedule_event \n | ||
Classification: | ||
""" | ||
CLASSIFICATION_PROMPT = PromptTemplate(input_variables=["query"], template=_CLASSIFICATION_PROMPT) |
Oops, something went wrong.