<a href="https://colab.research.google.com/github/jingtang10/odhs-genai/blob/main/%5BODHS_D2_T1_05A_CL2_2%5D_Gemma_WHO_ANC_Guidelines_Eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Evaluation of Questions and Answers with Gemma

## Install required libraries

In [None]:
! pip install -qU google-generativeai
! pip install -qU transformers
! pip install 'accelerate>=0.26.0'

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m84.8 MB/s[0m eta [36m0:00:00[0m


In [None]:
import os
import textwrap
import pandas as pd

import google.generativeai as genai

from typing import Union, List

from google.api_core import retry, exceptions

## Import required libraries

In [None]:
import torch
import re

from transformers import AutoTokenizer, AutoModelForCausalLM

## Setup

### Select colab runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:


1.   In the upper-right of the Colab window, select **▾ (Additional connection options).**
2.   Select **Change runtime type.**
3.   Under **Hardware accelerator**, select **T4 GPU**.



### Gemma setup


**Before you dive into the tutorial, let's get you set up with Gemma:**

1. **Hugging Face Account:** If you don't already have one, you can create a free Hugging Face account by clicking here.
2. **Gemma Model Access:** Head over to the [Gemma model page](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.
3. **Colab with Gemma Power:** For this tutorial, you'll need a Colab runtime with enough resources to handle the Gemma 2B model. Choose an appropriate runtime when starting your Colab session.
4. **Hugging Face Token:** Generate a Hugging Face access (preferably write permission) token by clicking here. You'll need this token later in the tutorial.


**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**

### Configure your Hugging Face token


Add your Hugging Face token to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel.
2. Create a new secret with the name HF_TOKEN.
3. Copy/paste your token key into the Value input box of HF_TOKEN.
4. Toggle the button on the left to allow notebook access to the secret.

In [None]:
# import os
from google.colab import userdata

# Set the Hugging Face token as an environment variable
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

### Logging into Hugging Face Hub

Next, you'll have to log into the Hugging Face Hub using your access token. This will allow us to download the Gemma model.

In [None]:
from huggingface_hub import login

login(os.environ["HF_TOKEN"])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## Select the model

In [None]:
gemma_model_name = "google/gemma-2-2b-it" #@param ["google/gemma-2-2b-it","google/gemma-2-9b-it"]

Every prompt you send to the model includes parameters that control how the model generates responses.

- `temperature`: Controls the randomness of the model's output. A higher value (closer to 1) makes the output more diverse, while a lower value (closer to 0) makes it more focused and deterministic.
- `top_p`: Implements nucleus sampling, which selects the smallest set of tokens whose cumulative probability is greater than or equal to this value (e.g., 0.95), promoting diversity by considering multiple tokens.
- `max_new_tokens`: The maximum number of new tokens to generate in the model's output.
- `repetition_penalty`: Penalizes the model for repeating tokens that have already been generated. A value above 1 discourages repetition.
- `no_repeat_ngram_size`: Prevents the model from repeating phrases of this size (e.g., 2-gram means no repeated pairs of words).


In [None]:
temperature = 0.7 #@param {type:"slider", min:0, max:1, step:0.1}
top_p = 0.95 #@param {type:"slider", min:0, max:1, step:0.05}
max_new_tokens = 256 #@param {type:"integer"}
repetition_penalty = 1.2 #@param {type:"slider", min:1, max:2, step:0.1}
no_repeat_ngram_size = 2 #@param {type:"integer"}


### Utility Functions

In [None]:
def model(model_name:str):
  """
    Loads a tokenizer and model for a specified pre-trained causal language model.

    Args:
        model_name (str): The name or path of the pre-trained model to load.

    Returns:
        tuple: A tuple containing:
            - tokenizer: The tokenizer corresponding to the specified model,
              loaded using AutoTokenizer.
            - model: The pre-trained causal language model loaded using AutoModelForCausalLM.

    The model is loaded with automatic device mapping (i.e., device_map="auto")
    and uses bfloat16 precision (torch_dtype=torch.bfloat16).
    """
  tokenizer = AutoTokenizer.from_pretrained(model_name)
  model = AutoModelForCausalLM.from_pretrained(
          model_name,
          device_map="auto",
          torch_dtype=torch.bfloat16,)
  return tokenizer, model

In [None]:
def generate(
    model,
    input_ids
)-> torch.Tensor:
  """
    Generates text using a pre-trained language model with customizable parameters.

    Args:
        model: The pre-trained language model used for text generation.
        input_ids: The input tokens for the model, typically the tokenized input text.

    Returns:
        outputs: The generated text tokens from the model based on the given parameters.
    """
  outputs = model.generate(**input_ids,
                           max_new_tokens = max_new_tokens,
                           temperature = temperature,
                           top_p = top_p,
                           repetition_penalty=repetition_penalty,
                           no_repeat_ngram_size=no_repeat_ngram_size)
  return outputs




In [None]:
def get_option_or_answer(
    output,
    item_type: str
) -> str:
    """
    Extracts and returns the option or answer from the provided text output based on the specified item type.

    Args:
        output (str): The output text containing the answer, assumed to have the format "Answer: <content>".
        item_type (str): The type of item, either "mcq" for multiple-choice questions or "sqa" for short-answer questions.

    Returns:
        str: The extracted option (first character after "Answer:") if `item_type` is "mcq".
             For "sqa", returns the full answer after "Answer:", with extra whitespace and asterisks removed.
             If no answer is found for "mcq", returns "None".
    """
    if item_type == "mcq":
        option = output.split("Answer:")[1].strip()
        option = option.replace("*", "")

        if option == "":
            return "None"

        return option[0]
    elif item_type == "sqa":
        answer = output.split("Answer:")[1].strip()
        answer = answer.replace("\n", " ")
        answer = answer.replace("*", "")
        answer = " ".join(answer.split())
        return answer


In [None]:
def merge_df(sqa_eval_df: pd.DataFrame, sqa_eval_gemma_df: pd.DataFrame) -> pd.DataFrame:
    """
    Merges two DataFrames containing evaluation results from different models (Gemini and Gemma 2b)
    based on a common 'Question' column and returns the merged DataFrame.

    Parameters:
    ----------
    sqa_eval_df : pd.DataFrame
        The DataFrame containing evaluation data for the Gemini model.
        Expected to have columns: 'Question', 'Grade', and 'Answer'.

    sqa_eval_gemma_df : pd.DataFrame
        The DataFrame containing evaluation data for the Gemma 2b model.
        Expected to have columns: 'Question', 'Grade', 'Answer', 'Context', and 'Intervention'.
        Columns 'Context' and 'Intervention' will be dropped in the merged DataFrame.

    Returns:
    -------
    pd.DataFrame
        A DataFrame resulting from an inner merge on 'Question' between the two input DataFrames.
        The output DataFrame will rename the overlapping columns as follows:
            - 'Grade_x' to 'Grade_gemini'
            - 'Answer_x' to 'Answer_gemini'
            - 'Grade_y' to 'Grade_gemma_2b'
            - 'Answer_y' to 'Answer_gemma_2b'
    """
    sqa_eval_gemma_df.drop(columns=['Context', 'Intervention'], inplace=True)
    result = pd.merge(sqa_eval_df, sqa_eval_gemma_df, on='Question', how='inner')
    result = result.rename(columns={
        "Grade_x": "Grade_gemini",
        "Answer_x": "Answer_gemini",
        "Grade_y": "Grade_gemma_2b",
        "Answer_y": "Answer_gemma_2b"
    })
    return result


In [None]:
def get_metrics(result):
  perfect_answer = 0
  better_answer = 0
  worst_answer = 0
  for idx in range(len(result['Grade_gemini'])):
    # gemma and gemini got same grades
    if result['Grade_gemini'][idx] == result['Grade_gemma_2b'][idx]:
      perfect_answer += 1
    # gemma got better grades compared gemini
    elif result['Grade_gemini'][idx] < result['Grade_gemma_2b'][idx]:
      better_answer += 1
    # gemma got worst grades compared to gemini
    else:
      worst_answer += 1

  print(f"The number of questions gemma and gemini got same grades is {perfect_answer}")
  print(f"The number of questions gemma got better grades compared gemini {better_answer}")
  print(f"The number of questions gemma got worst grades compared to gemini {worst_answer}")

  grade_difference = {}
  worst_answer = 0
  count = 0
  difference = []
  for idx in range(len(result['Grade_gemini'])):
    if result['Grade_gemini'][idx] > result['Grade_gemma_2b'][idx]:
      diff = result['Grade_gemini'][idx] - result['Grade_gemma_2b'][idx]
      grade_difference.setdefault(diff, []).append(result['Question'][idx])
  question_counts = {diff: len(questions) for diff, questions in grade_difference.items()}
  print("Grade differences and counts between Gemma and Gemini for questions where Gemma underperformed.")
  print(question_counts)

In [None]:
mcq_base_prompt = textwrap.dedent('''\
Your task is to Carefully read the question and analyze each option based on your knowledge as a WHO expert.
Select the most appropriate option considering the context of antenatal care.

Instructions:
Choose the correct option (A, B, C, D, or E).
Please strictly provide only the option which is correct.
Don't give any "explanations", "option" and "asterisks" in the output.\n\n
''')
print(mcq_base_prompt)

Your task is to Carefully read the question and analyze each option based on your knowledge as a WHO expert.
Select the most appropriate option considering the context of antenatal care.

Instructions:
Choose the correct option (A, B, C, D, or E).
Please strictly provide only the option which is correct.
Don't give any "explanations", "option" and "asterisks" in the output.





## Download the selected model

In [None]:
tokenizer, model = model(gemma_model_name)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

In [None]:
# Selects the device based on the availability

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def model_def(model_name, safety_settings, tools=None):
  model = genai.GenerativeModel(
    model_name=model_name,
    safety_settings=safety_settings,
    tools=tools
  )
  return model

In [None]:
def format_items_for_prompt(items: list, item_type: str,model_name: str) -> str:
  """
  Formats a list of items (questions, etc.) into a string for the prompt.
  Handles different item types (e.g., "mcq", "sqa").
  """
  formatted_text = ""
  for item in items:
    formatted_text += f"Question: {item['Question']}\n"
    formatted_text += f"{item.get('Intervention', '')}\n"
    if item_type == "mcq":
      if model_name == 'gemini':
        formatted_text += "".join(
            [f"{option}: {item[option]}\n" for option in ["A", "B", "C", "D", "E"]]
        )
        formatted_text += f"Choice: {item.get('Choice', '')}\n\n" # Include Choice if present
      elif model_name == "gemma":
          formatted_text += "".join(
            [f"{option}: {item[option]}\n" for option in ["A", "B", "C", "D", "E"]]
        )
          formatted_text += f"Answer: \n"

    elif item_type == "sqa_eval":
      if model_name == 'gemini':
        formatted_text += f"Answer: {item.get('Answer', '')}\n\n"
      elif model_name == "gemma":
        formatted_text += f"Answer: {item.get('Answer_2b', '')}\n\n"

    elif item_type == "sqa":
      if model_name == 'gemini':
        formatted_text += f"Answer: {item.get('Answer', '')}\n\n"
      elif model_name == 'gemma':
        formatted_text += f"Answer: \n\n"

  return formatted_text

This code iterates through a list of multiple-choice questions, generates an answer for each using a 2B model, and decodes the output to extract the selected choice (A-E). The extracted choices are then appended to the answer list, handling cases where no valid choice is found.

In [None]:
mcq_df = pd.read_csv('mcq_questions.csv')

In [None]:
choice = []
mcq_df_dict_list = mcq_df.to_dict(orient="records")
for idx in range(len(mcq_df_dict_list)):
  formatted_text = format_items_for_prompt(items = [mcq_df_dict_list[idx]],
                                           item_type = "mcq",
                                           model_name = 'gemma')
  prompt = mcq_base_prompt + formatted_text
  input_ids = tokenizer(prompt, return_tensors="pt").to(device)
  outputs = generate(model, input_ids)
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  #retuns the option from the output
  option = get_option_or_answer(output,"mcq")
  choice.append(option)



#### Evaluation of the choices between the gemini and gemma model selected

In [None]:
count = 0
for i, mcq_question in mcq_df.iterrows():
  if choice[i].lower() == mcq_question['Choice'].lower():
    count+=1

In [None]:
print(f"number of the correct answers for {gemma_model_name} is {count}")

number of the correct answers for google/gemma-2-2b-it is 77


In [None]:
mcq_df['choice_2b'] = choice
mcq_df.to_csv('mcq_questions_gemma.csv',index=False)

## Generating Short Answer using GEMMA

In [None]:
sqa_df = pd.read_csv('sqa_questions.csv')

In [None]:
eval_base_prompt =textwrap.dedent("""\
Your task is to Carefully read the question and generate short answers based on your knowledge as a WHO expert.

Instructions for short answers:
The number words should be strictly between 10 to 20 words.
Refrain from giving single or two word answers
Please strictly provide only the answer which is correct.
""")
print(eval_base_prompt)

Your task is to Carefully read the question and generate short answers based on your knowledge as a WHO expert.

Instructions for short answers:
The number words should be strictly between 10 to 20 words.
Refrain from giving single or two word answers
Please strictly provide only the answer which is correct.



The Below code processes questions and interventions from a dataset by generating answers using a language model. It cleans and extracts the final answers, storing them in the final_answer list.

In [None]:
final_answer = []
sqa_df_dict_list = sqa_df.to_dict(orient="records")
for idx in range(len(sqa_df_dict_list)):
  formatted_text = format_items_for_prompt(items = [sqa_df_dict_list[idx]],
                                           item_type = "sqa",
                                           model_name = 'gemma')
  prompt = eval_base_prompt + formatted_text
  # print(prompt)
  input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')
  outputs = generate(model,input_ids)
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)

  # retuns the option from the output
  answer = get_option_or_answer(output,"sqa")
  # print(answer)
  final_answer.append(answer)



In [None]:
sqa_df['Answer_2b'] = final_answer

In [None]:
sqa_df.to_csv('sqa_questions.csv',index=False)

# Evaluation of Gemma SQA

## Evaluation functions

## Set up your API key


To use the Gemini API, you'll need an API key. Store your API key in Colab Secrets named `GOOGLE_API_KEY`.   
If you don't have an API key or need help creating a Colab Secrets, see the [Authentication](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Authentication.ipynb) guide.

In [None]:
# passing the API key
try:
  from google.colab import userdata
  GOOGLE_API_KEY = userdata.get ('GOOGLE_API_KEY')
  genai.configure(api_key=GOOGLE_API_KEY)
except ImportError:
  pass

## Select a suitable gemini model

The Gemini API offers different models that are optimized for specific use cases. Here's a [brief overview of Gemini variants](https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*cyblbc*_up*MQ..&gclid=Cj0KCQjwsJO4BhDoARIsADDv4vB5i1gAcxplfDp37YCnHdYV1vFF_11JvdxwPjqBjujKpgMKrmDHM9caAlGLEALw_wcB) that are available.
To ensure your prompts work correctly, check the input and output token limits. Make sure your document and desired output fit within these limits.

We will be using [Gemini 1.5 pro](https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*cyblbc*_up*MQ..&gclid=Cj0KCQjwsJO4BhDoARIsADDv4vB5i1gAcxplfDp37YCnHdYV1vFF_11JvdxwPjqBjujKpgMKrmDHM9caAlGLEALw_wcB#gemini-1.5-pro) model in this tutorial to generate the questionnaires.

In [None]:
model_name = "gemini-1.5-pro-latest" #@param ["gemini-1.5-pro-latest", "gemini-1.5-pro", "gemini-1.5-flash-latest", "gemini-1.5-flash"]
model_info = genai.get_model(f'models/{model_name}')

print(f"Model: {model_name}")
print(f"Input Token Limit: {model_info.input_token_limit}")
print(f"Output Token Limit: {model_info.output_token_limit}")

Model: gemini-1.5-pro-latest
Input Token Limit: 2000000
Output Token Limit: 8192


## Upload WHO's antenatal care guidelines file

Use the [`upload_file`](https://ai.google.dev/gemini-api/docs/document-processing?lang=python#upload-document) API to temporarily store WHO's antenatal care guidelines pdf file. This process produces a file reference that can be used to prompt a model.

In [None]:
file_path = "9789241549912-eng.pdf" #@param {type:"string"}
display_name = "WHO recommendations on antenatal care for a positive pregnancy experience"

pdf_file = genai.upload_file(
    path=file_path,
    display_name=display_name,
)

file_ref = genai.get_file(name=pdf_file.name)

In [None]:
HARM_CATEGORY_DANGEROUS = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_HARASSMENT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_HATE_SPEECH = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_SEXUALLY_EXPLICIT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]
HARM_CATEGORY_DANGEROUS_CONTENT = "BLOCK_NONE" # @param ["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_LOW_AND_ABOVE", "HARM_BLOCK_THRESHOLD_UNSPECIFIED"]

safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": HARM_CATEGORY_DANGEROUS,
        },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": HARM_CATEGORY_HARASSMENT,
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": HARM_CATEGORY_HATE_SPEECH,
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": HARM_CATEGORY_SEXUALLY_EXPLICIT,
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": HARM_CATEGORY_DANGEROUS_CONTENT,
    },
]

## Configure text generation


Every prompt you send to the model includes [parameters]((https://ai.google.dev/gemini-api/docs/models/generative-models#model-parameters)) that control how the model generates responses. You can use [genai.GenerationConfig](https://ai.google.dev/api/generate-content#generationconfig) to configure these parameters. If you don't configure the parameters, the model uses default options, which can vary by model.

In [None]:
temperature = 1 #@param {type:"slider", min:0, max:1, step:0.1}
top_p = 0.95 #@param {type:"slider", min:0, max:1, step:0.05}
top_k = 64 #@param {type:"integer"}
max_output_tokens = 8192 #@param {type:"integer"}

generation_config ={
    "temperature": 1,
    "top_p": 0.95,
    "top_k": 64,
    "max_output_tokens": 8192,
}

## Evaluate the generated MCQ questions with Gemini 1.5 model

In this section we try to evaluate the questions generated from Gemini model and what context has been to used to frame the Multiple Choice Questions.

In [None]:
def call_model_and_extract(
    chat_model: genai.GenerativeModel,
    prompt: list,
    function_name: str,
    generation_config: dict = None,
    tool_config: dict = None,
) -> list:
  """Calls the language model and extracts the results from the function call."""

  try:
    response = generate_text(
        chat_model,
        prompt,
        generation_config=generation_config,
        tool_config=tool_config,
    )

    if response.candidates[0].content.parts[0].function_call:
      function_call = response.candidates[0].content.parts[0].function_call
      extracted_results = type(function_call).to_dict(function_call)["args"][
          function_name
      ]
      return extracted_results
    else:
      return []  # Return empty list if no function call


  except Exception as e:
    print(f"Error calling model: {e}")
    return []  # Return empty list on error

In [None]:
def eval_questions(
    chat_model: genai.GenerativeModel,
    file_ref: genai.types.file_types.File,
    base_prompt: str,
    questions_dict_list: list,
    question_type: str,
    eval_type_function_name: str,
    eval_per_prompt: int,
    generation_config: dict = None,
    tool_config: dict = None,
    model_name: str = "gemini"
):
  evaluated_questions = []
  current_prompt = [file_ref, base_prompt]

  for idx in range(0, len(questions_dict_list), eval_per_prompt):
    formatted_text = format_items_for_prompt(
        questions_dict_list[idx:idx+eval_per_prompt],
        question_type,
        model_name
    )
    current_prompt.append(formatted_text)

    per_prompt_evaluated_questions = call_model_and_extract(
        chat_model,
        current_prompt,
        eval_type_function_name,
        generation_config=generation_config,
        tool_config=tool_config,
    )
    print(f"Evaluated {len(per_prompt_evaluated_questions)} questions in prompt {idx//eval_per_prompt+1}")

    evaluated_questions.extend(per_prompt_evaluated_questions)
    current_prompt = []

  return evaluated_questions


### Base Prompt

In [None]:
eval_mcq_base_prompt = textwrap.dedent("""\
Task: Evaluate the provided multiple-choice questions using above extracted text corpus based on the following criteria:

Accuracy: Ensure the correct response aligns with the information presented in the provided text corpus.
Relevance: Verify that the question and answer choices are directly related to the stated intervention topic.
Clarity: Check if the question and answer choices are clear, concise, and avoid ambiguity.
Consistency: Ensure that the correct response is consistent with other relevant information in the provided text corpus.

Evaluation Format:

Question: [Question text]
Intervention Topic: [Intervention topic]
Correct Answer: [Correct option]
Context: [Page number or section reference where the correct answer can be found in the provided text corpus]

Example:

Question: According to the WHO recommendations, what is the recommended daily intake of calcium for pregnant women?
Intervention Topic: Nutritional interventions
Correct Answer: B
Context: Page 72, Section 3.1 of the WHO Guidelines for the Prevention and Management of Gestational Diabetes Mellitus
Note: To ensure accurate and relevant evaluations, please provide the specific text corpus that the questions are based on.
This will allow for a more precise assessment of the accuracy, relevance, clarity, and consistency of the questions and answers.
"""
)
print(eval_mcq_base_prompt)

Task: Evaluate the provided multiple-choice questions using above extracted text corpus based on the following criteria:

Accuracy: Ensure the correct response aligns with the information presented in the provided text corpus.
Relevance: Verify that the question and answer choices are directly related to the stated intervention topic.
Clarity: Check if the question and answer choices are clear, concise, and avoid ambiguity.
Consistency: Ensure that the correct response is consistent with other relevant information in the provided text corpus.

Evaluation Format:

Question: [Question text]
Intervention Topic: [Intervention topic]
Correct Answer: [Correct option]
Context: [Page number or section reference where the correct answer can be found in the provided text corpus]

Example:

Question: According to the WHO recommendations, what is the recommended daily intake of calcium for pregnant women?
Intervention Topic: Nutritional interventions
Correct Answer: B
Context: Page 72, Section 3.1

### Function Calling

#### Single Eval Schema

This schema defines the structure for evaluating multiple choice questions. It includes the following fields:

* Intervention: The topic or subject matter related to the question.
* Question: The actual multiple-choice question.
* Choice: The chosen option or answer for the question.
* Correct: A boolean value indicating whether the chosen answer is correct (1) or incorrect (0).
* Context: A reference or source where the correct answer can be found, such as a page number or section in a document.

In [None]:
mcq_eval = genai.protos.Schema(
    type = genai.protos.Type.OBJECT,
    properties = {
        'Intervention':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Question':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Choice': genai.protos.Schema(type=genai.protos.Type.STRING),
        'Correct': genai.protos.Schema(type=genai.protos.Type.NUMBER),
        'Context': genai.protos.Schema(type=genai.protos.Type.STRING),
    },
    required=['Intervention', 'Question', 'Choice', 'Correct', 'Context']
)

#### Array Schema

In [None]:
mcq_eval_schema = genai.protos.Schema(
    type=genai.protos.Type.ARRAY,
    items=mcq_eval
)

#### SQA Eval Schema

In [None]:
sqa_eval_single_schema = genai.protos.Schema(
    type = genai.protos.Type.OBJECT,
    properties = {
        'Intervention':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Question':  genai.protos.Schema(type=genai.protos.Type.STRING),
        'Answer': genai.protos.Schema(type=genai.protos.Type.STRING),
        'Grade': genai.protos.Schema(type=genai.protos.Type.NUMBER),
        'Context': genai.protos.Schema(type=genai.protos.Type.STRING),
    },
    required=['Intervention', 'Question', 'Answer', 'Grade', 'Context']
)

#### Array Schema

In [None]:
sqa_eval_schema = genai.protos.Schema(
    type=genai.protos.Type.ARRAY,
    items=sqa_eval_single_schema
)

#### Function Declaration

In [None]:
sqa_eval_database = genai.protos.FunctionDeclaration(
    name="sqa_eval_database",
    description=textwrap.dedent("""\
        Adds interventions, questions, answers, grading and its context to the database.
        """),
    parameters=genai.protos.Schema(
        type=genai.protos.Type.OBJECT,
        properties = {
            'sqa_eval': sqa_eval_schema,
        }
    )
)

In [None]:
mcq_eval_database = genai.protos.FunctionDeclaration(
    name="mcq_eval_database",
    description=textwrap.dedent("""\
        Adds interventions, questions, answers, correctness and its context to the database.
        """),
    parameters=genai.protos.Schema(
        type=genai.protos.Type.OBJECT,
        properties = {
            'mcq_eval': mcq_eval_schema,
        }
    )
)

## Generate text

* The Gemini API's client library offers built-in retry mechanisms for handling transient errors.

* The `generate_text` function sends a message to the chat model with the given prompt and
  configurations, and returns the response. It includes retry logic to handle
  transient errors.

* For more info on error handling, take a look at the [error_handling quickstart](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Error_handling.ipynb).

In [None]:
@retry.Retry(
    predicate=retry.if_transient_error,
    initial=5,
    maximum=10,
    multiplier=2.0,
    timeout=100,
)

def generate_text(
    chat: genai.GenerativeModel,
    prompt: Union[List[str], str],
    generation_config: dict = None,
    tool_config: dict = None
):
  """Generates text using a chat model, with retry mechanism for transient errors.

  This function sends a prompt to a chat model and returns the generated response.
  It uses a retry decorator to handle transient errors, such as network issues,
  allowing the function to automatically retry the operation multiple times
  before giving up.

  Args:
    chat: The chat model object (an instance of `genai.Model`).
    prompt: The text prompt to send to the chat model. Can be a string or a list of strings.
    generation_config: (Optional) A dictionary containing configuration
        parameters for the text generation process. This might include settings
        like temperature, max tokens, etc. The specific format depends on the
        `chat` object (genai.Model).  See GenAI's documentation for details.
    tool_config: (Optional)  A dictionary containing configuration
        parameters for any tools that the chat model might use. The specific
        format depends on the `chat` object (genai.Model) and whether it
        supports tools.

  Returns:
    The response from the chat model (genai.Response).


  Raises:
    retry.RetryError: If the function fails to generate text after multiple
        retries due to persistent transient errors. The original exception
        that triggered the retries will be chained to the `RetryError`.
    Any other exception raised by `chat.generate_text`: If the `generate_text`
        method raises an exception that is not considered a transient error,
        the exception will be propagated directly without retries.
  """

  response = chat.send_message(
      prompt,
      generation_config=generation_config,
      tool_config=tool_config,
  )
  return response

### Define the model with function declaration

In [None]:
model = model_def(model_name, safety_settings, tools=[mcq_eval_database])
chat = model.start_chat(history=[])

### Load the csv file containing previously generated mcq questions

In [None]:
mcq_questions_df = pd.read_csv("mcq_questions.csv")
mcq_questions_dict_list = mcq_questions_df.to_dict(orient="records")

### Evaluate in chat session

In [None]:
eval_num_mcq_per_prompt = 20

mcq_evaluated_questions = eval_questions(
    chat_model=chat,
    file_ref=file_ref,
    base_prompt=eval_mcq_base_prompt,
    questions_dict_list=mcq_questions_dict_list,
    question_type="mcq",
    eval_type_function_name="mcq_eval",
    eval_per_prompt=eval_num_mcq_per_prompt,
    generation_config=generation_config,
    tool_config={"function_calling_config": {"mode": "ANY"}},
    model_name = 'gemini'
)

Evaluated 20 questions in prompt 1




Evaluated 20 questions in prompt 2




Evaluated 20 questions in prompt 3




Evaluated 20 questions in prompt 4




Evaluated 20 questions in prompt 5


In [None]:
mcq_eval_df = pd.DataFrame(mcq_evaluated_questions)
mcq_eval_df

Unnamed: 0,Question,Choice,Correct,Context,Intervention
0,Which supplement is routinely recommended for ...,A,1.0,Page 23,Nutritional interventions
1,What does a healthy diet during pregnancy cons...,D,1.0,Page 14,Nutritional interventions
2,"For undernourished populations, which type of ...",B,1.0,Page 20,Nutritional interventions
3,When is vitamin A supplementation recommended ...,B,1.0,Page 29,Nutritional interventions
4,What is the recommended method for diagnosing ...,C,1.0,Page 41,Maternal and fetal assessment
...,...,...,...,...,...
95,Which of the following may be recommended for ...,D,1.0,Page 74,Interventional measures for common physiologic...
96,Which statement about midwife-led continuity o...,A,1.0,Page 89,Health systems interventions
97,What is a key consideration regarding task-shi...,D,1.0,Page 99,Health systems interventions
98,What is an accurate statement about group ante...,D,1.0,Page 91,Health systems interventions


In [None]:
# Save the evaluations to a CSV file
mcq_eval_df.to_csv("mcq_eval.csv", index=False)

Evaluation results show that all of the MCQ questions were answered correctly by Gemini 1.5 pro model





In [None]:
eval_sqa_base_prompt = textwrap.dedent("""\
Task: Evaluate the provided short answer responses using above extracted text corpus based on the following criteria:

Accuracy: Ensure the response aligns with the information presented in the provided text corpus.
Relevance: Verify that the response is directly related to the stated intervention topic.
Clarity: Check if the response is clear, concise, and avoids ambiguity.
Consistency: Ensure the response is consistent with other relevant information in the provided text corpus.

Evaluation Format:

Question: [Question text]
Intervention Topic: [Intervention topic]
Response: [Short answer response]
Grade: [1-5] (1: Poor, 2: Needs Improvement, 3: Satisfactory, 4: Good, 5: Excellent)
Context: [Page number or section reference where the response is supported or contradicted in the provided text corpus]

Example:

Question: According to the WHO recommendations, what is the recommended daily intake of calcium for pregnant women?
Intervention Topic: Nutritional interventions
Response: "Pregnant women should consume 1000 mg of calcium per day."
Grade: 4
Context: Page 75, Section 3.2 of the WHO Guidelines for the Prevention and Management of Gestational Diabetes Mellitus

Note: To ensure accurate and relevant evaluations, please provide the specific text corpus that the short answer responses are based on.
This will allow for a more precise assessment of the accuracy, relevance, clarity, and consistency of the responses.
"""
)
print(eval_sqa_base_prompt)

Task: Evaluate the provided short answer responses using above extracted text corpus based on the following criteria:

Accuracy: Ensure the response aligns with the information presented in the provided text corpus.
Relevance: Verify that the response is directly related to the stated intervention topic.
Clarity: Check if the response is clear, concise, and avoids ambiguity.
Consistency: Ensure the response is consistent with other relevant information in the provided text corpus.

Evaluation Format:

Question: [Question text]
Intervention Topic: [Intervention topic]
Response: [Short answer response]
Grade: [1-5] (1: Poor, 2: Needs Improvement, 3: Satisfactory, 4: Good, 5: Excellent)
Context: [Page number or section reference where the response is supported or contradicted in the provided text corpus]

Example:

Question: According to the WHO recommendations, what is the recommended daily intake of calcium for pregnant women?
Intervention Topic: Nutritional interventions
Response: "Pre

### Evaluation of Gemma Short Answer question answer using Gemini 1.5

In [None]:
sqa_questions_df = pd.read_csv("sqa_questions.csv")
sqa_questions_dict_list = sqa_questions_df.to_dict(orient="records")

### Define the model with function declaration



In [None]:
model = model_def(model_name, safety_settings, tools=[sqa_eval_database])
chat = model.start_chat(history=[])

### Evaluate chat session

In [None]:
num_sqa_eval_per_prompt = 20

In [None]:
sqa_evaluated_questions = eval_questions(
    chat_model=chat,
    file_ref=file_ref,
    base_prompt=eval_sqa_base_prompt,
    questions_dict_list=sqa_questions_dict_list,
    question_type="sqa_eval",
    eval_type_function_name="sqa_eval",
    eval_per_prompt=num_sqa_eval_per_prompt,
    generation_config=generation_config,
    tool_config={"function_calling_config": {"mode": "ANY"}},
    model_name = 'gemma'
)



Evaluated 20 questions in prompt 1




Evaluated 19 questions in prompt 2




Evaluated 20 questions in prompt 3




Evaluated 20 questions in prompt 4




Evaluated 20 questions in prompt 5


In [None]:
sqa_eval_df_gemma = pd.DataFrame(sqa_evaluated_questions)
sqa_eval_df_gemma

Unnamed: 0,Intervention,Answer,Grade,Context,Question
0,Nutritional interventions,"A balanced intake of essential nutrients, incl...",4.0,"Page 14: \""Pregnancy requires a healthy diet t...",What constitutes a healthy diet during pregnancy?
1,Nutritional interventions,"To optimize maternal health, fetal growth, and...",5.0,"Page 15: \""...optimize maternal and newborn he...",What is the goal of nutritional counseling dur...
2,Nutritional interventions,Addressing undernourishment during pregnant wo...,4.0,"Page 14: \""...maternal undernutrition is highl...",Why is addressing undernutrition during pregna...
3,Nutritional interventions,"During the first trimester, second trimester a...",5.0,"Page 20: \""...balanced energy and protein diet...",When is balanced energy and protein supplement...
4,Maternal and fetal assessment,"Hemoglobin level, Hematocrit level, Red blood ...",1.0,"Page 41: \""Full blood count testing...quantifi...",What are the recommended methods for diagnosin...
...,...,...,...,...,...
94,Interventions from common physiological symptoms,"Compression stockings, elevation of legs, exer...",4.0,Page 83: Recommendation D.6 recommends these a...,What non-pharmacological approaches can help m...
95,Health systems interventions to improve the ut...,Woman-Held Case Notes encourage better communi...,5.0,Page 87: Recommendation E.1 and its remarks me...,Why are woman-held case notes encouraged durin...
96,Health systems interventions to improve the ut...,Midwife-Led Continuity of Care Models are most...,4.0,Page 89: Recommendation E.2 emphasizes the imp...,Where are midwife-led continuity of care model...
97,Health systems interventions to improve the ut...,Incentives include: Competitive salaries: Offe...,4.0,Page 100: Recommendation E.6 suggests consider...,What incentives can be used to address healthc...


In [None]:
sqa_eval_df_gemma.to_csv("sqa_eval_gemma.csv", index=False)

### Merging of the sqa eval of gemma and gemini dataframes

In [None]:
sqa_eval_df = pd.read_csv("sqa_eval.csv")
sqa_eval_gemma_df = pd.read_csv("sqa_eval_gemma.csv")

In [None]:
result = merge_df(sqa_eval_df,sqa_eval_gemma_df)
result.to_csv("sqa_eval_result.csv", index=False)

### Comparision of the Gemini and Gemma Grades

### Comparision of gemma model w.r.t to gemini in terms of Mean

In [None]:
print(f"The mean of the grades is gemini {result['Grade_gemini'].agg('mean')}")

The mean of the grades is gemini 4.99


In [None]:
print(f"The mean of the grades is gemma_2b {result['Grade_gemma_2b'].agg('mean')}")

The mean of the grades is gemma_2b 3.56


### Comparision of gemma model w.r.t to gemini in terms of count

In [None]:
get_metrics(result)

The number of questions gemma and gemini got same grades is 20
The number of questions gemma got better grades compared gemini 0
The number of questions gemma got worst grades compared to gemini 80
Grade differences and counts between Gemma and Gemini for questions where Gemma underperformed.
{1.0: 47, 4.0: 10, 2.0: 13, 3.0: 10}
