In [None]:
# Import python packages
import streamlit as st
import pandas as pd 
from snowflake.core import Root
import json
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()
root = Root(session)


In [None]:
pmc_search_service = (root
  .databases["PMC_DATA"]
  .schemas["PMC_OA_OPENDATA"]
  .cortex_search_services["my_pmc_search_service"]
)

In [None]:
#PMC9388453
#type II diabetes
similar_articles=pmc_search_service.search(
    query='heavy metals',
    columns= ["CHUNK", "ACCESSIONID"],
    filter={"@eq": {"accessionid": "PMC9388453"} },
    limit=10
)

In [None]:
# Then, we can use the python name to turn cell2 into a Pandas dataframe
#PMC9388453
#type II diabetes
similar_articles=pmc_search_service.search(
    query='heavy metals',
    columns= ["CHUNK", "ACCESSIONID"],
    limit=100
)

In [None]:
similar_articles_resp=similar_articles.to_json()
data = json.loads(similar_articles_resp)
df = pd.json_normalize(data['results'])

In [None]:
df

In [None]:
df["ACCESSIONID"].unique()

In [None]:
#example streamlit code 
import streamlit as st # Import python packages
from snowflake.snowpark.context import get_active_session
session = get_active_session() # Get the current credentials

import pandas as pd

pd.set_option("max_colwidth",None)
num_chunks = 3 # Num-chunks provided as context. Play with this to check how it affects your accuracy

def create_prompt (myquestion, rag, article_chosen=None):
    st.write(article_chosen)
    root = Root(session)
    pmc_search_service = (root
                .databases["PMC_DATA"]
                .schemas["PMC_OA_OPENDATA"]
                .cortex_search_services["my_pmc_search_service"]
                )
    if rag == 1:  
        if article_chosen:
            similar_articles=pmc_search_service.search(
                query='heavy metals',
                columns= ["CHUNK", "ACCESSIONID"],
                filter={"@eq": {"accessionid": f"{article_chosen}" }},
                limit=10
                )
            
            similar_articles_resp=similar_articles.to_json()
            data = json.loads(similar_articles_resp)
            df_context = pd.json_normalize(data['results'])
            st.write(df_context)
        
            
            context_lenght = len(df_context) -1
    
            prompt_context = ""
            for i in range (0, context_lenght):
                prompt_context += df_context._get_value(i, 'CHUNK')
    
            prompt_context = prompt_context.replace("'", "")
            accessionid =  df_context._get_value(0,'ACCESSIONID')
        
            prompt = f"""
              'You are an expert assistance extracting information from context provided. 
               Answer the question based on the context. Be concise and do not hallucinate. 
               If you don´t have the information just say so.
              Context: {prompt_context}
              Question:  
               {myquestion} 
               Answer: '
               """
        else:
            similar_articles=pmc_search_service.search(
                query='heavy metals',
                columns= ["CHUNK", "ACCESSIONID"],
                limit=10
                )
            
            similar_articles_resp=similar_articles.to_json()
            data = json.loads(similar_articles_resp)
            df_context = pd.json_normalize(data['results'])
        
            
            context_lenght = len(df_context) -1
    
            prompt_context = ""
            for i in range (0, context_lenght):
                prompt_context += df_context._get_value(i, 'CHUNK')
    
            prompt_context = prompt_context.replace("'", "")
            accessionid =  df_context._get_value(0,'ACCESSIONID')
        
            prompt = f"""
              'You are an expert assistance extracting information from context provided. 
               Answer the question based on the context. Be concise and do not hallucinate. 
               If you don´t have the information just say so.
              Context: {prompt_context}
              Question:  
               {myquestion} 
               Answer: '
               """

    else:
        prompt = f"""
         'Question:  
           {myquestion} 
           Answer: '
           """
        accessionid = "None"
        
        
    return prompt,accessionid

def complete(myquestion, model_name, rag = 1, article_chosen=None):
    #st.write(article_chosen)

    prompt,accessionid =create_prompt (myquestion, rag, article_chosen)
    cmd = f"""
             select SNOWFLAKE.CORTEX.COMPLETE(?,?) as response
           """
    
    df_response = session.sql(cmd, params=[model_name, prompt]).collect()
    return df_response, accessionid

def display_response (question, model, rag=0, article_chosen=None):
    response, accessionid= complete(question, model, rag, article_chosen)
    res_text = response[0].RESPONSE
    st.markdown(res_text)
    if rag == 1:
        text= f"Associated NCBI AccessionID: {accessionid} that may be useful"
        st.write(text)

#Main code

st.title("Asking Questions on NCBI Articles with Snowflake Cortex:")
st.write("""You can ask questions and decide if you want to the NCBI Articles for context or allow the model to create their own response.""")
docs_available = session.sql("select distinct ACCESSIONID from PMC_DATA.PMC_OA_OPENDATA.PMC_SERVICE_VW limit 10").collect()
list_docs = []
for doc in docs_available:
    list_docs.append(doc["ACCESSIONID"])

rag = st.checkbox('Use articles as context?')

if rag:
    use_rag = 1
    article_chosen=st.selectbox("Choose an Article for Context",list_docs)
else:
    use_rag = 0
    article_chosen=None

#Here you can choose what LLM to use. Please note that they will have different cost & performance
model = st.selectbox('Select your model:',(
                                    'mixtral-8x7b',
                                    'snowflake-arctic',
                                    'mistral-large',
                                    'llama3-8b',
                                    'llama3-70b',
                                    'reka-flash',
                                     'mistral-7b',
                                     'llama2-70b-chat',
                                     'gemma-7b'))

question = st.text_input("Enter question", placeholder="Is there articles related to type-II diabetes?", label_visibility="collapsed")

if question:
    display_response (question, model, use_rag, article_chosen)