**Installing Dependencies**

In [None]:
!pip install -U dspy
!pip install datasets
!pip install torch transformers

Collecting dspy
  Downloading dspy-3.0.3-py3-none-any.whl.metadata (7.2 kB)
Collecting backoff>=2.2 (from dspy)
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Collecting optuna>=3.4.0 (from dspy)
  Downloading optuna-4.5.0-py3-none-any.whl.metadata (17 kB)
Collecting magicattr>=0.1.6 (from dspy)
  Downloading magicattr-0.1.6-py2.py3-none-any.whl.metadata (3.2 kB)
Collecting litellm>=1.64.0 (from dspy)
  Downloading litellm-1.77.5-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.4/42.4 kB[0m [31m923.5 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting diskcache>=5.6.0 (from dspy)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting json-repair>=0.30.0 (from dspy)
  Downloading json_repair-0.51.0-py3-none-any.whl.metadata (11 kB)
Collecting asyncer==0.0.8 (from dspy)
  Downloading asyncer-0.0.8-py3-none-any.whl.metadata (6.7 kB)
Collecting gepa==0.0.7 (from gepa[dspy]==0.0.7->dspy)
  Downloading gepa-0.

**Setting Up Ollama**

In [None]:
!curl -fsSL https://ollama.com/install.sh | sh

>>> Installing ollama to /usr/local
>>> Downloading Linux amd64 bundle
######################################################################## 100.0%
>>> Creating ollama user...
>>> Adding ollama user to video group...
>>> Adding current user to ollama group...
>>> Creating ollama systemd service...
>>> The Ollama API is now available at 127.0.0.1:11434.
>>> Install complete. Run "ollama" from the command line.


In [None]:
import subprocess
process = subprocess.Popen("ollama serve", shell=True)

In [None]:
!ollama pull llama3.1:8b

[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A

**Import Libraries**

In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import dspy
from datasets import load_dataset
import numpy as np
import re
import matplotlib.pyplot as plt
import time
import random
import json
from pprint import pprint
import logging
process = subprocess.Popen("ollama serve", shell=True)

**Configure DSPy & LLM**

In [None]:
# === Configure DSPy with Ollama LLM ===
lm = dspy.LM('ollama_chat/llama3.1:8b', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

**Load Training Datasets**

In [None]:
# === Load GSM8K and HotpotQA for training/testing ===
print("Loading datasets...")
gsm8k = load_dataset("gsm8k", "main", split="train[:10]")
hotpotqa = load_dataset("hotpot_qa", "fullwiki", split="train[:10]", trust_remote_code=True)

gsm8k_list = list(gsm8k)
hotpotqa_list = list(hotpotqa)

# Combine prompts and ground truths
data = [(ex["question"], ex["answer"].split("####")[-1].strip()) for ex in gsm8k_list] \
     + [(ex["question"], ex["answer"]) for ex in hotpotqa_list]

random.shuffle(data)
prompts, ground_truths = zip(*data)

prompts = list(prompts)
ground_truths = list(ground_truths)


Loading datasets...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'hotpot_qa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'hotpot_qa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


**Action Space & Pool**

In [None]:
# === Define Action Space ===
MODULES = ["CoT", "Predict"]
SIGNATURES = [
    "question -> answer", "text -> summary", "question -> reasoning", "question -> hypothesis",
    "problem -> solution", "problem_description -> explanation", "context -> summary", "context -> briefing",
    "word_problem -> solution", "math_problem -> answer", "prompt -> response", "query -> response",
    "text -> response", "prompt -> generated_text", "query -> generated_text"
]
ACTIONS = MODULES + SIGNATURES + ["stop"]

**Policy Model Setup**

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

policy_model = model
policy_model.config.pad_token_id = tokenizer.pad_token_id
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_model.to(device)

print(f"Using device: {device}")
print("GPT-2 model loaded successfully")

# ============================================================
# Cell 6: Load Existing Model (Optional)
save_path = "gpt2_trained_policy_model_grpo.pt"
try:
    policy_model.load_state_dict(torch.load(save_path))
    print(f"Loaded existing model from {save_path} for retraining")
except FileNotFoundError:
    print("No existing model found, starting fresh")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Using device: cpu
GPT-2 model loaded successfully
No existing model found, starting fresh


**Helper Functions**

In [None]:
# ============================================================
# Compare Answers Functions
def compare_answers(response, ground_truth):
    """
    Extracts the final answer from both the response and the ground truth,
    then compares them. The extraction first looks for bracketed content;
    if not found, it falls back to the last numeric match.

    Returns:
        regex_score (float): 1.0 if the answers match, 0.0 otherwise.
    """
    # Extract the answer from response: first try bracketed content
    response_bracket_matches = re.findall(r'\[(.*?)\]', str(response))
    if response_bracket_matches:
        response_answer = response_bracket_matches[-1].strip()
        print(f"Response bracket match: {response_answer}")
    else:
        response_num_matches = re.findall(r'-?\d*\.?\d+', str(response))
        response_answer = response_num_matches[-1] if response_num_matches else None
        print(f"Response last numeric match: {response_answer}")

    # Extract answer from ground truth similarly
    gt_bracket_matches = re.findall(r'\[(.*?)\]', str(ground_truth))
    if gt_bracket_matches:
        gt_answer = gt_bracket_matches[-1].strip()
        print(f"Ground truth bracket match: {gt_answer}")
    else:
        gt_num_matches = re.findall(r'-?\d*\.?\d+', str(ground_truth))
        gt_answer = gt_num_matches[-1] if gt_num_matches else None
        print(f"Ground truth last numeric match: {gt_answer}")

    # Compare the extracted answers
    if response_answer and gt_answer:
        response_num_match = re.search(r'-?\d*\.?\d+', response_answer)
        gt_num_match = re.search(r'-?\d*\.?\d+', gt_answer)
        if response_num_match and gt_num_match:
            response_num = response_num_match.group(0)
            gt_num = gt_num_match.group(0)
            regex_score = 1.0 if response_num == gt_num else 0.0
            print(f"Numeric comparison - Response: {response_num}, Ground Truth: {gt_num}, Regex Score: {regex_score}")
        else:
            regex_score = 1.0 if response_answer == gt_answer else 0.0
            print(f"String comparison - Response: {response_answer}, Ground Truth: {gt_answer}, Regex Score: {regex_score}")
    else:
        regex_score = 0.0
        print("No valid match found, Regex Score: 0.0")

    return regex_score

# ============================================================
# Compute Reward Function
def compute_reward(prompt, response, ground_truth):
    """
    Computes a reward for the response.
    First, a DSPy-format bonus is given if the response strictly matches the desired format.
    Then, it attempts to extract the final answer using regex (preferring bracketed content).
    The reward is based on matching the numeric (or string) content with the ground truth.
    """
    if response is None:
        print("Response is None")
        return 0.0

    regex1 = compare_answers(response, ground_truth)
    score = regex1
    print("score: ", score)

    # If regex score1 is 0, fall back to LLM judge
    if regex1 == 0.0:
        eval_prompt = f"""
        Evaluate whether the following response correctly answers the prompt based on the ground truth and give your final score in the square brackets[final score]. only the value in the [] should in your response and nothing else.
        Prompt: {prompt}
        Response: {response}
        Ground Truth: {ground_truth}
        Return a score in range between 1.0 to 0.0 if the response is correct or partially correct (matches or is equivalent to the ground truth), or 0.0 if incorrect.
        final score:[]
            """
        try:
            llm_response = lm(eval_prompt)
            print("LLM Response: ", llm_response)
            response_bracket_matches = re.findall(r'\[(.*?)\]', (llm_response[0]))
            print("llm matches: ", response_bracket_matches)
            if response_bracket_matches:
                response_answer = response_bracket_matches[-1].strip()
                print(f"llm Response bracket match: {response_answer}")
            else:
                response_num_matches = re.findall(r'-?\d*\.?\d+', (llm_response[0]))
                response_answer = response_num_matches[-1] if response_num_matches else None
                print(f"llm Response last numeric match: {response_answer}")
            score_str = response_answer
            score = float(score_str.strip())
            print(f"LLM score: {score_str}, in float: {score}")
            return min(max(score, 0.0), 1.0)

        except (ValueError, TypeError, IndexError):
            print("LLM evaluation failed, defaulting to 0.0")
            return 0.0

    return score

# ============================================================
# Generate Pipeline Function
def generate_pipeline(prompt, max_steps=3):
    """
    Generate a discrete DSPy pipeline using a fixed set of actions.
    This function restricts the output to a sequence of actions chosen from MODULES, SIGNATURES, and "stop".
    Note: Each action is approximated by using the logit of its first token.
    """
    state = {"prompt": prompt, "partial_pipeline": []}
    actions_taken = []
    log_probs = []

    for _ in range(max_steps):
        input_text = f"Prompt: {state['prompt']} Pipeline: {' '.join(state['partial_pipeline'])}"
        inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = policy_model(**inputs)
        logits = outputs.logits[:, -1, :]

        if not state["partial_pipeline"]:
            valid_actions = MODULES
        elif len(state["partial_pipeline"]) == 1 and state["partial_pipeline"][0] in MODULES:
            valid_actions = SIGNATURES
        elif len(state["partial_pipeline"]) == 2:
            valid_actions = ["stop"]
        else:
            break

        # Map each valid action to a token id (using first token as approximation)
        valid_token_ids = []
        for a in valid_actions:
            tokens = tokenizer.encode(a, add_special_tokens=False)
            if len(tokens) == 0:
                continue
            valid_token_ids.append(tokens[0])

        # Get logits and sample one action
        action_logits = logits[0, valid_token_ids]
        action_probs = torch.softmax(action_logits, dim=-1)
        action_choice = torch.multinomial(action_probs, 1).item()
        chosen_token_id = valid_token_ids[action_choice]
        action = valid_actions[action_choice]
        log_prob = torch.log(action_probs[action_choice])

        actions_taken.append(action)
        log_probs.append(log_prob)
        state["partial_pipeline"].append(action)

        if action == "stop":
            break

    return state["partial_pipeline"], actions_taken, log_probs

# ============================================================
# Execute Pipeline Function
def execute_pipeline(prompt, pipeline):
    """
    Execute the DSPy pipeline using the dspy modules.
    The pipeline should be a list of two actions: [module, signature] (with an optional "stop" appended).
    """
    if len(pipeline) != 2 or pipeline[0] not in MODULES or pipeline[1] not in SIGNATURES:
        print("Invalid pipeline format")
        return None

    module, signature = pipeline

    if signature.count("->") != 1:
        print("Signature does not have proper format")
        return None

    match = re.match(r"\s*([a-zA-Z_ ]+)\s*->\s*([a-zA-Z_ ]+)\s*", signature)
    if not match:
        print("Regex match failed for signature")
        return None

    inputfield = match.group(1).strip().replace(" ", "_").lower()
    outputfield = match.group(2).strip().replace(" ", "_").lower()

    try:
        if module == "CoT":
            program = dspy.ChainOfThought(signature)
        else:  # "Predict"
            program = dspy.Predict(signature)
        response = program(**{inputfield: f"system instruction: Must give your final answer in square brackets without fail! e.g., [final answer] like this: [5] \n prompt: {prompt}"})
        return response.get(outputfield)
    except Exception as e:
        print(f"Execution pipeline failed: {e}")
        return None

# ============================================================
# Testing The Trained Model Function
def test_model(trained_model, test_prompt, max_steps=3):
    """
    Evaluate the trained model on a test prompt by generating a discrete DSPy pipeline
    and then executing it.
    """
    trained_model.eval()
    with torch.no_grad():
        pipeline, actions, _ = generate_pipeline(test_prompt, max_steps=max_steps)
        response = execute_pipeline(test_prompt, pipeline[:-1] if pipeline[-1] == "stop" else pipeline)
        print(f"Test Prompt: {test_prompt}")
        print(f"Generated Pipeline: {pipeline}")
        print(f"Actions Taken: {actions}")
        print(f"Response: {response}")
    return response

**Training Model**

In [None]:
def train_grpo(num_episodes=500, learning_rate=1e-4, K=4, save_path="gpt2_trained_policy_model_grpo.pt"):
    """
    Train the policy model using a GRPO-inspired update with group-based advantage estimation.
    For each episode, generate K pipelines per prompt, compute group average reward, and use
    reward - group_average as the advantage for policy updates.
    """
    optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
    rewards = []
    losses = []

    for episode in range(num_episodes):
        idx = np.random.randint(len(prompts))
        prompt = prompts[idx]
        ground_truth = ground_truths[idx]

        # Generate K pipelines for the same prompt
        pipelines = []
        actions_list = []
        log_probs_list = []
        rewards_list = []

        for _ in range(K):
            pipeline, actions, log_probs = generate_pipeline(prompt)
            # Execute discrete DSPy module (exclude "stop" if present)
            response = execute_pipeline(prompt, pipeline[:-1] if pipeline[-1] == "stop" else pipeline)
            reward = compute_reward(prompt, response, ground_truth)
            pipelines.append(pipeline)
            actions_list.append(actions)
            log_probs_list.append(log_probs)
            rewards_list.append(reward)

        # Compute group average reward
        group_average = np.mean(rewards_list)

        # Compute advantages for each pipeline
        advantages = [reward - group_average for reward in rewards_list]

        # Compute policy loss
        loss = 0
        for k in range(K):
            advantage = advantages[k]
            log_probs = log_probs_list[k]
            for log_prob in log_probs:
                loss -= log_prob * advantage

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log average reward and loss
        avg_reward = np.mean(rewards_list)
        rewards.append(avg_reward)
        losses.append(loss.item())

        print(f"Episode {episode}, Prompt: {prompt[:20]}..., Avg Reward: {avg_reward:.3f}, Loss: {loss.item():.3f}")

    torch.save(policy_model.state_dict(), save_path)
    print(f"Model saved to {save_path}")
    return policy_model


**Run Training**

In [None]:
trained_model = train_grpo(num_episodes=2, learning_rate=2e-5, K=4, save_path="gpt2_trained_policy_model_grpo.pt")

Response last numeric match: 42
Ground truth last numeric match: 42
Numeric comparison - Response: 42, Ground Truth: 42, Regex Score: 1.0
score:  1.0
Response bracket match: 42
Ground truth last numeric match: 42
Numeric comparison - Response: 42, Ground Truth: 42, Regex Score: 1.0
score:  1.0
Response last numeric match: 42
Ground truth last numeric match: 42
Numeric comparison - Response: 42, Ground Truth: 42, Regex Score: 1.0
score:  1.0


In [None]:
def evaluate_module(dataset, fulldataset, model, query=False, output_file="evaluation_results.json"):
    """
    Evaluates the given module on a dataset of examples and stores the results in a JSON file.

    Parameters:
      dataset: A list of examples, where each example has at least:
               - example.question (input text)
               - example.answer (expected answer)
      model: The model used to generate responses.
      query: If True, retrieves context from Wikipedia.
      output_file: The JSON file where results will be stored.

    Returns:
      total_time: Total elapsed time (in seconds).
      total_correct: Number of correct responses.
    """
    total_correct = 0
    results = []
    start_time = time.time()

    print("Query mode:", query)

    for i, example in enumerate(dataset):
        print(f"-------------------------------------Iteration: {i}------------------------------------------------")

        prompt = f"{example.context} \n\n {example.question} " if query else example.question
        prediction_text = test_model(model, prompt)

        if example.answer:
            is_correct = (
                prediction_text and prediction_text.strip().lower() == example.answer.strip().lower()
            ) or (
                prediction_text and example.answer.strip().lower() in prediction_text.strip().lower()
            )
        else:
            is_correct = False

        score = 1 if is_correct else 0

        if not is_correct:
            eval_prompt = f"""
                Evaluate whether the following response correctly answers the prompt based on the ground truth.
                Return a score of 1.0 or 0.0 if the response is correct and 0 if incorrect.
                Only return the score inside the square brackets [].

                Prompt: {example.question}
                Response: {prediction_text}
                Ground Truth: {example.answer}
                Final score:[]"""

            llm_res = lm(eval_prompt)
            score = 1 if "1" in llm_res[0] else 0

        total_correct += score

        print("Ground Truth: ", example.answer)
        print("Score: ", score)
        print("------------------------ total correct so far: ", total_correct, "--------------------------")
        if query:
            result = {
                "question": example.question,
                "context": example.context,
                "response": prediction_text,
                "ground_truth": example.answer if example.answer else "N/A",
                "score": score,
            }
            results.append(result)

    total_time = time.time() - start_time

    if query:
        with open(output_file, "w") as f:
            json.dump(results, f, indent=4)

    return total_time, total_correct

**Loading Test Datasets**

In [None]:
gsm8k_test = load_dataset("gsm8k", "main", split="test[:10]")
hotpotqa_test = load_dataset("hotpot_qa", "fullwiki", split="train[20:30]", trust_remote_code=True)

gsm8k_test_data = [dspy.Example(question=ex['question'], answer=ex["answer"].split("####")[-1].strip(), task_type='math').with_inputs('question') for ex in gsm8k_test]
hotpotqa_test_data = [dspy.Example(question=ex['question'], answer=ex['answer'], context=ex["context"], task_type='qa').with_inputs('question') for ex in hotpotqa_test]


In [None]:

# Testing HotpotQA
htotal_time, htotal_correct = evaluate_module(hotpotqa_test_data, hotpotqa_test, policy_model, trained_value_head, query=True)
print(f"HotPotQA - Total time: {htotal_time}, Total correct: {htotal_correct}")

In [None]:
# Testing GSM8K
gtotal_time, gtotal_correct = evaluate_module(gsm8k_test_data, gsm8k_test, policy_model, trained_value_head)
print(f"GSM8K - Total time: {gtotal_time}, Total correct: {gtotal_correct}")

Query mode: False
-------------------------------------Iteration: 0------------------------------------------------
Test Prompt: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Generated Pipeline: ['Predict', 'text -> response', 'stop']
Actions Taken: ['Predict', 'text -> response', 'stop']
Response: [Janet starts with 16 eggs per day. She eats 3 and bakes 4, so she has 16 - 3 - 4 = 9 eggs left. Since each egg is worth $2, she makes 9 x $2 = $18 every day at the farmers' market.]
Ground Truth:  18
Score:  1
------------------------ total correct so far:  1 --------------------------
-------------------------------------Iteration: 1------------------------------------------------
Test Prompt: A robe takes 2 bolts of blue fiber and half that much white fiber.