In [1]:
%pip install datasets tqdm python-dotenv google-api-core --quiet
%pip install -U -q "google-genai>=1.10.0" --quiet

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [13]:
import os
import json
import time
import concurrent.futures
import threading 
from tqdm.notebook import tqdm
from google import genai
from google.genai import types
from dotenv import load_dotenv
from datasets import load_dataset

In [23]:
# --- Configuration ---
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY") # always restart kernel to make changes to this take effect
MODEL_ID = "gemini-2.5-flash-lite-preview-06-17"
OUTPUT_FILE = "./data/train_reasoning.jsonl"
FREE_LIMIT = 1000
MAX_CALLS_PER_RUN = 10

# --- Concurrency Configuration ---
logical_cores = os.cpu_count()
MAX_WORKERS = logical_cores * 2 # threads making api calls to gemini. NOTE: this is a conservative estimate.


In [24]:
sys_prompt = """
You are an AI language expert simulating the process of text editing. You will receive a source text (`src`) which contains an editing instruction followed by the text to be edited.

Your task is to:
1.  Carefully identify the editing instruction given at the beginning of the `src` text.
2.  Analyze the original text segment that follows the instruction in `src`.
3.  Apply the instruction step-by-step to the original text.
4.  For each step or change you make, articulate the precise reasoning based on linguistic rules, the instruction's goal (e.g., grammar, clarity, simplification), or common writing conventions. Explain *why* a change is necessary or beneficial according to the instruction.
5.  Your reasoning should demonstrate *how* to arrive at the corrected text, simulating the thought process of performing the edit yourself.

**Crucially:** Your reasoning process should *not* refer to or imply knowledge of a pre-existing target or corrected text. Focus *only* on applying the instruction to the source text to derive the necessary changes.

Your entire output must consist *only* of the step-by-step reasoning process, enclosed within `<think></think>` tags. Do not include the original instruction, the source text, the final corrected text, or any other text outside these tags.

**Example Interaction:**

**User Input (constructed from your dataset):**

```
Source (src): Fix grammar in this sentence: If engineers do not come up with new ideas, they cannot find best solution for the problems.
Target (tgt): If engineers do not come up with new ideas, they cannot find the best solution for different problems.
```

**Your Expected Output (demonstrating the reasoning to *derive* the correction):**

<think>
1. Instruction analysis: The task is to "Fix grammar". This requires identifying and correcting grammatical errors in the sentence "If engineers do not come up with new ideas, they cannot find best solution for the problems."
2. Sentence segment analysis: The first clause "If engineers do not come up with new ideas" appears grammatically sound.
3. Focus on the second clause: "they cannot find best solution for the problems."
4. Identify potential error 1: The noun phrase "best solution". "Best" is a superlative adjective. Grammatical rule: Superlative adjectives modifying singular countable nouns generally require a definite article ("the").
5. Apply correction 1: Insert "the" before "best solution". The phrase becomes "the best solution".
6. Identify potential error/improvement area 2: The phrase "for the problems". While grammatically plausible, using "the" implies specific, previously identified problems. The context seems general. Improving grammatical flow or clarity might involve adjusting this.
7. Consider alternatives for "the problems": "for problems" (too general?), "for specific problems" (adds info), "for different problems" (suggests variety, often fits general statements well).
8. Apply correction 2: Replace "the problems" with "different problems" for better contextual fit and naturalness, aligning with the goal of general grammatical improvement and clarity often included in "Fix grammar" tasks.
9. Synthesized correction based on reasoning: The corrected clause is "they cannot find the best solution for different problems." The full sentence derived from applying the grammar fixes is "If engineers do not come up with new ideas, they cannot find the best solution for different problems."
</think>
"""


In [25]:
# --- Main Processing Logic ---
if not gemini_api_key:
    print("Error: GEMINI_API_KEY not found in environment variables.")
else:
    # Initialize client
    client = genai.Client(api_key=gemini_api_key)

    # Load dataset
    print("Loading dataset...")
    try:
        train_ds = load_dataset("grammarly/coedit", split="train", cache_dir="./data")
        print(f"Dataset loaded: {len(train_ds)} rows")
        train_ds = train_ds.shuffle(seed=42)
    except Exception as e:
        print(f"Could not load dataset. {e}")
        train_ds = None # Ensure loop doesn't run

# Global flag for controlled shutdown
shutdown_requested = threading.Event()

# Custom exception for controlled shutdown
class APIShutdownException(Exception):
    """Exception raised to trigger a controlled shutdown of API processing."""
    pass        

# --- Helper Function to Load Processed IDs ---

def load_processed_ids(filename):
    processed = set()
    if os.path.exists(filename):
        with open(filename, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if '_id' in data:
                        processed.add(data['_id'])
                except json.JSONDecodeError:
                    print(f"Warning: Skipping invalid JSON line in {filename}")
    print(f"Loaded {len(processed)} processed IDs from {filename}")
    return processed

# --- Worker Function ---

def process_item(item, client, sys_prompt):
    """Processes a single dataset item to get reasoning."""
    item_id = item['_id']
    
    if shutdown_requested.is_set():
        return {"_id": item_id, "error": "Skipped due to shutdown request"}

    item_src = item['src']
    item_tgt = item['tgt']
    user_prompt = f"Source (src): {item_src}\nTarget (tgt): {item_tgt}"

    try:
        # ---  API Call ---
        response = client.models.generate_content(
            model=MODEL_ID,
            contents=user_prompt,
            config=types.GenerateContentConfig(
            system_instruction=sys_prompt,
                    thinking_config=types.ThinkingConfig(
                        thinking_budget=600 # 0 means no thinking None means no limit to thinking it can take however many tokens it wants.
                    )
                )
            )
        
        # --- Process Response ---
        reasoning_text = response.text

        if reasoning_text and reasoning_text.strip().startswith("```xml") and reasoning_text.strip().endswith("```"):
            reasoning_text = reasoning_text[6:-3].strip() # removing the xml codefences

        if reasoning_text.strip().startswith("<think>") and reasoning_text.strip().endswith("</think>"):
            return {"_id": item_id, "reasoning": reasoning_text.strip()}
        else:
            reasoning_text = "<think>" + reasoning_text + "</think>"
            return {"_id": item_id, "reasoning": reasoning_text.strip()}
    except Exception as e:
        if hasattr(e, 'code') and 400 <= e.code < 600:
            error_msg = f"Critical HTTP error {e.code}: {e}"
            print(f"\n{error_msg}")
            # Signal all threads to stop
            shutdown_requested.set()
            # Propagate the exception to interrupt processing
            raise APIShutdownException(error_msg)
        
        # --- Handle API or Processing Errors ---
        print(f"\nError processing ID {item_id}: {e}")
        return {"_id": item_id, "error": str(e)} # Return error indicator



Loading dataset...
Dataset loaded: 69071 rows


In [26]:

# --- Main processing section ---
if train_ds:
    processed_ids = load_processed_ids(OUTPUT_FILE)
    api_calls_made_this_run = 0
    file_lock = threading.Lock()
    items_to_process = []

    print("Filtering items to process...")
    for item in tqdm(train_ds, desc="Filtering"):
        if item['_id'] not in processed_ids:
            items_to_process.append(item)
    print(f"Found {len(items_to_process)} items needing processing.")

if MAX_CALLS_PER_RUN <= FREE_LIMIT:
    print(f"Entering Free tier. Processing {MAX_CALLS_PER_RUN} items sequentially...")
    for i in tqdm(range(MAX_CALLS_PER_RUN), desc="Processing items", total=MAX_CALLS_PER_RUN):
        print(f"Processing item {i+1} of {MAX_CALLS_PER_RUN}")
        time.sleep(1) # to avoid rpm rate limit (10/min if you go over you incur costs)
        response = process_item(items_to_process[i], client, sys_prompt)
        if response and 'error' not in response:
            with open(OUTPUT_FILE, 'a', encoding='utf-8') as outfile:
                json_line = json.dumps(response)
                outfile.write(json_line + '\n')
                outfile.flush()
                processed_ids.add(response['_id'])
        else:
            print(f"Error processing item {i} of {MAX_CALLS_PER_RUN}: {response['error'] if response else 'Unknown error'}")

elif MAX_CALLS_PER_RUN > FREE_LIMIT:
    print(f"Entering Paid tier.Processing {MAX_CALLS_PER_RUN} items concurrently...")
    # Process items concurrently
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as outfile:
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                # stores ids of items that will be processed
                futures = {}
                
                # Submit tasks only up to the daily limit
                num_to_submit = min(len(items_to_process), MAX_CALLS_PER_RUN - api_calls_made_this_run)
                print(f"Submitting {num_to_submit} tasks to the executor...")

                submitted_count = 0
                for i in range(len(items_to_process)):
                    if submitted_count >= num_to_submit or shutdown_requested.is_set():
                        break  # Stop submitting if we hit the limit or shutdown requested
                    item = items_to_process[i]
                    future = executor.submit(process_item, item, client, sys_prompt)
                    futures[future] = item['_id']
                    submitted_count += 1

                print(f"Tasks submitted ({submitted_count}). Waiting for results...")

                # Process results as they complete
                completed_futures = []

                for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing Results"):
                    if shutdown_requested.is_set():
                        print("\nShutdown requested. Canceling remaining tasks...")
                        for f in futures:
                            if f not in completed_futures and not f.done():
                                f.cancel()
                        break
        
                    completed_futures.append(future)
                    item_id = futures[future]       
                    
                    try:
                        result_data = future.result()
                        
                        if result_data and "error" not in result_data:
                            with file_lock:
                                if result_data['_id'] not in processed_ids:
                                    json_line = json.dumps(result_data)
                                    outfile.write(json_line + '\n')
                                    outfile.flush()
                                    processed_ids.add(result_data['_id'])
                                    api_calls_made_this_run += 1
                                    
                    except APIShutdownException:
                        print(f"Task for item {item_id} aborted due to shutdown request")
                    except Exception as exc:
                        print(f'\nItem ID {item_id} generated an exception: {exc}')
                        
        except KeyboardInterrupt:
            print("\nKeyboard interrupt detected. Shutting down gracefully...")
            shutdown_requested.set()
            
    # Final status report
    if shutdown_requested.is_set():
        print("\nProcessing terminated early due to errors or interruption.")
    else:
        print(f"\nFinished processing batch. Made approximately {api_calls_made_this_run} successful API calls in this run.")
    
    print(f"Total processed IDs (including previous runs): {len(processed_ids)}")

else:
    print("No items to process")


Loaded 56739 processed IDs from ./data/train_reasoning.jsonl
Filtering items to process...


Filtering:   0%|          | 0/69071 [00:00<?, ?it/s]

Found 12332 items needing processing.
Entering Free tier. Processing 10 items sequentially...


Processing items:   0%|          | 0/10 [00:00<?, ?it/s]

Processing item 1 of 10
Processing item 2 of 10
Processing item 3 of 10
Processing item 4 of 10
Processing item 5 of 10
Processing item 6 of 10
Processing item 7 of 10
Processing item 8 of 10
Processing item 9 of 10
Processing item 10 of 10


In [27]:
# Pricing per 1 Million tokens
INPUT_PRICE_PER_MILLION = 0.15
OUTPUT_NON_THINKING_PRICE_PER_MILLION = 0.60
OUTPUT_THINKING_PRICE_PER_MILLION = 3.50

# --- Get Token Counts from User ---
print("Enter the token counts (use 0 if Thoughts tokens is None or zero):")

prompt_tokens = 800
thoughts_tokens = 393 # Assume 0 if None was the actual value
output_tokens = 609

# --- Calculate Costs ---

# Input Cost
input_cost = (prompt_tokens / 1_000_000) * INPUT_PRICE_PER_MILLION

# Output Cost (depends on whether thinking tokens were generated)
if thoughts_tokens > 0:
    # Thinking was used - price applies to output + thoughts tokens
    billable_output_tokens = output_tokens + thoughts_tokens
    output_cost = (billable_output_tokens / 1_000_000) * OUTPUT_THINKING_PRICE_PER_MILLION
    print("\nCalculating using THINKING output price ($3.50 / 1M tokens)")
else:
    # No thinking - price applies only to output tokens
    billable_output_tokens = output_tokens
    output_cost = (billable_output_tokens / 1_000_000) * OUTPUT_NON_THINKING_PRICE_PER_MILLION
    print("\nCalculating using NON-THINKING output price ($0.60 / 1M tokens)")

# Total Cost
total_cost = input_cost + output_cost

# --- Display Results ---
print("--- Cost Breakdown ---")
print(f"Input Cost:       ${input_cost:.6f}")
print(f"Output Cost:      ${output_cost:.6f} (based on {billable_output_tokens} billable output tokens)")
print("----------------------")
print(f"Total Estimated Cost: ${total_cost:.6f}")

# Example calculation for 10,000 identical calls
num_calls = 10000
total_cost_10k = total_cost * num_calls
print(f"\nEstimated Cost for {num_calls:,} calls: ${total_cost_10k:.2f}")

Enter the token counts (use 0 if Thoughts tokens is None or zero):

Calculating using THINKING output price ($3.50 / 1M tokens)
--- Cost Breakdown ---
Input Cost:       $0.000120
Output Cost:      $0.003507 (based on 1002 billable output tokens)
----------------------
Total Estimated Cost: $0.003627

Estimated Cost for 10,000 calls: $36.27


In [17]:
# test out other models or inspect api errors

import sys

prompt = """
    what is the capital of the moon?
"""
client = genai.Client(api_key=gemini_api_key)
try:
    response = client.models.generate_content(
        model="gemini-2.5-flash-lite-preview-06-17",
        # model="gemini-2.0-flash",
        contents=prompt
    )
except Exception as e:
    if 400 <= e.code < 600:
        print(f"Exiting due to HTTP error {e.code}: {e}")
        sys.exit(1)
        
        

In [18]:
print(response.text)

The Moon does not have a capital.

It's not a country or a political entity with a government. Currently, there are no permanent human settlements or established governing bodies on the Moon that would designate a capital city.

However, in science fiction, you'll often find imagined lunar cities or bases, but these are purely fictional concepts!


# -- Uploading augmented dataset to hf --

In [28]:
from datasets import load_dataset, DatasetDict
from huggingface_hub import login
import json
import os

In [None]:
# load the original coedit dataset

og_coedit_train = load_dataset("grammarly/coedit", split="train")
og_coedit_val = load_dataset("grammarly/coedit", split="validation")

print(len(og_coedit_train), og_coedit_train[0])
print(len(og_coedit_val), og_coedit_val[0])

69071 {'_id': '1', 'task': 'gec', 'src': 'Remove all grammatical errors from this text: For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert.', 'tgt': 'For example, countries with a lot of deserts can transform their desert to increase their habitable land and use irrigation to provide clean water to the desert.'}
1712 {'_id': '1', 'task': 'paraphrase', 'src': 'Paraphrase this sentence: Why are you arresting me?', 'tgt': 'Why am I being arrested?'}


In [3]:
# load the reasoning traces
train_reasoning_traces = load_dataset("json", data_files="./data/train_reasoning.jsonl", split="train")
val_reasoning_traces = load_dataset("json", data_files="./data/val_reasoning.jsonl", split="train")

print(train_reasoning_traces[0])
print(val_reasoning_traces[0])

print(f"Train reasoning traces: {len(train_reasoning_traces)} examples")
print(f"Val reasoning traces: {len(val_reasoning_traces)} examples")

{'_id': '1', 'reasoning': '<think>\n1.  **Instruction Analysis:** The goal is to "Remove all grammatical errors" from the provided text. I need to read the text carefully and identify any points that violate standard English grammar rules.\n2.  **Source Text Analysis:** The text is "For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert."\n3.  **Sentence Structure Review:** The sentence starts with an introductory phrase "For example". The main subject is "countries with a lot of deserts". The main verb phrase is "can [verb]".\n4.  **Identify Potential Error 1:** The phrase "can terraform their desert to increase their habitable land" appears structurally sound. "Can" is followed by the base form of the verb "terraform". The infinitive "to increase" explains the purpose.\n5.  **Identify Potential Error 2:** The coordinating conjunction "and" connects the action "terraform their d

In [4]:
train_reasoning_map = {row['_id']: row['reasoning'] for row in train_reasoning_traces}
val_reasoning_map = {row['_id']: row['reasoning'] for row in val_reasoning_traces}

In [None]:
def add_reasoning(dataset, reasoning_map):
    dataset["reasoning"] = reasoning_map.get(dataset["_id"], None)
    return dataset

print("\nAdding reasoning traces to train split...")
train_with_reasoning = og_coedit_train.map(
    lambda x: add_reasoning(x, train_reasoning_map),
    desc="Adding reasoning to train"
)

print("Adding reasoning traces to validation split...")
val_with_reasoning = og_coedit_val.map(
    lambda x: add_reasoning(x, val_reasoning_map),
    desc="Adding reasoning to validation"
)


train_with_reasoning_count = sum(1 for x in train_with_reasoning if x['reasoning'] is not None)
val_with_reasoning_count = sum(1 for x in val_with_reasoning if x['reasoning'] is not None)

print("Reasoning coverage:")
print(f"Train: {train_with_reasoning_count}/{len(train_with_reasoning)} ({train_with_reasoning_count/len(train_with_reasoning)*100:.1f}%)")
print(f"Val: {val_with_reasoning_count}/{len(val_with_reasoning)} ({val_with_reasoning_count/len(val_with_reasoning)*100:.1f}%)")



Adding reasoning traces to train split...
Adding reasoning traces to validation split...
Reasoning coverage:
Train: 56736/69071 (82.1%)
Val: 1712/1712 (100.0%)


In [6]:
# create dataset to upload onto huggingface
print("\nCreating DatasetDict...")
coedit_w_reasoning = DatasetDict({
    "train": train_with_reasoning,
    "validation": val_with_reasoning
})

print("\nDataset preview:")
print("Train example:", coedit_w_reasoning["train"][0])
print("Validation example:", coedit_w_reasoning["validation"][0])



Creating DatasetDict...

Dataset preview:
Train example: {'_id': '1', 'task': 'gec', 'src': 'Remove all grammatical errors from this text: For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert.', 'tgt': 'For example, countries with a lot of deserts can transform their desert to increase their habitable land and use irrigation to provide clean water to the desert.', 'reasoning': '<think>\n1.  **Instruction Analysis:** The goal is to "Remove all grammatical errors" from the provided text. I need to read the text carefully and identify any points that violate standard English grammar rules.\n2.  **Source Text Analysis:** The text is "For example, countries with a lot of deserts can terraform their desert to increase their habitable land and using irrigation to provide clean water to the desert."\n3.  **Sentence Structure Review:** The sentence starts with an introductory phrase "F

In [9]:
# initialize write token for huggingface
import os
token = os.environ.get("HF_WRITE_TOKEN")
login(token=token)
print('logged in')

logged in


In [10]:
hf_repo = "muzzz/coedit-w-reasoning-traces" # USERNAME/REPO_NAME

coedit_w_reasoning.push_to_hub(
    hf_repo,
    private=True,
    commit_message="Added reasoning traces to coedit dataset",
)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/70 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/muzzz/coedit-w-reasoning-traces/commit/b8f25c337f2c136f789c8028fcb3484090aa42c4', commit_message='Added reasoning traces to coedit dataset', commit_description='', oid='b8f25c337f2c136f789c8028fcb3484090aa42c4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/muzzz/coedit-w-reasoning-traces', endpoint='https://huggingface.co', repo_type='dataset', repo_id='muzzz/coedit-w-reasoning-traces'), pr_revision=None, pr_num=None)