In [9]:
from datasets import load_dataset, get_dataset_config_names 
import os
import json

from config import GEMINI_API_KEY 

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI

In [10]:
if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = GEMINI_API_KEY

In [11]:
dataset_path = "keeve101/sutd-qa-dataset"

config_names = get_dataset_config_names(dataset_path)

split = "train"

datasets = {config_name: load_dataset(dataset_path, config_name, split=split).shuffle(seed=0) for config_name in config_names}

In [12]:
question_paraphrase_prompt_template = PromptTemplate(
    input_variables=["QUESTION", "ANSWER"],
    template = """
You are a helpful assistant that paraphrases both the question and answer in a Q&A pair. Given one Q&A pair, generate one new version that uses different wording but keeps the same meaning. Return the output as a JSON object with "question" and "answer" fields.

Question: Where is SUTD located?  
Answer: SUTD is located at 8 Somapah Road, Singapore 487372.

Paraphrased Question Answer Pair:
{{
  "question": "What is the address of SUTD?",
  "answer": "The address of SUTD is 8 Somapah Road, Singapore 487372."
}}

---

Question: {QUESTION}
Answer: {ANSWER}

Paraphrased Question Answer Pair:
"""
)

In [13]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0.5 # We want some randomness
)

json_output_parser = JsonOutputParser()

chain = llm | json_output_parser

In [14]:
output_filepath = "validation_test_splits_{config_name}.jsonl"

In [15]:
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_message

lexicon_question_search_query_pairs = {}

@retry(
    stop=stop_after_attempt(25),
    wait=wait_exponential(multiplier=1, min=1, max=60),
    retry=retry_if_exception_message(match="429|ResourceExhausted")
)
def invoke_chain(chain, prompt):
    return chain.invoke(prompt)

In [17]:
num_generations = 40

for config_name, dataset in datasets.items():
    for idx, example in enumerate(dataset):
        question = example["question"]
        answer = example["answer"]
        
        with open(output_filepath.format(config_name=config_name), "a") as f:
            result = invoke_chain(chain, question_paraphrase_prompt_template.format(QUESTION=question, ANSWER=answer))
            
            f.write(json.dumps({
                "question": question,
                "answer": answer,
            }) + "\n")
        
        if idx > num_generations:
            break

Retrying langchain_google_genai.chat_models._chat_with_retry.<locals>._chat_with_retry in 2.0 seconds as it raised ResourceExhausted: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerMinutePerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_dimensions {
    key: "model"
    value: "gemini-2.0-flash"
  }
  quota_value: 15
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 17
}
].
Retrying langchain_google_genai.chat_models._chat_with_retry.<locals>._chat_with_retry in 2.0 seconds as it raised ResourceExhausted: 429 You exceeded your current quota, please check your plan and billing