# Setup Environment
The following code loads the environment variables required to run this notebook.


In [None]:
FILE="GenAI Lab 4"

! pip install -qqq git+https://github.com/elastic/notebook-workshop-loader.git@main
from notebookworkshoploader import loader
import os
from dotenv import load_dotenv

if os.path.isfile("../env"):
    load_dotenv("../env", override=True)
    print('Successfully loaded environment variables from local env file')
else:
    loader.load_remote_env(file=FILE, env_url="https://notebook-workshop-api-voldmqr2bq-uc.a.run.app")

In [None]:
### TODO Custom connection strings sill required not ready for prod


In [None]:
! pip install -qqq  openai==0.28.1 # tiktoken==0.5.2 cohere==4.38
! pip install -qqq streamlit==1.30.0 elasticsearch==8.12.0 elastic-apm==6.20.0 inquirer==3.2.1 python-dotenv==1.0.1

import os, inquirer, re, secrets, requests
import streamlit as st
import openai

from IPython.display import display
from ipywidgets import widgets
from pprint import pprint
from elasticsearch import Elasticsearch
from string import Template
from requests.auth import HTTPBasicAuth

#if using the Elastic AI proxy, then generate the correct API key
if os.environ['ELASTIC_PROXY'] == "True":

  #remove the api type variable: it's a must when using the proxy
  if "OPENAI_API_TYPE" in os.environ: del os.environ["OPENAI_API_TYPE"]

  #generate and share "your" unique hash
  os.environ['USER_HASH'] = secrets.token_hex(nbytes=6)
  print(f"Your unique user hash is: {os.environ['USER_HASH']}")

  #get the current API key and combine with your hash
  os.environ['OPENAI_API_KEY'] = f"{os.environ['OPENAI_API_KEY']} {os.environ['USER_HASH']}"


## Create Elasticsearch client connection

In [None]:
if 'ELASTIC_CLOUD_ID' in os.environ:
  es = Elasticsearch(
    cloud_id=os.environ['ELASTIC_CLOUD_ID'],
    api_key=(os.environ['ELASTIC_APIKEY_ID'], os.environ['ELASTIC_APIKEY_SECRET']),
    request_timeout=30
  )
elif 'ELASTIC_URL' in os.environ:
  es = Elasticsearch(
    os.environ['ELASTIC_URL'],
    api_key=(os.environ['ELASTIC_APIKEY_ID'], os.environ['ELASTIC_APIKEY_SECRET']),
    request_timeout=30
  )
else:
  print("env needs to set either ELASTIC_CLOUD_ID or ELASTIC_URL")

# Lab 4-1
RAG Lab

## Step 1: Write the Streamlit single page app to the file system

In [None]:
%%writefile app.py

import os

import streamlit as st
import openai
from elasticsearch import Elasticsearch
import elasticapm
import base64

######################################
# Streamlit Configuration
st.set_page_config(layout="wide")


import textwrap
# wrap text when printing, because colab scrolls output to the right too much
def wrap_text(text, width):
    wrapped_text = textwrap.wrap(text, width)
    return '\n'.join(wrapped_text)

@st.cache_data()
def get_base64(bin_file):
    with open(bin_file, 'rb') as f:
        data = f.read()
    return base64.b64encode(data).decode()


def set_background(png_file):
    bin_str = get_base64(png_file)
    page_bg_img = '''
    <style>
    .stApp {
    background-image: url("data:image/png;base64,%s");
    background-size: cover;
    }
    </style>
    ''' % bin_str
    st.markdown(page_bg_img, unsafe_allow_html=True)
    return


#set_background('images/smoke_blackyellow1.png')

######################################

######################################
# Sidebar Options
def sidebar_bg(side_bg):
    side_bg_ext = 'png'
    st.markdown(
        f"""
      <style>
      [data-testid="stSidebar"] > div:first-child {{
          background: url(data:image/{side_bg_ext};base64,{base64.b64encode(open(side_bg, "rb").read()).decode()});
      }}
      </style>
      """,
        unsafe_allow_html=True,
    )


side_bg = './images/sidebar_chairs3.jpg'
sidebar_bg(side_bg)

# sidebar logo
st.markdown(
    """
    <style>
        [data-testid=stSidebar] [data-testid=stImage]{
            text-align: center;
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 100%;
        }
    </style>
    """, unsafe_allow_html=True
)

with st.sidebar:
    st.image("images/elastic_logo_transp_100.png")

######################################


# Configure OpenAI client
openai.api_key = os.environ['OPENAI_API_KEY']
openai.api_base = os.environ['OPENAI_API_BASE']
openai.default_model = os.environ['OPENAI_API_ENGINE']
openai.verify_ssl_certs = False


# Initialize Elasticsearch and APM clients
# Configure APM and Elasticsearch clients
@st.cache_resource
def initElastic():
    os.environ['ELASTIC_APM_SERVICE_NAME'] = "genai_workshop_v2_lab_2-2"
    apmclient = elasticapm.Client()
    elasticapm.instrument()

    if 'ELASTIC_CLOUD_ID' in os.environ:
        es = Elasticsearch(
            cloud_id=os.environ['ELASTIC_CLOUD_ID'],
            basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),
            request_timeout=30
        )
    else:
        es = Elasticsearch(
            os.environ['ELASTIC_URL'],
            basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),
            request_timeout=30
        )

    if os.environ['ELASTIC_PROXY'] != "True":
        openai.api_type = os.environ['OPENAI_API_TYPE']
        openai.api_version = os.environ['OPENAI_API_VERSION']

    return apmclient, es


apmclient, es = initElastic()

# Set our data index
index = os.environ['ELASTIC_INDEX_DOCS']


# Run an Elasticsearch query using BM25 relevance scoring
@elasticapm.capture_span("bm25_search")
def search_bm25(query_text, es, size=1, augment_method="Full Text"):
    if augment_method == "Full Text":
        query = {
            "bool": {
                "must": {
                    "match": {
                    "text": query_text
                    }
                },
                "filter": {
                    "term": {
                    "categories.keyword": "Living people"
                    }
                }
            }
        }

    elif augment_method == "Matching Chunk":
        query = {
            "query": {
                "bool": {
                    "must": [
                        {
                        "nested": {
                            "path": "passages",
                            "query": {
                            "bool": {
                                "must": [
                                {
                                    "match": {
                                    "passages.text": query_text
                                    }
                                }
                                ]
                            }
                            },
                            "inner_hits": {
                            "_source": False,
                            "fields": [
                                "passages.text"
                            ]
                            }
                        }
                        }
                    ],
                    "filter": {
                        "term": {
                        "categories.keyword": "Living people"
                        }
                    }
                }
            }
        }
    fields = [
        "text",
        "title",
    ]

    resp = es.search(index=index,
                     query=query,
                     fields=fields,
                     size=size,
                     source=False)
    # print(resp)
    body = resp
    url = 'nothing'

    return body, url


@elasticapm.capture_span("knn_search")
def search_knn(query_text, es, size=1, augment_method="Full Text"):
    fields = [
        "title",
        "text"
    ]

    knn = {
        "inner_hits": {
            "_source": False,
            "fields": [
                "passages.text"
            ]
        },
        "field": "passages.embeddings",
        "k": size,
        "num_candidates": 100,
        "query_vector_builder": {
            "text_embedding": {
                "model_id": "sentence-transformers__all-distilroberta-v1",
                "model_text": query_text
            }
        },
        "filter": {
            "term": {
            "categories.keyword": "Living people"
            }
        }
    }

    resp = es.search(index=index,
                     knn=knn,
                     fields=fields,
                     size=size,
                     source=False)

    # body = resp['hits']['hits'][0]['fields']['body_content'][0]
    # url = resp['hits']['hits'][0]['fields']['url'][0]

    return resp, None


def truncate_text(text, max_tokens):
    tokens = text.split()
    if len(tokens) <= max_tokens:
        return text

    return ' '.join(tokens[:max_tokens])


def build_text_obj(resp, aug_method):
    '''
    parse the response from ES and return a dict of the text fields
    :param resp:
    :return:
    '''

    # print(resp)
    tobj = {}

    for hit in resp['hits']['hits']:
        # tobj[hit['fields']['title'][0]] = []
        title = hit['fields']['title'][0]
        tobj.setdefault(title, [])

        if aug_method == "Matching Chunk":
            # print('hit')
            # print(hit)
            # tobj['passages'] = []
            for ihit in hit['inner_hits']['passages']['hits']['hits']:
                tobj[title].append(
                    {'passage': ihit['fields']['passages'][0]['text'][0],
                     '_score': ihit['_score']}
                )
        elif aug_method == "Full Text":
            tobj[title].append(
                hit['fields']
            )

    return tobj

def build_text_summary(resp):
    response = ["Titles of Documents Retrieved: "]
    for hit in resp['hits']['hits']:
        title = hit['fields']['title'][0]
        response.append( f" * {title}" )
    return "\n".join(response)


def generate_response(query, es, search_method, custom_prompt, negative_response, show_prompt, size=1,
                      augment_method="Full Text"):
    """
    Generates a response from ChatGPT based on the given query and Search Method.
    Formats the prompt, sends it to ChatGPT, and displays the results.
    """

    # Perform the search based on the specified method
    search_functions = {
        'bm25': {'method': search_bm25, 'display': 'Lexical Search'},
        'knn': {'method': search_knn, 'display': 'Semantic Search'}
    }
    search_func = search_functions.get(search_method)['method']
    if not search_func:
        raise ValueError(f"Invalid search method: {search_method}")

    # Perform the search and format the docs
    response, url = search_func(query, es, size, augment_method)
    augment_text = build_text_obj(response, augment_method)
    augment_summary = build_text_summary(response)

    res_col1, res_col2 = st.columns(2)
    # Display the search results from ES
    with res_col2:
        st.subheader(':rainbow[Elasticsearch Response]')
        st.write(':gray[Search Method:] :blue[%s]' % search_functions.get(search_method)['display'])
        st.write(':gray[Size Setting:] :blue[%s]' % size)
        st.write(':gray[Augment Setting:] :blue[%s]' % augment_method)

        st.write(':green[Augment Chunk(s) from Elasticsearch]')
        st.write(str(augment_summary))
        st.json(dict(augment_text))

        st.write(':violet[Elasticsearch Response]')

        st.json(dict(response))

    #    response_text = response['hits']['hits'][0]['fields']['text'][0]
    formatted_prompt = custom_prompt.replace("$query", query).replace("$response", str(augment_text)).replace(
        "$negResponse", negative_response)

    # Generate the ChatGPT response

    with res_col1:
        st.subheader(':orange[GenAI Response]')

        chat_response = chat_gpt(formatted_prompt, system_prompt="You are a helpful assistant.")
        st.markdown(chat_response)

    # Display results
    if show_prompt:
        st.text("Full prompt sent to ChatGPT:")
        st.text(wrap_text(formatted_prompt,70))

    #    st.header('Response from ChatGPT')
    #    st.text(chat_response)

    if negative_response not in chat_response:
        st.text("Reference URL:")
        st.text(url)


def chat_gpt(user_prompt, system_prompt):
    """
    Generates a response from ChatGPT based on the given user and system prompts.
    """
    max_tokens = 1024
    max_context_tokens = 4000
    safety_margin = 5

    # Truncate the prompt content to fit within the model's context length
    truncated_prompt = truncate_text(user_prompt, max_context_tokens - max_tokens - safety_margin)

    # Prepare the messages for the ChatGPT API
    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": truncated_prompt}]

    # Make the OpenAI API call
    response = openai.ChatCompletion.create(model=openai.default_model, temperature=0, messages=messages)

    # Add APM metadata and return the response content
    elasticapm.set_custom_context({'model': openai.default_model, 'prompt': user_prompt})
    return response["choices"][0]["message"]["content"]


# Main chat form
st.title("Search Powered AI - Famous Living People:")

# Define the default prompt and negative response
default_prompt_intro = "You are a helpful AI assistant that provides concise answers to questions about famous people using the provided context. "
default_response_instructions = "Context: $response\n"
default_negative_response = "If you cannot answer the question using only the provided context, reply with 'I'm unable to answer the question based on the information I have from wikipedia' and nothing else."


with st.form("chat_form"):

    query = st.text_input(f"Ask a question about a famous living person :",
                          placeholder='Who is Bill Gates and what was his biggest philanthropic act?')

    opt_col1, opt_col2 = st.columns(2)
    with opt_col1:
        with st.expander("Customize Prompt Template"):
            prompt_intro = st.text_area("Introduction/context of the prompt:", value=default_prompt_intro)
            prompt_query_placeholder = st.text_area("Placeholder for the user's query:", value="Answer this question: $query")
            prompt_response_placeholder = st.text_area("Placeholder for the Elasticsearch response:", value=default_response_instructions)
            prompt_negative_response = st.text_area("Negative response placeholder:", value=default_negative_response)
            prompt_closing = st.text_area("Closing remarks of the prompt:",
                                          value="Answer the question in markdown bulletpoints.")

            combined_prompt = f"{prompt_intro}\n{prompt_query_placeholder}\n{prompt_response_placeholder}\n{prompt_negative_response}\n{prompt_closing}"
            st.text_area("Preview of your custom prompt:", value=combined_prompt, disabled=True)

    with opt_col2:
        with st.expander("Retrieval Search and Display Options"):
            st.subheader("Retrieval Options")
            ret_1, ret_2 = st.columns(2)
            with ret_1:
                search_method = st.radio("Search Method", ("Semantic Search", "Lexical Search"))
                augment_method = st.radio("Augment Method", ("Full Text", "Matching Chunk"))
            with ret_2:
                # TODO this should update the title based on the augment_method
                doc_count_title = "Number of docs or chunks to Augment with" if augment_method == "Full Text" else "Number of Matching Chunks to Retrieve"
                doc_count = st.slider(doc_count_title, min_value=1, max_value=5, value=1)

            st.subheader("Display Options")
            show_full_prompt = st.checkbox('Show Full Prompt Sent to LLM')

    col1, col2 = st.columns(2)
    with col1:
        answer_button = st.form_submit_button("Find my answer!")

if answer_button:
    search_method = "knn" if search_method == "Semantic Search" else "bm25"

    apmclient.begin_transaction("query")
    elasticapm.label(search_method=search_method)
    elasticapm.label(query=query)

    try:
        # Use combined_prompt and show_full_prompt as arguments
        generate_response(query, es, search_method, combined_prompt, prompt_negative_response, show_full_prompt,
                          doc_count, augment_method)
        apmclient.end_transaction("query", "success")
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        apmclient.end_transaction("query", "failure")


## Step 2: Install localtunnel
This is a Colab-specific requirement to allow us to connect to a Streamlit app we're building from within Colab

In [None]:
!npm install localtunnel --loglevel=error

## Step 3: let's get some images to make things pretty

In [None]:
! git clone "https://github.com/elastic/genai-workshop-colab.git"
! cd genai-workshop-colab;  git checkout wave2; cd ..; cp -r ./genai-workshop-colab/notebooks/images images

## Step 4: run Streamlit
Running this cell will start local tunnel and generate a random URL

Copy the IP address on the first line then open the generated URL and paste it in the input box "Endpoint IP"

This will then start the Streamlit app

In [None]:
!streamlit run app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl ipv4.icanhazip.com