In [1]:
%pip install datasets tqdm python-dotenv google-api-core --quiet
%pip install -U -q "google-genai>=1.10.0" --quiet

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [9]:
import os
import json
import time
import concurrent.futures
import threading # Import threading for Lock
import signal
from tqdm.notebook import tqdm
from google import genai
from google.genai import types
from dotenv import load_dotenv
from datasets import load_dataset
from google.api_core import exceptions as core_exceptions

In [26]:
# --- Configuration ---
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY") # restart kernel to make changes to this take effect
MODEL_ID = "gemini-2.5-flash-preview-04-17"
OUTPUT_FILE = "./data/train_reasoning.jsonl"
MAX_CALLS_PER_RUN = 174

# --- Concurrency Configuration ---
# recommend not going over 15
MAX_WORKERS = 1 # Number of concurrent API requests


In [27]:
sys_prompt = """
You are an AI language expert simulating the process of text editing. You will receive a source text (`src`) which contains an editing instruction followed by the text to be edited.

Your task is to:
1.  Carefully identify the editing instruction given at the beginning of the `src` text.
2.  Analyze the original text segment that follows the instruction in `src`.
3.  Apply the instruction step-by-step to the original text.
4.  For each step or change you make, articulate the precise reasoning based on linguistic rules, the instruction's goal (e.g., grammar, clarity, simplification), or common writing conventions. Explain *why* a change is necessary or beneficial according to the instruction.
5.  Your reasoning should demonstrate *how* to arrive at the corrected text, simulating the thought process of performing the edit yourself.

**Crucially:** Your reasoning process should *not* refer to or imply knowledge of a pre-existing target or corrected text. Focus *only* on applying the instruction to the source text to derive the necessary changes.

Your entire output must consist *only* of the step-by-step reasoning process, enclosed within `<think></think>` tags. Do not include the original instruction, the source text, the final corrected text, or any other text outside these tags.

**Example Interaction:**

**User Input (constructed from your dataset):**

```
Source (src): Fix grammar in this sentence: If engineers do not come up with new ideas, they cannot find best solution for the problems.
Target (tgt): If engineers do not come up with new ideas, they cannot find the best solution for different problems.
```

**Your Expected Output (demonstrating the reasoning to *derive* the correction):**

<think>
1. Instruction analysis: The task is to "Fix grammar". This requires identifying and correcting grammatical errors in the sentence "If engineers do not come up with new ideas, they cannot find best solution for the problems."
2. Sentence segment analysis: The first clause "If engineers do not come up with new ideas" appears grammatically sound.
3. Focus on the second clause: "they cannot find best solution for the problems."
4. Identify potential error 1: The noun phrase "best solution". "Best" is a superlative adjective. Grammatical rule: Superlative adjectives modifying singular countable nouns generally require a definite article ("the").
5. Apply correction 1: Insert "the" before "best solution". The phrase becomes "the best solution".
6. Identify potential error/improvement area 2: The phrase "for the problems". While grammatically plausible, using "the" implies specific, previously identified problems. The context seems general. Improving grammatical flow or clarity might involve adjusting this.
7. Consider alternatives for "the problems": "for problems" (too general?), "for specific problems" (adds info), "for different problems" (suggests variety, often fits general statements well).
8. Apply correction 2: Replace "the problems" with "different problems" for better contextual fit and naturalness, aligning with the goal of general grammatical improvement and clarity often included in "Fix grammar" tasks.
9. Synthesized correction based on reasoning: The corrected clause is "they cannot find the best solution for different problems." The full sentence derived from applying the grammar fixes is "If engineers do not come up with new ideas, they cannot find the best solution for different problems."
</think>
"""


In [28]:
# --- Main Processing Logic ---
if not gemini_api_key:
    print("Error: GEMINI_API_KEY not found in environment variables.")
else:
    # Initialize client
    client = genai.Client(api_key=gemini_api_key)

    # Load dataset
    print("Loading dataset...")
    try:
        train_ds = load_dataset("grammarly/coedit", split="train", cache_dir="./data")
        print(f"Dataset loaded: {len(train_ds)} rows")
        train_ds = train_ds.shuffle(seed=42)
    except Exception as e:
        print(f"Could not load dataset. {e}")
        train_ds = None # Ensure loop doesn't run

# Global flag for controlled shutdown
shutdown_requested = threading.Event()

# Custom exception for controlled shutdown
class APIShutdownException(Exception):
    """Exception raised to trigger a controlled shutdown of API processing."""
    pass        

# --- Helper Function to Load Processed IDs ---

def load_processed_ids(filename):
    processed = set()
    if os.path.exists(filename):
        with open(filename, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if '_id' in data:
                        processed.add(data['_id'])
                except json.JSONDecodeError:
                    print(f"Warning: Skipping invalid JSON line in {filename}")
    print(f"Loaded {len(processed)} processed IDs from {filename}")
    return processed

# --- Worker Function ---

def process_item(item, client, sys_prompt):
    """Processes a single dataset item to get reasoning."""
    item_id = item['_id']
    
    if shutdown_requested.is_set():
        return {"_id": item_id, "error": "Skipped due to shutdown request"}

    item_src = item['src']
    item_tgt = item['tgt']
    user_prompt = f"Source (src): {item_src}\nTarget (tgt): {item_tgt}"

    try:
        # ---  API Call ---
        response = client.models.generate_content(
            model=MODEL_ID,
            contents=user_prompt,
            config=types.GenerateContentConfig(
            system_instruction=sys_prompt,
                    thinking_config=types.ThinkingConfig(
                        thinking_budget=600 # 0 means no thinking None means no limit to thinking it can take however many tokens it wants.
                    )
                )
            )
        
        # --- Process Response ---
        reasoning_text = response.text

        if reasoning_text and reasoning_text.strip().startswith("```xml"):
            reasoning_text = reasoning_text[6:-3].strip() # removing the xml codefences

        if reasoning_text.strip().startswith("<think>") and reasoning_text.strip().endswith("</think>"):
            return {"_id": item_id, "reasoning": reasoning_text.strip()}
        else:
            print(f"\nWarning: Unexpected response format for ID {item_id}. Skipping. Response: {reasoning_text[:200]}...")
            return {"_id": item_id, "error": "Bad Format"} # Return error indicator
    except Exception as e:
        if hasattr(e, 'code') and 400 <= e.code < 600:
            error_msg = f"Critical HTTP error {e.code}: {e}"
            print(f"\n{error_msg}")
            # Signal all threads to stop
            shutdown_requested.set()
            # Propagate the exception to interrupt processing
            raise APIShutdownException(error_msg)
        
        # --- Handle API or Processing Errors ---
        print(f"\nError processing ID {item_id}: {e}")
        return {"_id": item_id, "error": str(e)} # Return error indicator



Loading dataset...
Dataset loaded: 69071 rows


In [29]:

# --- Main processing section ---
if train_ds:
    processed_ids = load_processed_ids(OUTPUT_FILE)
    api_calls_made_this_run = 0
    file_lock = threading.Lock()
    items_to_process = []

    print("Filtering items to process...")
    for item in tqdm(train_ds, desc="Filtering"):
        if item['_id'] not in processed_ids:
            items_to_process.append(item)
    print(f"Found {len(items_to_process)} items needing processing.")

if MAX_CALLS_PER_RUN < 500:
    for i in range(min(MAX_CALLS_PER_RUN, len(items_to_process))):
        print(f"Processing item {i} of {MAX_CALLS_PER_RUN}")
        time.sleep(6) # to avoid rpm rate limit (10/min if you go over you incur costs)
        response = process_item(items_to_process[i], client, sys_prompt)
        if response and 'error' not in response:
            with open(OUTPUT_FILE, 'a', encoding='utf-8') as outfile:
                json_line = json.dumps(response)
                outfile.write(json_line + '\n')
                outfile.flush()
                processed_ids.add(response['_id'])
        else:
            print(f"Error processing item {i} of {MAX_CALLS_PER_RUN}: {response['error'] if response else 'Unknown error'}")

elif MAX_CALLS_PER_RUN >= 500:
    # Process items concurrently
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as outfile:
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                # stores ids of items that will be processed
                futures = {}
                
                # Submit tasks only up to the daily limit
                num_to_submit = min(len(items_to_process), MAX_CALLS_PER_RUN - api_calls_made_this_run)
                print(f"Submitting {num_to_submit} tasks to the executor...")

                submitted_count = 0
                for i in range(len(items_to_process)):
                    if submitted_count >= num_to_submit or shutdown_requested.is_set():
                        break  # Stop submitting if we hit the limit or shutdown requested
                    item = items_to_process[i]
                    future = executor.submit(process_item, item, client, sys_prompt)
                    futures[future] = item['_id']
                    submitted_count += 1

                print(f"Tasks submitted ({submitted_count}). Waiting for results...")

                # Process results as they complete
                completed_futures = []

                for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing Results"):
                    if shutdown_requested.is_set():
                        print("\nShutdown requested. Canceling remaining tasks...")
                        for f in futures:
                            if f not in completed_futures and not f.done():
                                f.cancel()
                        break
        
                    completed_futures.append(future)
                    item_id = futures[future]       
                    
                    try:
                        result_data = future.result()
                        
                        if result_data and "error" not in result_data:
                            with file_lock:
                                if result_data['_id'] not in processed_ids:
                                    json_line = json.dumps(result_data)
                                    outfile.write(json_line + '\n')
                                    outfile.flush()
                                    processed_ids.add(result_data['_id'])
                                    api_calls_made_this_run += 1
                                    
                    except APIShutdownException:
                        print(f"Task for item {item_id} aborted due to shutdown request")
                    except Exception as exc:
                        print(f'\nItem ID {item_id} generated an exception: {exc}')
                        
        except KeyboardInterrupt:
            print("\nKeyboard interrupt detected. Shutting down gracefully...")
            shutdown_requested.set()
            
    # Final status report
    if shutdown_requested.is_set():
        print("\nProcessing terminated early due to errors or interruption.")
    else:
        print(f"\nFinished processing batch. Made approximately {api_calls_made_this_run} successful API calls in this run.")
    
    print(f"Total processed IDs (including previous runs): {len(processed_ids)}")

else:
    print("No items to process")


Loaded 51896 processed IDs from ./data/train_reasoning.jsonl
Filtering items to process...


Filtering:   0%|          | 0/69071 [00:00<?, ?it/s]

Found 17175 items needing processing.
Processing item 0 of 174
Processing item 1 of 174
Processing item 2 of 174
Processing item 3 of 174
Processing item 4 of 174
Processing item 5 of 174
Processing item 6 of 174
Processing item 7 of 174
Processing item 8 of 174
Processing item 9 of 174
Processing item 10 of 174
Processing item 11 of 174
Processing item 12 of 174
Processing item 13 of 174
Processing item 14 of 174

Critical HTTP error 500: 500 INTERNAL. {'error': {'code': 500, 'message': 'An internal error has occurred. Please retry or report in https://developers.generativeai.google/guide/troubleshooting', 'status': 'INTERNAL'}}


APIShutdownException: Critical HTTP error 500: 500 INTERNAL. {'error': {'code': 500, 'message': 'An internal error has occurred. Please retry or report in https://developers.generativeai.google/guide/troubleshooting', 'status': 'INTERNAL'}}

In [4]:
# Pricing per 1 Million tokens
INPUT_PRICE_PER_MILLION = 0.15
OUTPUT_NON_THINKING_PRICE_PER_MILLION = 0.60
OUTPUT_THINKING_PRICE_PER_MILLION = 3.50

# --- Get Token Counts from User ---
print("Enter the token counts (use 0 if Thoughts tokens is None or zero):")

prompt_tokens = 800
thoughts_tokens = 393 # Assume 0 if None was the actual value
output_tokens = 609

# --- Calculate Costs ---

# Input Cost
input_cost = (prompt_tokens / 1_000_000) * INPUT_PRICE_PER_MILLION

# Output Cost (depends on whether thinking tokens were generated)
if thoughts_tokens > 0:
    # Thinking was used - price applies to output + thoughts tokens
    billable_output_tokens = output_tokens + thoughts_tokens
    output_cost = (billable_output_tokens / 1_000_000) * OUTPUT_THINKING_PRICE_PER_MILLION
    print("\nCalculating using THINKING output price ($3.50 / 1M tokens)")
else:
    # No thinking - price applies only to output tokens
    billable_output_tokens = output_tokens
    output_cost = (billable_output_tokens / 1_000_000) * OUTPUT_NON_THINKING_PRICE_PER_MILLION
    print("\nCalculating using NON-THINKING output price ($0.60 / 1M tokens)")

# Total Cost
total_cost = input_cost + output_cost

# --- Display Results ---
print(f"\n--- Cost Breakdown ---")
print(f"Input Cost:       ${input_cost:.6f}")
print(f"Output Cost:      ${output_cost:.6f} (based on {billable_output_tokens} billable output tokens)")
print(f"----------------------")
print(f"Total Estimated Cost: ${total_cost:.6f}")

# Example calculation for 10,000 identical calls
num_calls = 70000
total_cost_10k = total_cost * num_calls
print(f"\nEstimated Cost for {num_calls:,} calls: ${total_cost_10k:.2f}")

Enter the token counts (use 0 if Thoughts tokens is None or zero):

Calculating using THINKING output price ($3.50 / 1M tokens)

--- Cost Breakdown ---
Input Cost:       $0.000120
Output Cost:      $0.003507 (based on 1002 billable output tokens)
----------------------
Total Estimated Cost: $0.003627

Estimated Cost for 70,000 calls: $253.89


In [6]:
import sys

prompt = """
    what is the capital of new zealand?
"""
# The animal I'm thinking of is a platipus
client = genai.Client(api_key=gemini_api_key)
try:
    response = client.models.generate_content(
        model="gemini-2.5-flash-preview-04-17",
        # model="gemini-2.0-flash",
        contents=prompt
    )
except Exception as e:
    if 400 <= e.code < 600:
        print(f"Exiting due to HTTP error {e.code}: {e}")
        sys.exit(1)
        
        

In [7]:
print(response.text)

The capital of New Zealand is **Wellington**.
