Skip to content

Commit

Permalink
watsonx longchain example
Browse files Browse the repository at this point in the history
  • Loading branch information
huang-cn committed Apr 1, 2024
1 parent 6767a89 commit ec5a98a
Showing 1 changed file with 81 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as WatsonMLGenParams
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
# from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai.foundation_models import Model as WatsonAIModel
from ibm_watsonx_ai.foundation_models.extensions.langchain import WatsonxLLM as WatsonxLLM_AI
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as WatsonMLGenParams

from langchain.llms.openai import OpenAI
from langchain.agents import load_tools
Expand All @@ -30,29 +33,29 @@

load_dotenv(find_dotenv())

# Traceloop.init(api_endpoint=os.environ["OTLP_EXPORTER_HTTP"],
# # api_key=os.environ["TRACELOOP_API_KEY"],
# app_name=os.environ["SVC_NAME"],
# )
Traceloop.init(api_endpoint=os.environ["OTLP_EXPORTER_HTTP"],
# api_key=os.environ["TRACELOOP_API_KEY"],
app_name=os.environ["SVC_NAME"],
)

""" only need 2 lines code to instrument Langchain LLM
"""
from otel_lib.instrumentor import LangChainHandlerInstrumentor as SimplifiedLangChainHandlerInstrumentor
from opentelemetry.sdk._logs import LoggingHandler
tracer_provider, metric_provider, logger_provider = SimplifiedLangChainHandlerInstrumentor().instrument(
otlp_endpoint=os.environ["OTLP_EXPORTER"],
# otlp_endpoint=os.environ["OTLP_EXPORTER_GRPC"],
# metric_endpoint=os.environ["OTEL_METRICS_EXPORTER"],
# log_endpoint=os.environ["OTEL_LOG_EXPORTER"],
service_name=os.environ["SVC_NAME"],
insecure = True,
)
# from otel_lib.instrumentor import LangChainHandlerInstrumentor as SimplifiedLangChainHandlerInstrumentor
# from opentelemetry.sdk._logs import LoggingHandler
# tracer_provider, metric_provider, logger_provider = SimplifiedLangChainHandlerInstrumentor().instrument(
# otlp_endpoint=os.environ["OTLP_EXPORTER"],
# # otlp_endpoint=os.environ["OTLP_EXPORTER_GRPC"],
# # metric_endpoint=os.environ["OTEL_METRICS_EXPORTER"],
# # log_endpoint=os.environ["OTEL_LOG_EXPORTER"],
# service_name=os.environ["SVC_NAME"],
# insecure = True,
# )
"""=======================================================
"""
handler = LoggingHandler(level=logging.DEBUG,logger_provider=logger_provider)
# handler = LoggingHandler(level=logging.DEBUG,logger_provider=logger_provider)
# Create different namespaced loggers
logger = logging.getLogger("mylog_test")
logger.addHandler(handler)
# logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

# os.environ["WATSONX_APIKEY"] = os.getenv("IAM_API_KEY")
Expand Down Expand Up @@ -82,40 +85,68 @@
# params=watson_ml_parameters,
# )

api_key = os.getenv("IBM_GENAI_KEY", None)
api_url = "https://bam-api.res.ibm.com"
creds = Credentials(api_key, api_endpoint=api_url)

genai_parameters = GenaiGenerateParams(
decoding_method="sample", # Literal['greedy', 'sample']
max_new_tokens=300,
min_new_tokens=10,
top_p=1,
top_k=50,
temperature=0.05,
time_limit=30000,
# length_penalty={"decay_factor": 2.5, "start_index": 5},
# repetition_penalty=1.2,
truncate_input_tokens=2048,
# random_seed=33,
stop_sequences=["fail", "stop1"],
return_options={
"input_text": True,
"generated_tokens": True,
"input_tokens": True,
"token_logprobs": True,
"token_ranks": False,
"top_n_tokens": False
},
)
os.environ["WATSONX_APIKEY"] = os.getenv("IAM_API_KEY")
apikey=os.getenv("IAM_API_KEY")
project_id=os.getenv("PROJECT_ID")
watson_ml_url="https://us-south.ml.cloud.ibm.com"


# api_key = os.getenv("IBM_GENAI_KEY", None)
# api_url = "https://bam-api.res.ibm.com"
# creds = Credentials(api_key, api_endpoint=api_url)

# genai_parameters = GenaiGenerateParams(
# decoding_method="sample", # Literal['greedy', 'sample']
# max_new_tokens=300,
# min_new_tokens=10,
# top_p=1,
# top_k=50,
# temperature=0.05,
# time_limit=30000,
# # length_penalty={"decay_factor": 2.5, "start_index": 5},
# # repetition_penalty=1.2,
# truncate_input_tokens=2048,
# # random_seed=33,
# stop_sequences=["fail", "stop1"],
# return_options={
# "input_text": True,
# "generated_tokens": True,
# "input_tokens": True,
# "token_logprobs": True,
# "token_ranks": False,
# "top_n_tokens": False
# },
# )

watsonx_genai_llm = LangChainInterface(
# model="google/flan-t5-xxl",
# model="meta-llama/llama-2-70b",
model = "ibm/granite-13b-chat-v1",
params=genai_parameters,
credentials=creds
watson_ml_parameters = {
WatsonMLGenParams.DECODING_METHOD: "sample",
WatsonMLGenParams.MAX_NEW_TOKENS: 30,
WatsonMLGenParams.MIN_NEW_TOKENS: 1,
WatsonMLGenParams.TEMPERATURE: 0.5,
WatsonMLGenParams.TOP_K: 50,
WatsonMLGenParams.TOP_P: 1,
}

model = WatsonAIModel(
model_id="google/flan-ul2",
credentials={
"apikey": apikey,
"url": watson_ml_url
},
params=watson_ml_parameters,
project_id=project_id,
)
watsonx_ai_llm = WatsonxLLM_AI(model=model)
# watsonx_ml_llm = WatsonxLLM_ML(model=model)

# watsonx_genai_llm = LangChainInterface(
# # model="google/flan-t5-xxl",
# # model="meta-llama/llama-2-70b",
# # model = "ibm/granite-13b-chat-v1",
# model="google/flan-ul2",
# params=genai_parameters,
# credentials=creds
# )


# openai_llm = OpenAI(
Expand Down Expand Up @@ -152,7 +183,7 @@ def langchain_watson_genai_llm_chain():
HumanMessage(content=f"tell me what is the most famous dish in {RandomCountryName()}?"),
]
first_prompt_template = ChatPromptTemplate.from_messages(first_prompt_messages)
first_chain = LLMChain(llm=watsonx_genai_llm, prompt=first_prompt_template, output_key="target")
first_chain = LLMChain(llm=watsonx_ai_llm, prompt=first_prompt_template, output_key="target")
logger.info("first chain set", extra={"action": "set llm chain", "chain name": "first chain"})

second_prompt_messages = [
Expand All @@ -161,7 +192,7 @@ def langchain_watson_genai_llm_chain():
HumanMessagePromptTemplate.from_template("pls provide the recipe for dish {target}\n "),
]
second_prompt_template = ChatPromptTemplate.from_messages(second_prompt_messages)
second_chain = LLMChain(llm=watsonx_genai_llm, prompt=second_prompt_template)
second_chain = LLMChain(llm=watsonx_ai_llm, prompt=second_prompt_template)
logger.info("second chain set", extra={"action": "set llm chain", "chain name": "second chain"})

workflow = SequentialChain(chains=[first_chain, second_chain], input_variables=[])
Expand Down

0 comments on commit ec5a98a

Please sign in to comment.