<a href="https://colab.research.google.com/github/jingtang10/odhs-genai/blob/main/%5BODHS_D2_T1_05A_CL1%5D_Gemini_WHO_ANC_Guidelines_Eval_Data_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WHO recommendations on antenatal care for a positive pregnancy experience


### This notebook processes the WHO's recommendations on antenatal care, extracting key information to generate a collection of multiple-choice questions (MCQs) and short-answer questions aimed at fostering positive pregnancy experiences.

## Install the Gemini API SDK

The Python SDK for the Gemini API is contained in the [google-generativeai package](https://pypi.org/project/google-generativeai/). Install the dependency using pip.



In [None]:
! pip install -U -q google-generativeai

## Import the libraries


In [None]:
import os
import textwrap
import pandas as pd

import google.generativeai as genai

from typing import Union, List

from google.api_core import retry, exceptions


## Set up your API key


To use the Gemini API, you'll need an API key. Store your API key in Colab Secrets named `GOOGLE_API_KEY`.   
If you don't have an API key or need help creating a Colab Secrets, see the [Authentication](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Authentication.ipynb) guide.

In [None]:
# passing the API key
try:
  from google.colab import userdata
  GOOGLE_API_KEY = userdata.get ('GOOGLE_API_KEY')
  genai.configure(api_key=GOOGLE_API_KEY)
except ImportError:
  pass

## Upload WHO's antenatal care guidelines file

Use the [`upload_file`](https://ai.google.dev/gemini-api/docs/document-processing?lang=python#upload-document) API to temporarily store WHO's antenatal care guidelines pdf file. This process produces a file reference that can be used to prompt a model.

In [5]:
file_path = "9789241549912-eng.pdf" #@param {type:"string"}
display_name = "WHO recommendations on antenatal care for a positive pregnancy experience"

pdf_file = genai.upload_file(
    path=file_path,
    display_name=display_name,
)

file_ref = genai.get_file(name=pdf_file.name)

KeyboardInterrupt: 

## Select a suitable gemini model

The Gemini API offers different models that are optimized for specific use cases. Here's a [brief overview of Gemini variants](https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*cyblbc*_up*MQ..&gclid=Cj0KCQjwsJO4BhDoARIsADDv4vB5i1gAcxplfDp37YCnHdYV1vFF_11JvdxwPjqBjujKpgMKrmDHM9caAlGLEALw_wcB) that are available.
To ensure your prompts work correctly, check the input and output token limits. Make sure your document and desired output fit within these limits.

We will be using [Gemini 1.5 pro](https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*cyblbc*_up*MQ..&gclid=Cj0KCQjwsJO4BhDoARIsADDv4vB5i1gAcxplfDp37YCnHdYV1vFF_11JvdxwPjqBjujKpgMKrmDHM9caAlGLEALw_wcB#gemini-1.5-pro) model in this tutorial to generate the questionnaires.

In [None]:
model_name = "gemini-1.5-pro-latest" #@param ["gemini-1.5-pro-latest", "gemini-1.5-pro", "gemini-1.5-flash-latest", "gemini-1.5-flash"]
model_info = genai.get_model(f'models/{model_name}')

print(f"Model: {model_name}")
print(f"Input Token Limit: {model_info.input_token_limit}")
print(f"Output Token Limit: {model_info.output_token_limit}")

## Let's define some utility functions to perform repetitative tasks:

1.   Define the model
2.   Call the model and generate text



In [None]:
def model_def(model_name, safety_settings, tools=None):
  model = genai.GenerativeModel(
    model_name=model_name,
    safety_settings=safety_settings,
    tools=tools
  )
  return model

In [None]:
def format_items_for_prompt(items: list, item_type: str,model_name: str) -> str:
  """
  Formats a list of items (questions, etc.) into a string for the prompt.
  Handles different item types (e.g., "mcq", "sqa").
  """
  formatted_text = ""
  for item in items:
    formatted_text += f"Question: {item['Question']}\n"
    formatted_text += f"{item.get('Intervention', '')}\n"
    if item_type == "mcq":
      if model_name == 'gemini':
        formatted_text += "".join(
            [f"{option}: {item[option]}\n" for option in ["A", "B", "C", "D", "E"]]
        )
        formatted_text += f"Choice: {item.get('Choice', '')}\n\n" # Include Choice if present
      elif model_name == "gemma":
          formatted_text += "".join(
            [f"{option}: {item[option]}\n" for option in ["A", "B", "C", "D", "E"]]
        )
          formatted_text += f"Answer: \n"

    elif item_type == "sqa_eval":
      if model_name == 'gemini':
        formatted_text += f"Answer: {item.get('Answer', '')}\n\n"
      elif model_name == "gemma":
        formatted_text += f"Answer: {item.get('Answer_2b', '')}\n\n"

    elif item_type == "sqa":
      if model_name == 'gemini':
        formatted_text += f"Answer: {item.get('Answer', '')}\n\n"
      elif model_name == 'gemma':
        formatted_text += f"Answer: \n\n"

  return formatted_text

In [None]:
def call_model_and_extract(
    chat_model: genai.GenerativeModel,
    prompt: list,
    function_name: str,
    generation_config: dict = None,
    tool_config: dict = None,
) -> list:
  """Calls the language model and extracts the results from the function call."""

  try:
    response = generate_text(
        chat_model,
        prompt,
        generation_config=generation_config,
        tool_config=tool_config,
    )

    if response.candidates[0].content.parts[0].function_call:
      function_call = response.candidates[0].content.parts[0].function_call
      extracted_results = type(function_call).to_dict(function_call)["args"][
          function_name
      ]
      return extracted_results
    else:
      return []  # Return empty list if no function call


  except Exception as e:
    print(f"Error calling model: {e}")
    return []  # Return empty list on error

## Generate text

* The Gemini API's client library offers built-in retry mechanisms for handling transient errors.

* The `generate_text` function sends a message to the chat model with the given prompt and
  configurations, and returns the response. It includes retry logic to handle
  transient errors.

* For more info on error handling, take a look at the [error_handling quickstart](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Error_handling.ipynb).

In [None]:
@retry.Retry(
    predicate=retry.if_transient_error,
    initial=5,
    maximum=10,
    multiplier=2.0,
    timeout=100,
)

def generate_text(
    chat: genai.GenerativeModel,
    prompt: Union[List[str], str],
    generation_config: dict = None,
    tool_config: dict = None
):
  """Generates text using a chat model, with retry mechanism for transient errors.

  This function sends a prompt to a chat model and returns the generated response.
  It uses a retry decorator to handle transient errors, such as network issues,
  allowing the function to automatically retry the operation multiple times
  before giving up.

  Args:
    chat: The chat model object (an instance of `genai.Model`).
    prompt: The text prompt to send to the chat model. Can be a string or a list of strings.
    generation_config: (Optional) A dictionary containing configuration
        parameters for the text generation process. This might include settings
        like temperature, max tokens, etc. The specific format depends on the
        `chat` object (genai.Model).  See GenAI's documentation for details.
    tool_config: (Optional)  A dictionary containing configuration
        parameters for any tools that the chat model might use. The specific
        format depends on the `chat` object (genai.Model) and whether it
        supports tools.

  Returns:
    The response from the chat model (genai.Response).


  Raises:
    retry.RetryError: If the function fails to generate text after multiple
        retries due to persistent transient errors. The original exception
        that triggered the retries will be chained to the `RetryError`.
    Any other exception raised by `chat.generate_text`: If the `generate_text`
        method raises an exception that is not considered a transient error,
        the exception will be propagated directly without retries.
  """

  response = chat.send_message(
      prompt,
      generation_config=generation_config,
      tool_config=tool_config,
  )
  return response

## Customize Safety Settings

The Gemini API provides safety settings that you can adjust during the prototyping stage to determine if your application requires more or less restrictive safety configuration. You can adjust these settings across four filter categories to restrict or allow certain types of content.

To make this customization you must define a safety_settings and pass it to model initialization as below.

**Important:** To guarantee the Google commitment with the Responsible AI development and its [AI Principles](https://ai.google/responsibility/principles/), for some prompts Gemini will avoid generating the results even if you set all the filters to none.

In [None]:

HARM_CATEGORY_DANGEROUS = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_HARASSMENT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_HATE_SPEECH = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_SEXUALLY_EXPLICIT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_DANGEROUS_CONTENT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]

safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": HARM_CATEGORY_DANGEROUS,
        },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": HARM_CATEGORY_HARASSMENT,
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": HARM_CATEGORY_HATE_SPEECH,
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": HARM_CATEGORY_SEXUALLY_EXPLICIT,
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": HARM_CATEGORY_DANGEROUS_CONTENT,
    },
]

## Configure text generation


Every prompt you send to the model includes [parameters]((https://ai.google.dev/gemini-api/docs/models/generative-models#model-parameters)) that control how the model generates responses. You can use [genai.GenerationConfig](https://ai.google.dev/api/generate-content#generationconfig) to configure these parameters. If you don't configure the parameters, the model uses default options, which can vary by model.

In [None]:
temperature = 1 #@param {type:"slider", min:0, max:1, step:0.1}
top_p = 0.95 #@param {type:"slider", min:0, max:1, step:0.05}
top_k = 64 #@param {type:"integer"}
max_output_tokens = 8192 #@param {type:"integer"}

generation_config ={
    "temperature": 1,
    "top_p": 0.95,
    "top_k": 64,
    "max_output_tokens": 8192,
}

## Generate Multiple Choice Questions (MCQ's) for main 5 intervention types


The questions created from the World Health Organization's guidelines are divided equally among the five main categories of recommendations for pregnancy care:

1. Nutritional interventions
2. Maternal and fetal assessment
3. Preventive measures
4. Interventions for common physiological symptoms
5. Health systems interventions to improve the utilization and quality of antenatal care


### Utility function for generating questions

In [None]:
def generate_questions(
    chat_model: genai.GenerativeModel,
    file_ref: genai.types.file_types.File,
    base_prompt: str,
    total_questions: int,
    question_type: str,
    question_type_function_name: str,
    questions_per_prompt: int,
    generation_config: dict = None,
    tool_config: dict = None,
    model_name: str = "gemini"
):
  """
  Generates multiple-choice or short-answer questions based on a given file
  and prompt.

  Args:
    chat_model: `genai.GenerativeModel`. The ChatModel instance to use for
      generation.
    file_ref: `genai.types.file_types.File`. The File reference to use for
      questions need to be generated.
    base_prompt: str. The base prompt for generating questions.
    total_questions: int. The total number of questions to generate.
    question_type: str. The type of questions to generate. Can be either
      "mcq" or "sqa".
    questions_per_prompt: int. The number of questions to generate per API
      call.

    generation_config: dict. Configuration for text generation.
    tool_config: dict. Configuration for tool usage.

  Returns:
    A list of dictionaries, where each dictionary represents a generated
      question.
  """
  questions = []
  current_prompt = [file_ref, base_prompt]

  for prompt_index in range(total_questions // questions_per_prompt):
    # Append instruction to avoid repeating questions in subsequent calls after
    # the first prompt
    if prompt_index > 0:
      current_prompt.append(
          f"{base_prompt} Make sure NOT to generate questions which are "
          f"already present above."
      )

    per_prompt_questions = call_model_and_extract(
        chat_model,
        current_prompt,
        question_type_function_name,
        generation_config=generation_config,
        tool_config=tool_config,
    )

    print(f"Generated {len(per_prompt_questions)} questions in prompt {prompt_index+1}")

    questions.extend(per_prompt_questions)
    formatted_text = format_items_for_prompt(per_prompt_questions, question_type,model_name)
    current_prompt = [formatted_text]


  return questions


### Define the base prompt to generate multiple choice questions

**Note:** `num_mcq_per_prompt` should always evenly divide total_mcq_questions.
This ensures that we can generate the exact number of total questions
without any remainder. In this case, `300 ÷ 30 = 10`, which is a whole number.
If these values are changed, make sure to maintain this relationship.

In [None]:
total_mcq_questions = 100
num_mcq_per_prompt = 20
num_mcq_per_intervention = num_mcq_per_prompt // 5

mcq_base_prompt = f"""
Prompt:

Task: Generate {num_mcq_per_prompt} multiple choice questions based on the provided extracted text corpus.

Guidelines:

Paraphrase: Avoid direct quotes and rephrase the text to create unique questions.
Question Structure: Ensure each question is clear and concise, avoiding ambiguity.
Answer Choices:
  1. Provide at least five answer choices (A, B, C, D, E) for each question.
  2. Make all options plausible, but only one correct.
  3. Avoid using "All of the above" or "None of the above" as options.
Correct Answer: Clearly indicate the correct answer as option from (A, B, C, D, E) for each question.

Question Balance: Each intervention should have exactly {num_mcq_per_intervention} questions not less or more.

1. Nutritional interventions
2. Maternal and fetal assessment
3. Preventive measures
4. Interventional measures for common physiological symptoms
5. Health systems interventions

Example:

Intervention: Nutritional interventions (Intervention should be one from above 5 mentioned intervention topics)
Question: In non-endemic areas, when is preventive anthelminthic treatment recommended for pregnant women?
Choices:
A. During the first trimester
B. Throughout pregnancy
C. During the second trimester
D. During the second and third trimesters
E. During the third trimester
Correct Answer: D

"""

print(mcq_base_prompt)

### Function calling

Function calling enhances the reliability of generative models by enabling them to produce [structured data outputs](). This is achieved by leveraging the API's function calling feature to define a strict schema, which ensures robust and predictable outputs.

With function calling your function and its parameters are described to the API as a `genai.protos.FunctionDeclaration`.

In most basic cases the SDK converts function parameter type annotations to a format the API understands `genai.protos.FunctionDeclaration`. It's better to define them explicitly wherever possible.

#### MCQ single question schema

This schema defines the structure for a multiple-choice question (MCQ) format.   
Each MCQ encompasses certain elements:

* Context: A concise description of the relevant intervention or topic.
* Query: The heart of the question itself.
* Choices: Five distinct answer options, labeled A through E.
* Solution: The definitive answer to the query.

All fields are required and stored as text. This standardized format ensures consistency when generating or processing MCQs, making it easier to generate, validate, or work with question data in your application.

In [None]:
mcq_schema = genai.protos.Schema(
    type = genai.protos.Type.OBJECT,
    properties = {
        'Intervention':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Question':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'A': genai.protos.Schema(type=genai.protos.Type.STRING),
        'B': genai.protos.Schema(type=genai.protos.Type.STRING),
        'C': genai.protos.Schema(type=genai.protos.Type.STRING),
        'D': genai.protos.Schema(type=genai.protos.Type.STRING),
        'E': genai.protos.Schema(type=genai.protos.Type.STRING),
        'Choice': genai.protos.Schema(type=genai.protos.Type.STRING)
    },
    required=['Intervention', 'Question', 'A', 'B', 'C', 'D', 'E', 'Choice']
)

#### Array Schema

In [None]:
mcqs_schema = genai.protos.Schema(
    type=genai.protos.Type.ARRAY,
    items=mcq_schema
)

#### Define as function declaration

In [None]:
mcq_database = genai.protos.FunctionDeclaration(
    name="mcq_database",
    description=textwrap.dedent("""\
        Adds interventions and questions with multiple choices and its answers to the database.
        """),
    parameters=genai.protos.Schema(
        type=genai.protos.Type.OBJECT,
        properties = {
            'multiple_choice_questions': mcqs_schema,
        }
    )
)

### Define the model with the function declaration

In [None]:
model = model_def(model_name, safety_settings, tools=[mcq_database])

### Create a chat session

The Gemini API enables you to have freeform conversations across multiple turns. The `ChatSession` class will store the conversation history for multi-turn interactions.

In [None]:
chat = model.start_chat(history=[])

### Generate text in chat session


This code generates a series of multiple-choice questions (MCQs).

1. The code generates MCQs, aiming for a specified total number. It uses a base prompt and a file reference as initial input.

2. In each iteration, it calls an Gemini model to generate questions, extracts the structured data, and adds it to a list. To avoid repetition, it includes previously generated questions in subsequent prompts.

3. It also handles potential errors during the API requests, ensuring the process continues even if individual calls fail.

In [None]:
mcq_questions = generate_questions(
    chat_model=chat,
    file_ref=file_ref,
    base_prompt=mcq_base_prompt,
    total_questions=total_mcq_questions,
    question_type="mcq",
    question_type_function_name="multiple_choice_questions",
    questions_per_prompt=num_mcq_per_prompt,
    generation_config=generation_config,
    tool_config={"function_calling_config": {"mode": "ANY"}},
    model_name="gemini"
)

### Convert the MCQ's into DataFrame

In [None]:
mcq_questions[0]

In [None]:
mcq_df = pd.DataFrame(mcq_questions, columns=["Intervention", "Question", "A", "B", "C", "D", "E", "Choice"])
mcq_df

In [None]:
# Write the dataframe to a CSV file
mcq_df.to_csv("mcq_questions.csv", index=False)

## Generate set of short-form answer questions

The questions created from the World Health Organization's guidelines are divided equally among the five main categories of recommendations for pregnancy care:

1. Nutritional interventions
2. Maternal and fetal assessment
3. Preventive measures
4. Interventions for common physiological symptoms
5. Health systems interventions to improve the utilization and quality of antenatal care


### Define the base prompt for the task

**Note:** num_mcq_per_prompt should always evenly divide total_mcq_questions.
This ensures that we can generate the exact number of total questions
without any remainder. In this case, `300 ÷ 30 = 10`, which is a whole number.
If these values are changed, make sure to maintain this relationship.

In [None]:
total_number_sqa_questions = 100
num_sqa_per_prompt = 20
num_sqa_per_intervention = num_sqa_per_prompt // 5

sqa_base_prompt= f"""
Please generate {num_sqa_per_prompt} short answer questions from extracted_text corpus provided.
Answer should not be less than 20 words.
While generating text please make sure to paraphrase the text from the original text.

We have 5 intervention topics:
1. Nutritional interventions
2. Maternal and fetal assessment
3. Preventive measures
4. Interventions from common physiological symptoms
5. Health systems interventions to improve the utilization and quality of antenatal care.

Each should have exactly {num_sqa_per_intervention} question not less or more.
Mention the intervention topic from the above 5
Mention the question
Mention the answer

For Example each Short Questions Answer format should look like this:

Intervention: Interventions from common physiological symptoms
Question: What advice can help manage varicose veins during pregnancy?
Answer: Women should be advised to wear compression stockings as prescribed by their healthcare provider and to avoid standing or sitting for long periods.
"""

print(sqa_base_prompt)

### Function calling

#### Single Short Question Answer Schema

This schema defines the structure for short answer questions (SQAs). It includes the following fields:

* Intervention: The topic or subject matter related to the question.
* Question: The actual short answer question.
* Answer: The correct answer to the question.


In [None]:
sqa = genai.protos.Schema(
    type = genai.protos.Type.OBJECT,
    properties = {
        'Intervention':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Question':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Answer': genai.protos.Schema(type=genai.protos.Type.STRING)
    },
    required=['Intervention', 'Question', 'Answer']
)

#### Now declare each SQA in `ARRAY` type

In [None]:
sqas = genai.protos.Schema(
    type=genai.protos.Type.ARRAY,
    items=sqa
)

#### Function Declaration

In [None]:
sqa_database = genai.protos.FunctionDeclaration(
    name="sqa_database",
    description=textwrap.dedent("""\
        Adds interventions, questions and its short answers to the database.
        """),
    parameters=genai.protos.Schema(
        type=genai.protos.Type.OBJECT,
        properties = {
            'short_answer_questions': sqas,
        }
    )
)

### Define the model with the function declaration

In [None]:
model = model_def(model_name, safety_settings, tools=[sqa_database])

### Create a chat session

The Gemini API enables you to have freeform conversations across multiple turns. The ChatSession class will store the conversation history for multi-turn interactions.

In [None]:
chat = model.start_chat(history=[])

### Generate text in chat session


This code generates a series of Short Question Answers (SQAs).

1. The code generates SQAs, aiming for a specified total number. It uses a base prompt and a file reference as initial input.

2. In each iteration, it calls an Gemini model to generate questions, extracts the structured data, and adds it to a list. To avoid repetition, it includes previously generated questions in subsequent prompts.

3. It also handles potential errors during the API requests, ensuring the process continues even if individual calls fail.

In [None]:
sqa_questions = generate_questions(
    chat_model=chat,
    file_ref=file_ref,
    base_prompt=sqa_base_prompt,
    total_questions=total_number_sqa_questions,
    question_type_function_name="short_answer_questions",
    question_type="sqa",
    questions_per_prompt=num_sqa_per_prompt,
    generation_config=generation_config,
    tool_config={"function_calling_config": {"mode": "ANY"}},
    model_name = 'gemini'
)

### Convert to DataFrame

In [None]:
sqa_df = pd.DataFrame(sqa_questions)
sqa_df

In [None]:
# Write the dataframe to a CSV file
sqa_df.to_csv("sqa_questions.csv", index=False)