In [2]:
%%writefile ai.py
#Import Required Functions
from google import genai
from google.genai import types
from autots import AutoTS
import pandas as pd
import streamlit as st
import google.generativeai as genai
import json
import plotly.express as px

st.title("💬 Chatbot")

genai.configure(api_key= )
model = genai.GenerativeModel("gemini-1.5-flash")

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

if prompt := st.chat_input():
    if not api_key:
        st.info("Please add your OpenAI API key to continue.")
        st.stop()

uploaded_file = st.file_uploader("Upload your Excel file", type=["xlsx", "csv"])

def interpret_with_prompt(user_input):
    prompt = f"""
You are an AI that converts user questions into Python function calls.
User question: "{user_input}"
Return a JSON dict like:
{{
    "function": "forecast_model",
    "args": {{"target_kpi": "sales"}}
}}

Supported functions:
- forecast_model(target_kpi)
- summary_statistics(kpi, stat)
- generate_graphs(x_col, y_col, graph_type)
"""
    response = model.generate_content(prompt)
    try:
        return json.loads(response.text)
    except:
        return None
#Create all the functions required for exploratory data analysis

def forecast_model(df):
    # Train AutoTS model
    model = AutoTS(
        forecast_length=30,
        frequency='infer',
        ensemble='simple',  # options: simple, weighted, horizontal, etc.
        model_list='fast',  # 'superfast', 'fast', 'default', or list of models
        max_generations=5,
        num_validations=2,
        validation_method='backwards')
    model = model.fit(df, date_col='datetime', value_col='value', id_col=None)
    # Generate forecast
    prediction = model.predict()
    forecast = prediction.forecast
    return forecast

def summary_statistics(kpi, stat='mean'):
    if stat == 'mean':
        return kpi.mean()
    elif stat == 'median':
        return kpi.median()
    elif stat == 'mode':
        return kpi.mode().iloc[0] if not kpi.mode().empty else None
    else:
        return "Invalid statistic"

def generate_graphs(df, x_col, y_col=None, graph_type='line'):
    if graph_type == 'line':
        fig = px.line(df, x=x_col, y=y_col)
    elif graph_type == 'bar':
        fig = px.bar(df, x=x_col, y=y_col)
    elif graph_type == 'pie':
        fig = px.pie(df, names=x_col, values=y_col)
    else:
        return "Invalid graph type"
    return fig


col1,col2=st.columns(2,border=True)
if uploaded_file:
    with st.col1:
        df = pd.read_excel(uploaded_file) if uploaded_file.name.endswith("xlsx") else pd.read_csv(uploaded_file)
        st.write("Preview of Data:", df.head())

    with st.col2:
        question = st.text_input("Ask a question about your data:")
        if st.button("Ask"):
            plan = interpret_with_prompt(question)
            if plan:
                func = plan["function"]
                args = plan["args"]

            if func == "forecast_model":
                st.write("Forecasting:", forecast_model(df))
            elif func == "summary_statistics":
                kpi_series = df[args["kpi"]]
                st.write(f"{args['stat']} of {args['kpi']}:", summary_statistics(kpi_series, args["stat"]))
            elif func == "generate_graphs":
                fig = generate_graphs(df, args["x_col"], args.get("y_col"), args.get("graph_type", "line"))
                if fig:
                    st.plotly_chart(fig)
        else:
            st.error("Could not understand the question.")


Overwriting ai1.py
