diff --git a/labs/TrapheusAI/llm/model.py b/labs/TrapheusAI/llm/model.py index fd41438..a945472 100644 --- a/labs/TrapheusAI/llm/model.py +++ b/labs/TrapheusAI/llm/model.py @@ -1,15 +1,17 @@ -from dataclasses import asdict -from typing import Tuple, List -import openai - import matplotlib +import openai import pandas -from pandasai import SmartDataframe -from pandasai.llm.openai import OpenAI import streamlit as streamlit - +from dataclasses import asdict +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.llms import Ollama from llm.prompts import Prompt +from pandasai import SmartDataframe +from pandasai.llm.openai import OpenAI +from typing import Tuple, List +from pandasai.llm import LangchainLLM # Since most backends are non GUI setting this to Agg to pick this up later from a file # TODO later let the user choose the ibackend he is operating on from a command line param. @@ -26,11 +28,16 @@ def ask_foundational_model(discourses: List[Prompt]) -> Tuple[str, List[Prompt]] answer = Prompt(**result["choices"][0]["message"]) return answer.content, discourses + [answer] -def ask_foundational_data_model(dataframe: pandas.core.frame.DataFrame, query: str): +def ask_foundational_data_model(dataframe: pandas.core.frame.DataFrame, query: str, llm_type): # local llm is still having issues, i have reported this at # https://github.com/gventuri/pandas-ai/issues/340#issuecomment-1637184573 # seeing if chart introduction can help https://github.com/gventuri/pandas-ai/pull/497/files#r1341966270 - llm = OpenAI(api_token=streamlit.secrets["OPENAI_API_KEY"]) + if llm_type == 'LocalLLM': + olama = Ollama(model="llama2") + # Wrapping up ollama in langchain LLM till its supported https://github.com/gventuri/pandas-ai/pull/611 + llm = LangchainLLM(olama) + else: + llm = OpenAI(api_token=streamlit.secrets["OPENAI_API_KEY"]) smart_df = SmartDataframe(dataframe, config={"llm": llm}) response = smart_df.chat(query) return response diff --git a/labs/TrapheusAI/trapheusai_app.py b/labs/TrapheusAI/trapheusai_app.py index d7c6bf7..a74cdc0 100644 --- a/labs/TrapheusAI/trapheusai_app.py +++ b/labs/TrapheusAI/trapheusai_app.py @@ -20,13 +20,15 @@ def initialize_page(): streamlit.sidebar.image("logo/TrapheusAILogo.png", use_column_width=True) streamlit.sidebar.subheader("Select the type of analysis") analysis_type = streamlit.sidebar.selectbox("", options=["Concept Search", "Dataset Search"]) + streamlit.sidebar.subheader("Select the type of LLM for analysis") + llm_type = streamlit.sidebar.selectbox("", options=["LocalLLM", "OpenAI"]) if analysis_type == "Dataset Search": - handle_dataset_search() + handle_dataset_search(llm_type) else: handle_concept_search() # TODO Later move analysis types to a common abstract factory -def handle_dataset_search(): +def handle_dataset_search(llm_type): query = streamlit.sidebar.text_area( "Search for data", value=streamlit.session_state.get("dataset-input", ""), @@ -54,13 +56,9 @@ def handle_dataset_search(): mime = "text/csv", on_click = showDownloadMessage) question = streamlit.chat_input("Ask any question related to the dataset") if question: - - answer = ask_foundational_data_model(df, question) - print(question) + answer = ask_foundational_data_model(df, question, llm_type) if ("plot" or "Plot" or "Chart" or "chart" or "Graph" or "graph") in question: - print('inside') plot_folder = glob("exports/charts/temp_chart.png") - print (plot_folder) plot = Image.open(plot_folder[0]) streamlit.image(plot, use_column_width=False) else: