In [1]:
!pip install google-genai datasets pandas tqdm



In [21]:
from google import genai
from google.genai import types
import os
import time
from datasets import load_dataset
import pandas as pd
from tqdm.notebook import tqdm
import json
import random

In [8]:
# --- 1. Configure API Key ---
from google.colab import userdata
api_key = userdata.get('GOOGLE_API_KEY')
if not api_key:
    raise ValueError("API Key not found in Colab Secrets.")
client = genai.Client(api_key=api_key)

print("Gemini API Key Configured.")

Gemini API Key Configured.


In [13]:
# --- 2. Configuration ---
# from google.colab import drive
# drive.mount('/content/drive')
output_base_path = "/content/drive/MyDrive/"

test_dataset_path = "/content/drive/MyDrive/test_json_extraction.jsonl"
train_dataset_path = "/content/drive/MyDrive/train_json_extraction.jsonl" # Needed for few-shot examples
num_few_shot_examples = 3 # Number of examples to include in the few-shot prompt

output_results_zeroshot_filename = "results_gemini_zeroshot_task2.jsonl"
output_results_fewshot_filename = "results_gemini_fewshot_task2.jsonl"

output_results_zs_filepath = os.path.join(output_base_path, output_results_zeroshot_filename)
output_results_fs_filepath = os.path.join(output_base_path, output_results_fewshot_filename)

In [14]:
# --- 3. Load Datasets ---
print(f"Loading test dataset from: {test_dataset_path}")
try:
    test_dataset = load_dataset("json", data_files=test_dataset_path, split="train")
    print(f"Test dataset loaded with {len(test_dataset)} examples.")
    required_cols = ['formatted_prompt', 'ground_truth_json', 'schema']
    if not all(col in test_dataset.column_names for col in required_cols):
         raise ValueError(f"Test file missing required columns: {required_cols}")
except Exception as e:
    print(f"Error loading test dataset: {e}")
    raise

print(f"Loading train dataset (for few-shot examples) from: {train_dataset_path}")
try:
    train_dataset = load_dataset("json", data_files=train_dataset_path, split="train")
    if len(train_dataset) < num_few_shot_examples:
        print(f"Warning: Training dataset has only {len(train_dataset)} examples, less than requested {num_few_shot_examples} for few-shot.")
        num_few_shot_examples = len(train_dataset)
except Exception as e:
    print(f"Error loading train dataset: {e}")
    raise

Loading test dataset from: /content/drive/MyDrive/test_json_extraction.jsonl
Test dataset loaded with 40 examples.
Loading train dataset (for few-shot examples) from: /content/drive/MyDrive/train_json_extraction.jsonl


In [15]:
# --- 4. Prepare Few-Shot Examples String ---
print(f"Selecting {num_few_shot_examples} few-shot examples...")
few_shot_prompt_string = "Here are some examples:\n\n"
# Select random examples from the training set
random.seed(42) # for reproducibility
few_shot_indices = random.sample(range(len(train_dataset)), num_few_shot_examples)

for index in few_shot_indices:
    example_text = train_dataset[index]['text']
    # Extract relevant parts from the formatted training string
    try:
        parts = example_text.split("[/INST]")
        inst_part = parts[0].split("[INST]")[1].strip()
        # Re-extract schema and text from the instruction part for clarity in the FS prompt
        schema_section = inst_part.split("Schema:\n```json")[1].split("```")[0].strip()
        text_section = inst_part.split("Text:\n'''")[1].split("'''")[0].strip()
        output_json_str = parts[1].split("</s>")[0].strip()

        few_shot_prompt_string += f"Example:\nSchema:\n```json\n{schema_section}\n```\n\nText:\n'''\n{text_section}\n'''\n\nOutput JSON:\n```json\n{output_json_str}\n```\n\n---\n\n"
    except (IndexError, Exception) as e:
        print(f"Warning: Could not parse training example at index {index} for few-shot prompt: {e}")
        continue # Skip this example if parsing fails

Selecting 3 few-shot examples...


In [16]:
# --- 5. Define Prompting Functions for Gemini ---
# Base instruction - kept concise for Gemini
base_instruction = "Extract information from the following text based on the provided JSON schema. Output ONLY the valid JSON object."

def create_gemini_zeroshot_prompt(schema_text, input_text):
    return (
        f"{base_instruction}\n\n"
        f"Schema:\n```json\n{schema_text}\n```\n\n"
        f"Text:\n'''\n{input_text}\n'''\n\n"
        f"Output JSON:"
    )

def create_gemini_fewshot_prompt(schema_text, input_text, examples_string):
    return (
        f"{base_instruction}\n\n"
        f"{examples_string}" # Add the examples string here
        f"Now, based on the following schema and text, provide the output JSON.\n\n"
        f"Schema:\n```json\n{schema_text}\n```\n\n"
        f"Text:\n'''\n{input_text}\n'''\n\n"
        f"Output JSON:"
    )

In [23]:
# --- 6. Function to Call Gemini API and Parse JSON Response ---
# model = genai.GenerativeModel('gemini-2.0-flash-latest')

def get_gemini_json_output(prompt):
    start_time = time.time()
    predicted_json_str = None
    raw_response = "Error: No response"
    latency = 0

    try:
        # safety_settings = [ # Keep safety low for this task
        #     {"category": h.name, "threshold": h.HarmBlockThreshold.BLOCK_NONE}
        #     for h in genai.types.HarmCategory
        # ]
        safety_settings = [
            types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
        ]
        generation_config = types.GenerateContentConfig(
            temperature=0.0,
            max_output_tokens=1024 # Allow more tokens for JSON
        )

        generation_config.safety_settings = safety_settings
        response = client.models.generate_content(
            model='gemini-2.0-flash-001',
            contents=prompt,
            config=generation_config
        )
        end_time = time.time()
        latency = end_time - start_time
        raw_response = response.text.strip()

        # --- Robust JSON Parsing ---
        # 1. Try finding ```json ... ``` markdown block
        try:
            start_md = raw_response.find('```json')
            if start_md != -1:
                end_md = raw_response.find('```', start_md + 7)
                if end_md != -1:
                    potential_json = raw_response[start_md + 7 : end_md].strip()
                    json.loads(potential_json) # Validate
                    predicted_json_str = potential_json
        except (json.JSONDecodeError, Exception):
            predicted_json_str = None # Reset if markdown parsing failed

        # 2. If no valid markdown JSON, try finding first '{' and last '}'
        if predicted_json_str is None:
            try:
                start_brace = raw_response.find('{')
                end_brace = raw_response.rfind('}')
                if start_brace != -1 and end_brace != -1 and end_brace > start_brace:
                    potential_json = raw_response[start_brace : end_brace + 1]
                    json.loads(potential_json) # Validate
                    predicted_json_str = potential_json
            except (json.JSONDecodeError, Exception):
                 predicted_json_str = None # Still None if this fails

        if predicted_json_str is None:
            print(f"Warning: Could not parse valid JSON from Gemini response: '{raw_response[:200]}...' for prompt: '{prompt[:100]}...'")

    except Exception as e:
        latency = time.time() - start_time
        raw_response = f"API Error: {e}"
        print(f"Error calling Gemini API: {e}")
        time.sleep(5) # Longer sleep on error

    # Add delay to avoid rate limits
    time.sleep(4) # Slightly longer delay for potentially heavier JSON tasks

    return predicted_json_str, latency, raw_response

In [24]:
# --- 7. Run Evaluation Loop ---
results_zeroshot = []
results_fewshot = []

print("\nStarting evaluation for Gemini Flash - Task 2...")
for example in tqdm(test_dataset):
    # We need the original text and schema, NOT the pre-formatted prompt string here
    # Let's re-extract them or assume test_json_extraction.jsonl still has them...
    # Re-parsing the 'formatted_prompt' is safer if 'item' isn't saved in test file
    try:
        full_prompt_str = example['formatted_prompt']
        inst_part = full_prompt_str.split("[INST]")[1].split("[/INST]")[0].strip()
        schema_text = inst_part.split("Schema:\n```json")[1].split("```")[0].strip()
        input_text = inst_part.split("Text:\n'''")[1].split("'''")[0].strip()
        ground_truth = example['ground_truth_json'] # The GT JSON object
    except (IndexError, KeyError, Exception) as e:
        print(f"Skipping example due to parsing error in test data: {e} - Data: {example}")
        continue

    # --- Zero-Shot ---
    prompt_zs = create_gemini_zeroshot_prompt(schema_text, input_text)
    pred_zs, lat_zs, raw_zs = get_gemini_json_output(prompt_zs)
    results_zeroshot.append({
        "schema": schema_text,
        "ground_truth": ground_truth,
        "predicted_json_str": pred_zs,
        "latency": lat_zs,
        "raw_response": raw_zs
    })

    # --- Few-Shot ---
    prompt_fs = create_gemini_fewshot_prompt(schema_text, input_text, few_shot_prompt_string)
    pred_fs, lat_fs, raw_fs = get_gemini_json_output(prompt_fs)
    results_fewshot.append({
        "schema": schema_text,
        "ground_truth": ground_truth,
        "predicted_json_str": pred_fs,
        "latency": lat_fs,
        "raw_response": raw_fs
    })

print("\nEvaluation complete.")


Starting evaluation for Gemini Flash - Task 2...


  0%|          | 0/40 [00:00<?, ?it/s]

{"facilityName":"Acme Corporation","facilityLocation":{"name":"Acme Headquarters","address":"123 Main Street","city":"Anytown","state":"CA","zip":"12345"},"emergencyContacts":[{"name":"John Sm...' for prompt: 'Extract information from the following text based on the provided JSON schema. Output ONLY the valid...'
{"id": "DRP-001", "name": "IT Disaster Recovery Plan", "objectives": ["Ensure the uninterrupted operation of essential business functions during and after a disaster.", "Safeguard the company'...' for prompt: 'Extract information from the following text based on the provided JSON schema. Output ONLY the valid...'
{"loyaltyProgram":{"id":"lp-123","name":"My Loyalty Program","description":"Rewards Customers for Purchases. Tiered Benefits","tiers":[{"name":"Bronze Tier","benefits":["10% discount on all pu...' for prompt: 'Extract information from the following text based on the provided JSON schema. Output ONLY the valid...'
{"id": "a1a1a1a1-b2b2-c3c3-d4d4-e5e5e5e5e5e5", "patient

In [25]:
# --- 8. Save Results ---
def save_results(results, filepath):
    print(f"Saving raw results to: {filepath}")
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            for item in results:
                item_to_save = item.copy()
                # Ensure ground_truth is saved as a string
                if isinstance(item_to_save['ground_truth'], dict):
                     item_to_save['ground_truth'] = json.dumps(item_to_save['ground_truth'], ensure_ascii=False)
                f.write(json.dumps(item_to_save, ensure_ascii=False) + '\n')
        print(f"Successfully saved {len(results)} results to {filepath}")
    except Exception as e:
        print(f"Error saving results file {filepath}: {e}")

save_results(results_zeroshot, output_results_zs_filepath)
save_results(results_fewshot, output_results_fs_filepath)

Saving raw results to: /content/drive/MyDrive/results_gemini_zeroshot_task2.jsonl
Successfully saved 40 results to /content/drive/MyDrive/results_gemini_zeroshot_task2.jsonl
Saving raw results to: /content/drive/MyDrive/results_gemini_fewshot_task2.jsonl
Successfully saved 40 results to /content/drive/MyDrive/results_gemini_fewshot_task2.jsonl


In [40]:
# --- 9. Calculate and Print Results ---
def calculate_metrics(results):
    correct = 0
    total_latency = 0
    successful_parses = 0 # Count successfully parsed predictions
    for res in results:
        # res.pop('schema')
        # print('res', res)
        total_latency += res['latency']
        predicted_json_parsed = None
        ground_truth_parsed = None

        # Attempt to parse predicted JSON
        if res['predicted_json_str']:
            try:
                predicted_json_parsed = json.loads(res['predicted_json_str'])
                successful_parses += 1 # Increment if parsing was successful
            except json.JSONDecodeError:
                pass # Keep predicted_json_parsed as None if parsing fails

        # Attempt to parse ground truth JSON string (it was saved as a string)
        if isinstance(res['ground_truth'], str):
            try:
                ground_truth_parsed = json.loads(res['ground_truth'])
            except json.JSONDecodeError:
                pass # Keep ground_truth_parsed as None if parsing fails
        elif isinstance(res['ground_truth'], dict):
             ground_truth_parsed = res['ground_truth']


        # Compare parsed objects if both were successfully parsed
        if predicted_json_parsed is not None and ground_truth_parsed is not None:
            if predicted_json_parsed == ground_truth_parsed:
                correct += 1
            else:
                print('==================================')
                for key, value in ground_truth_parsed.items():
                    if key in predicted_json_parsed:
                      predicted = predicted_json_parsed.get(key, 'EMPTY')
                      equals = predicted == value
                      print('key', key,' equals: ',equals,' || truth: ',value, ' <> predicted: ', predicted)

                print('==================================')


    accuracy = (correct / len(results)) * 100 if len(results) > 0 else 0
    avg_latency = (total_latency / len(results)) * 1000 if len(results) > 0 else 0 # in ms
    # Calculate parse rate based on the total number of examples
    parse_rate = (successful_parses / len(results)) * 100 if len(results) > 0 else 0
    return accuracy, avg_latency, parse_rate

acc_zs, lat_zs_avg, parse_zs = calculate_metrics(results_zeroshot)
acc_fs, lat_fs_avg, parse_fs = calculate_metrics(results_fewshot)

print("\n--- Gemini Flash Performance on AG News ---")
print(f"Number of test samples: {len(test_dataset)}")

print("\nZero-Shot Results:")
print(f"  Accuracy: {acc_zs:.2f}%")
print(f"  Avg. Latency: {lat_zs_avg:.2f} ms per request")
print(f"  Successfully Parsed Responses: {parse_zs:.2f}%")

print("\nFew-Shot Results (3 examples):")
print(f"  Accuracy: {acc_fs:.2f}%")
print(f"  Avg. Latency: {lat_fs_avg:.2f} ms per request")
print(f"  Successfully Parsed Responses: {parse_fs:.2f}%")

key origin  equals:  True  || truth:  123 Main Street, Springfield  <> predicted:  123 Main Street, Springfield
key destination  equals:  True  || truth:  456 Elm Street, Anytown  <> predicted:  456 Elm Street, Anytown
key packages  equals:  True  || truth:  [{'weight': 10, 'dimensions': {'length': 12, 'width': 8, 'height': 6}}, {'weight': 15, 'dimensions': {'length': 16, 'width': 10, 'height': 8}}]  <> predicted:  [{'weight': 10, 'dimensions': {'length': 12, 'width': 8, 'height': 6}}, {'weight': 15, 'dimensions': {'length': 16, 'width': 10, 'height': 8}}]
key shippingOptions  equals:  False  || truth:  [{'name': 'Standard Shipping', 'carrier': 'USPS', 'service': 'First Class', 'cost': 5.99, 'deliveryTime': '3-5 business days'}, {'name': 'Expedited Shipping', 'carrier': 'UPS', 'service': 'Next Day Air', 'cost': 19.99, 'deliveryTime': '1 business day'}]  <> predicted:  [{'name': 'Standard Shipping', 'carrier': 'USPS', 'service': 'First Class', 'cost': 5.99}, {'name': 'Expedited Shipping