In [None]:
from typing import List

import pandas as pd
from tqdm.autonotebook import tqdm

from google.protobuf.json_format import MessageToDict
from google.api_core.client_options import ClientOptions
from google.cloud import discoveryengine_v1beta as discoveryengine

from vertexai.preview.language_models import TextGenerationModel

In [None]:
# Input parameter start here

In [None]:
prompt_generation_template = """
<INSTRUCTIONS>
1. ALL INSTRUCTIONS must be adhered to exactly when generating the Truthful Response.
2. Look through the text in<CONTEXT></CONTEXT> and answer the user's question <QUESTION></QUESTION>
3. The text in <CONTEXT></CONTEXT> may contain documents that are not relevant to the user's query (for example, document titles or snippets could be about a different topic), you should ignore those results.
4. Think step-by-step. First, determine the the set of documents and snippets in the context that are relevant to the user's query. Then, synthesize them to create a detailed truthful response.
5. If none of the information in <CONTEXT></CONTEXT> are relevant to the user's question, explain that to the user instead of making up an answer.
6. Use "they/them" by default, avoid gendered identifiers if unspecified. Otherwise, use the pronoun in the person summary.
7. Be helpful to the user, but avoid workplace violations.
8. Repond in same language as the <QUESTION>
</INSTRUCTIONS>
<CONTEXT>
{}
</CONTEXT>
<QUESTION>
{}
<QUESTION>

Truthful Response:

"""

In [None]:
data_store_id = '' # vertex search data store id

In [None]:
max_search_result = 5
max_extractive_segment = 3

In [None]:
max_docs = 1
max_segment = 1
min_relevance = 0

In [None]:
# use_model = 'tune'
# model_name = ''

use_model = 'base'
model_name = 'text-bison@latest'

In [None]:
input_file = ''
output_file = f"{input_file.split('.')[0]}_gen_{model_name.split('/')[-1]}.csv"

In [None]:
# Input parameter stop here

In [None]:
df = pd.read_csv(input_file)

In [None]:
df.head()

In [None]:
def search(
    project_id: str,
    location: str,
    data_store_id: str,
    search_query: str,
) -> List[discoveryengine.SearchResponse]:
    #  For more information, refer to:
    # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
    client_options = (
        ClientOptions(api_endpoint=f"{location}-discoveryengine.googleapis.com")
        if location != "global"
        else None
    )

    # Create a client
    client = discoveryengine.SearchServiceClient(client_options=client_options)

    # The full resource name of the search engine serving config
    # e.g. projects/{project_id}/locations/{location}/dataStores/{data_store_id}/servingConfigs/{serving_config_id}
    serving_config = client.serving_config_path(
        project=project_id,
        location=location,
        data_store=data_store_id,
        serving_config="default_config",
    )

    # Optional: Configuration options for search
    # Refer to the `ContentSearchSpec` reference for all supported fields:
    # https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1.types.SearchRequest.ContentSearchSpec
    content_search_spec = {
        'extractive_content_spec': {
            'max_extractive_segment_count': max_extractive_segment
        },
    }

    # Refer to the `SearchRequest` reference for all supported fields:
    # https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1.types.SearchRequest
    request = discoveryengine.SearchRequest(
        serving_config=serving_config,
        query=search_query,
        
        page_size=max_search_result,
        content_search_spec=content_search_spec,
        
        query_expansion_spec=discoveryengine.SearchRequest.QueryExpansionSpec(
            condition=discoveryengine.SearchRequest.QueryExpansionSpec.Condition.AUTO,
        ),
        spell_correction_spec=discoveryengine.SearchRequest.SpellCorrectionSpec(
            mode=discoveryengine.SearchRequest.SpellCorrectionSpec.Mode.AUTO
        ),
    )

    response = client.search(request)

    return response

In [None]:
location = ''
project_id = ''

In [None]:
def prepare_context(response):
    
    context = ''
    for i, r in enumerate(response.results):
        dd = MessageToDict(r.document._pb)
        doc = dd['derivedStructData']
        
        context += f"#{i} {doc['link']}\n"
        context += f"{doc['extractive_segments'][0]['content']}"
                
    return context

In [None]:
if use_model == 'tune':
    summarizer = TextGenerationModel.get_tuned_model(model_name)
else:
    summarizer = TextGenerationModel.from_pretrained(model_name)

In [None]:
result_df = pd.DataFrame()
result_df['original_question'] = df['original_question']
result_df['ground_truth_answer'] = df['ground_truth_answer']
result_df['generated_answer'] = None

In [None]:
for index, row in tqdm(df.iterrows()):
    
    search_query = row['original_question']
    print(index, search_query)
    
    response = search(
        project_id=project_id,
        location=location,
        search_query=search_query,
        data_store_id=data_store_id
    )
    
    context = prepare_context(response)

    input_prompt = prompt_generation_template.format(context, search_query)
    response = summarizer.predict(input_prompt, temperature=0, max_output_tokens=1024)

    result_df.loc[index, 'generated_answer'] = response.text

In [None]:
result_df.head()

In [None]:
result_df.to_csv(output_file, index=False)