In [18]:
import numpy as np
import pandas as pd
import pprint
import google.generativeai as palm
import tiktoken

### Grab an API Key

To get started, you'll need to [create an API key](https://developers.generativeai.google/tutorials/setup).

In [2]:
palm.configure(api_key='AIzaSyCXA0HyEyG4kQ1f_vTWqX2wY_PIAQdDaQY')

In [11]:
from langchain.embeddings import GooglePalmEmbeddings
from langchain.llms import GooglePalm

## Embedding models

In [13]:
models = [m for m in palm.list_models() if 'embedText' in m.supported_generation_methods]
EMBEDDING_MODEL = models[0].name
print(EMBEDDING_MODEL)

models/embedding-gecko-001


## Text generation

In [14]:
models = [m for m in palm.list_models() if 'generateText' in m.supported_generation_methods]
COMPLETIONS_MODEL = models[0].name
print(COMPLETIONS_MODEL)

models/text-bison-001


# Question answer against your own data- xlsx/csv/etc

### Reading open source data - from openai

In [49]:
#Reading open source data - from openai
df = pd.read_csv('https://cdn.openai.com/API/examples/data/olympics_sections_text.csv')
df['content']=df['title']+' '+df['heading']+ ' '+df['content']
df = df.set_index(["title", "heading"])
print(f"{len(df)} rows in the data.")
df.sample(5)

3964 rows in the data.


Unnamed: 0_level_0,Unnamed: 1_level_0,content,tokens
title,heading,Unnamed: 2_level_1,Unnamed: 3_level_1
Chile at the 2020 Summer Olympics,Artistic,Chile at the 2020 Summer Olympics Artistic Chi...,79
France at the 2020 Summer Olympics,Weightlifting,France at the 2020 Summer Olympics Weightlifti...,126
Canoeing at the 2020 Summer Olympics – Women's K-1 200 metres,Schedule,Canoeing at the 2020 Summer Olympics – Women's...,43
Boxing at the 2020 Summer Olympics – Men's middleweight,Qualification,Boxing at the 2020 Summer Olympics – Men's mid...,236
Cycling at the 2020 Summer Olympics – Qualification,Road cycling,Cycling at the 2020 Summer Olympics – Qualific...,166


### Getting embeddings

In [25]:
def get_embedding(text: str, model: str=EMBEDDING_MODEL) -> list[float]:
    result = palm.generate_embeddings(
      model=EMBEDDING_MODEL,
      text=text
    )
    return result["embedding"]

def compute_doc_embeddings(df: pd.DataFrame) -> dict[tuple[str, str], list[float]]:
    """
    
    Return a dictionary that maps between each embedding vector and the index of the row that it corresponds to.
    """
    return {
        idx: get_embedding(r.content) for idx, r in df.iterrows()
    }

In [50]:

document_embeddings = compute_doc_embeddings(df.head(100))

### We could directly use the dataframe for QA, but because of the token limitations for the LLMs, we have to first get a most relevant context based on the question from the given dataframe,and then pass that with question in the formatted prompt to get the desired answer.

In [33]:
def vector_similarity(x: list[float], y: list[float]) -> float:
    """
    Returns the similarity between two vectors.
    
    Because OpenAI Embeddings are normalized to length 1, the cosine similarity is the same as the dot product.
    """
    return np.dot(np.array(x), np.array(y))

def order_document_sections_by_query_similarity(query: str, contexts: dict[(str, str), np.array]) -> list[(float, (str, str))]:
    """
    Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddings
    to find the most relevant sections. 
    
    Return the list of document sections, sorted by relevance in descending order.
    """
    query_embedding = get_embedding(query)
    
    document_similarities = sorted([
        (vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items()
    ], reverse=True)
    
    return document_similarities

In [60]:
#Configs
MAX_SECTION_LEN = 500
SEPARATOR = "\n* "
ENCODING = "gpt2"  # encoding for text-davinci-003

encoding = tiktoken.get_encoding(ENCODING)
separator_len = len(encoding.encode(SEPARATOR))

COMPLETIONS_API_PARAMS = {
    # We use temperature of 0.0 because it gives the most predictable, factual answer.
    "temperature": 0.0,
    "model": COMPLETIONS_MODEL,
}

### Formatted prompt creation

In [61]:
def construct_prompt(question: str, context_embeddings: dict, df: pd.DataFrame) -> str:
    """
    Fetch relevant 
    """
    most_relevant_document_sections = order_document_sections_by_query_similarity(question, context_embeddings)
    
    chosen_sections = []
    chosen_sections_len = 0
    chosen_sections_indexes = []
     
    for _, section_index in most_relevant_document_sections:
        # Add contexts until we run out of space.        
        document_section = df.loc[section_index]
        
        chosen_sections_len += document_section.tokens + separator_len
        if chosen_sections_len > MAX_SECTION_LEN:
            break
            
        chosen_sections.append(SEPARATOR + document_section.content.replace("\n", " "))
        chosen_sections_indexes.append(str(section_index))
            
    # Useful diagnostic information
    print(f"Selected {len(chosen_sections)} document sections:")
    print("\n".join(chosen_sections_indexes))
    
    header = """Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say "I don't know."\n\nContext:\n"""
    
    return header + "".join(chosen_sections) + "\n\n Q: " + question + "\n A:"

### Getting desired answer 

In [77]:
def answer_query_with_context(
    query: str,
    df: pd.DataFrame,
    document_embeddings: dict[(str, str), np.array],
    show_prompt: bool = False
) -> str:
    prompt = construct_prompt(
        query,
        document_embeddings,
        df
    )
    
    if show_prompt:
        print(prompt)

    response = palm.generate_text(
                prompt=prompt,
                **COMPLETIONS_API_PARAMS
            )

    return response.candidates[0]['output'].strip(" \n")

In [78]:
response=answer_query_with_context("Why was the 2020 Summer Olympics originally postponed?", df, document_embeddings)
response

Selected 1 document sections:
('2020 Summer Olympics', 'Postponement to 2021')


'due to the COVID-19 pandemic'