Skip to content

Commit

Permalink
Resolving comments
Browse files Browse the repository at this point in the history
Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
  • Loading branch information
sunishsheth2009 committed Apr 21, 2023
1 parent 29f16fb commit 16272ec
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,19 @@ def log_model(
"""
import langchain

if (
type(lc_model) != langchain.chains.llm.LLMChain
and type(lc_model) != langchain.agents.agent.AgentExecutor
if not isinstance(
lc_model, (langchain.chains.llm.LLMChain, langchain.agents.agent.AgentExecutory)
):
raise TypeError(
raise mlflow.MlflowException.invalid_parameter_value(
"MLflow langchain flavor only supports logging langchain.chains.llm.LLMChain and "
+ f"langchain.agents.agent.AgentExecutor instances, found {type(lc_model)}"
)
_SUPPORTED_LLMS = {langchain.llms.openai.OpenAI, langchain.llms.huggingface_hub.HuggingFaceHub}
if (
type(lc_model) == langchain.chains.llm.LLMChain
isinstance(lc_model, langchain.chains.llm.LLMChain)
and type(lc_model.llm) not in _SUPPORTED_LLMS
) or (
type(lc_model) == langchain.agents.agent.AgentExecutor
isinstance(lc_model, langchain.agents.agent.AgentExecutor)
and type(lc_model.agent.llm_chain.llm) not in _SUPPORTED_LLMS
):
logger.warning(
Expand Down Expand Up @@ -343,13 +342,17 @@ def _save_model(model, path):
if model.agent:
agent_data_path = os.path.join(path, _AGENT_DATA_FILE_NAME)
model.save_agent(agent_data_path)
model_data_kwargs = {**model_data_kwargs, _AGENT_DATA_KEY: _AGENT_DATA_FILE_NAME}
model_data_kwargs[_AGENT_DATA_KEY] = _AGENT_DATA_FILE_NAME

if model.tools:
tools_data_path = os.path.join(path, _TOOLS_DATA_FILE_NAME)
with open(tools_data_path, "wb") as f:
cloudpickle.dump(model.tools, f)
model_data_kwargs = {**model_data_kwargs, _TOOLS_DATA_KEY: _TOOLS_DATA_FILE_NAME}
model_data_kwargs[_TOOLS_DATA_KEY] = _TOOLS_DATA_FILE_NAME
else:
raise mlflow.MlflowException.invalid_parameter_value(
"For Initializing the AgentExecutor, tools must be provided."
)

key_to_ignore = ["llm_chain", "agent", "tools", "callback_manager"]
temp_dict = {}
Expand All @@ -361,13 +364,12 @@ def _save_model(model, path):
with open(agent_primitive_path, "w") as config_file:
json.dump(temp_dict, config_file, indent=4)

model_data_kwargs = {
**model_data_kwargs,
_AGENT_PRIMITIVES_DATA_KEY: _AGENT_PRIMITIVES_FILE_NAME,
}
model_data_kwargs[_AGENT_PRIMITIVES_DATA_KEY] = _AGENT_PRIMITIVES_FILE_NAME
else:
logger.error("Could not save model.")
pass
raise mlflow.MlflowException.invalid_parameter_value(
"MLflow langchain flavor only supports logging langchain.chains.llm.LLMChain and "
+ f"langchain.agents.agent.AgentExecutor instances, found {type(lc_model)}"

Check failure on line 371 in mlflow/langchain/__init__.py

View workflow job for this annotation

GitHub Actions / lint

E0602: Undefined variable 'lc_model' (undefined-variable)
)

return model_data_kwargs

Expand All @@ -383,10 +385,20 @@ def _load_model(path, agent_path=None, tools_path=None, agent_primitive_path=Non
from langchain.agents import initialize_agent

llm = load_chain(path)
with open(tools_path, "rb") as f:
tools = cloudpickle.load(f)
with open(agent_primitive_path, "r") as config_file:
args = json.load(config_file)
tools = []
args = {}

if os.path.exists(tools_path):
with open(tools_path, "rb") as f:
tools = cloudpickle.load(f)
else:
raise mlflow.MlflowException.invalid_parameter_value(
"For Initializing the AgentExecutor, tools must be provided."
)

if os.path.exists(agent_primitive_path):
with open(agent_primitive_path, "r") as config_file:
args = json.load(config_file)

model = initialize_agent(tools=tools, llm=llm, agent_path=agent_path, **args)
return model
Expand Down Expand Up @@ -439,15 +451,14 @@ def _load_model_from_local_fs(local_model_path):
lc_model_path = os.path.join(
local_model_path, flavor_conf.get(_MODEL_DATA_KEY, _MODEL_DATA_FILE_NAME)
)
agent_model_path = None

agent_model_path = tools_model_path = agent_primitive_path = None
if flavor_conf.get(_AGENT_DATA_KEY):
agent_model_path = os.path.join(local_model_path, flavor_conf.get(_AGENT_DATA_KEY))

tools_model_path = None
if flavor_conf.get(_TOOLS_DATA_KEY):
tools_model_path = os.path.join(local_model_path, flavor_conf.get(_TOOLS_DATA_KEY))

agent_primitive_path = None
if flavor_conf.get(_AGENT_PRIMITIVES_DATA_KEY):
agent_primitive_path = os.path.join(
local_model_path, flavor_conf.get(_AGENT_PRIMITIVES_DATA_KEY)
Expand Down

0 comments on commit 16272ec

Please sign in to comment.