**Install Dependencies**

In [None]:
!pip install -U dspy
!pip install datasets
!pip install torch transformers

In [None]:
!curl -fsSL https://ollama.com/install.sh | sh

In [None]:
import subprocess
process = subprocess.Popen("ollama serve", shell=True)

In [None]:
!ollama pull llama3.1:8b




**Import Libraries**

In [None]:
# === Imports ===
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel  # Change the models to your prefered model
import dspy
from datasets import load_dataset
import numpy as np
import re
import matplotlib.pyplot as plt
import time
import random
from pprint import pprint
import subprocess
process = subprocess.Popen("ollama serve", shell=True)


**Configure DSPy & LLM**

In [None]:
# === Configure DSPy with Ollama LLM ===
lm = dspy.LM('ollama_chat/llama3.1:8b', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

**Load Training Datasets**

In [None]:
# === Load GSM8K and HotpotQA for training/testing ===
print("Loading datasets...")
gsm8k = load_dataset("gsm8k", "main", split="train[:10]")
hotpotqa = load_dataset("hotpot_qa", "fullwiki", split="train[:10]", trust_remote_code=True)

gsm8k_list = list(gsm8k)
hotpotqa_list = list(hotpotqa)

# Combine prompts and ground truths
data = [(ex["question"], ex["answer"].split("####")[-1].strip()) for ex in gsm8k_list] \
     + [(ex["question"], ex["answer"]) for ex in hotpotqa_list]

random.shuffle(data)
prompts, ground_truths = zip(*data)

prompts = list(prompts)
ground_truths = list(ground_truths)


**Action Space & Pool**

In [None]:
# === Define Action Space ===
MODULES = ["CoT", "Predict"]
SIGNATURES = [
    "question -> answer", "text -> summary", "question -> reasoning", "question -> hypothesis",
    "problem -> solution", "problem_description -> explanation", "context -> summary", "context -> briefing",
    "word_problem -> solution", "math_problem -> answer", "prompt -> response", "query -> response",
    "text -> response", "prompt -> generated_text", "query -> generated_text"
]
ACTIONS = MODULES + SIGNATURES + ["stop"]

**Policy Model Setup**

In [None]:
# === Initialize Policy Model (GPT-2) ===
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

policy_model = GPT2LMHeadModel.from_pretrained("gpt2")
policy_model.config.pad_token_id = tokenizer.pad_token_id

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_model.to(device)

# Try loading saved weights
save_path = "trained_policy_model.pt"
try:
    policy_model.load_state_dict(torch.load(save_path))
    print(f"✅ Loaded existing model from {save_path}")
except FileNotFoundError:
    print("⚠️ No existing model found, starting fresh...")


**Helper Functions**

In [None]:
# === Answer Comparison Function ===
def compare_answers(response, ground_truth):
    """
    Extract final answer from both response and ground truth.
    Prefer bracketed content (e.g., [5]), fallback to last number.
    Returns 1.0 if they match, else 0.0.
    """
    # Extract from response
    response_bracket_matches = re.findall(r'\[(.*?)\]', str(response))
    if response_bracket_matches:
        response_answer = response_bracket_matches[-1].strip()
        print(f"Response bracket match: {response_answer}")
    else:
        response_num_matches = re.findall(r'-?\d*\.?\d+', str(response))
        response_answer = response_num_matches[-1] if response_num_matches else None
        print(f"Response numeric fallback: {response_answer}")

    # Extract from ground truth
    gt_bracket_matches = re.findall(r'\[(.*?)\]', str(ground_truth))
    if gt_bracket_matches:
        gt_answer = gt_bracket_matches[-1].strip()
        print(f"GT bracket match: {gt_answer}")
    else:
        gt_num_matches = re.findall(r'-?\d*\.?\d+', str(ground_truth))
        gt_answer = gt_num_matches[-1] if gt_num_matches else None
        print(f"GT numeric fallback: {gt_answer}")

    if response_answer and gt_answer:
        return 1.0 if response_answer == gt_answer else 0.0
    return 0.0


def chain_response(prompt):
    """
    Processes a prompt through a fixed chain of dspy modules:
      1. retrieve: query -> passages
      2. cot: question, passages -> reasoning
      3. predict: question, reasoning -> answer

    Each module's output is passed as input to the next module.
    The final answer is returned.
    """
    try:

        # Module: cot (signature: question, passages -> reasoning)
        # We use the original prompt as the question.
        if prompt:
            response_cot = dspy.ChainOfThought("question -> reasoning")(question=prompt)
            reasoning = response_cot.get("reasoning", "")
        else:
            reasoning = ""

        # Module: predict (signature: reasoning -> answer)
        if reasoning:
            response_predict = dspy.Predict("question, reasoning -> answer")(question=f"system instruction: Must give your final answer in square brackets without fail! e.g., [final answer] like this: [5] \n prompt:{prompt}", reasoning=reasoning)
            answer = response_predict.get("answer", "")
        else:
            answer = ""

        return answer

    except Exception as e:
        print("Error during processing:", e)
        return None


def generate_pipeline(prompt, max_steps=3):
    """
    Generate a discrete DSPy pipeline using a fixed set of actions.
    This function restricts the output to a sequence of actions chosen from MODULES, SIGNATURES, and "stop".
    Note: Each action is approximated by using the logit of its first token.
    """
    state = {"prompt": prompt, "partial_pipeline": []}
    actions_taken = []
    log_probs = []

    for _ in range(max_steps):
        input_text = f"Prompt: {state['prompt']} Pipeline: {' '.join(state['partial_pipeline'])}"
        inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = policy_model(**inputs)
        logits = outputs.logits[:, -1, :]

        if not state["partial_pipeline"]:
            valid_actions = MODULES
        elif len(state["partial_pipeline"]) == 1 and state["partial_pipeline"][0] in MODULES:
            valid_actions = SIGNATURES
        elif len(state["partial_pipeline"]) == 2:
            valid_actions = ["stop"]
        else:
            break

        # Map each valid action to a token id (using first token as approximation)
        valid_token_ids = []
        for a in valid_actions:
            tokens = tokenizer.encode(a, add_special_tokens=False)
            if len(tokens) == 0:
                continue
            valid_token_ids.append(tokens[0])

        # Get logits and sample one action
        action_logits = logits[0, valid_token_ids]
        action_probs = torch.softmax(action_logits, dim=-1)
        action_choice = torch.multinomial(action_probs, 1).item()
        chosen_token_id = valid_token_ids[action_choice]
        action = valid_actions[action_choice]
        log_prob = torch.log(action_probs[action_choice])

        actions_taken.append(action)
        log_probs.append(log_prob)
        state["partial_pipeline"].append(action)

        if action == "stop":
            break

    return state["partial_pipeline"], actions_taken, log_probs

def execute_pipeline(prompt, pipeline):
    """
    Execute the DSPy pipeline using the dspy modules.
    The pipeline should be a list of two actions: [module, signature] (with an optional "stop" appended).
    """
    if len(pipeline) != 2 or pipeline[0] not in MODULES or pipeline[1] not in SIGNATURES:
        print("Invalid pipeline format")
        return None

    module, signature = pipeline

    if signature.count("->") != 1:
        print("Signature does not have proper format")
        return None

    match = re.match(r"\s*([a-zA-Z_ ]+)\s*->\s*([a-zA-Z_ ]+)\s*", signature)
    if not match:
        print("Regex match failed for signature")
        return None

    inputfield = match.group(1).strip().replace(" ", "_").lower()
    outputfield = match.group(2).strip().replace(" ", "_").lower()

    try:
        if module == "CoT":
            program = dspy.ChainOfThought(signature)
        else:  # "Predict"
            program = dspy.Predict(signature)
        response = program(**{inputfield: f"system instruction: Must give your final answer in square brackets without fail! e.g., [final answer] like this: [5] \n prompt: {prompt}"})
        fixedPipeLineResponse = chain_response(prompt)
        print("fixed: ", fixedPipeLineResponse)
        return response.get(outputfield), fixedPipeLineResponse
    except Exception as e:
        print(f"Execution pipeline failed: {e}")
        return None, None


def compute_reward(prompt, response, fixedPipeLineResponse, ground_truth):
    """
    Computes a reward for the response.
    First, a DSPy-format bonus is given if the response strictly matches the desired format.
    Then, it attempts to extract the final answer using regex (preferring bracketed content).
    The reward is based on matching the numeric (or string) content with the ground truth.
    """
    if response is None:
        print("Response is None")
        return 0.0

    regex1 = compare_answers(response, ground_truth)
    regex2 = compare_answers(fixedPipeLineResponse, ground_truth)
    print("regex1: ", regex1)
    print("regex2: ", regex2)

 # If regex score1 is 0, fall back to LLM judge
    if regex1 == 0.0:
        eval_prompt = f"""
        Evaluate whether the following response correctly answers the prompt based on the ground truth and give your final score in the square brackets[final score]. only the value in the [] should in your response and nothing else.
        Prompt: {prompt}
        Response: {response}
        Ground Truth: {ground_truth}
        Return a score in range between 1.0 to 0.0 if the response is correct or partially correct (matches or is equivalent to the ground truth), or 0.0 if incorrect.
        final score:[]
            """
        try:
            llm_response = lm(eval_prompt)
            print("LLM Response: ", llm_response)
            response_bracket_matches = re.findall(r'\[(.*?)\]', (llm_response[0]))
            print("llm matches: ", response_bracket_matches)
            if response_bracket_matches:
                response_answer = response_bracket_matches[-1].strip()  # Take the last bracketed content
                print(f"llm Response bracket match: {response_answer}")
            else:
                # Fallback: Extract the last numeric value if no brackets
                response_num_matches = re.findall(r'-?\d*\.?\d+', (llm_response[0]))
                response_answer = response_num_matches[-1] if response_num_matches else None
                print(f"llm Response last numeric match: {response_answer}")
            score_str = response_answer
            score = float(score_str.strip())
            print(f"LLM score: {score_str}, in float: {score}")
            # time.sleep(20)  # Wait 20 seconds after LLM call
            return min(max(score, 0.0), 1.0)

        except (ValueError, TypeError, IndexError):
            print("LLM evaluation failed, defaulting to 0.0")
            return 0.0
    # If regex2 is 0, fall back to LLM judge
    if regex2 == 0.0:
        eval_prompt = f"""
        Evaluate whether the following response correctly answers the prompt based on the ground truth and give your final score in the square brackets[final score]. only the value in the [] should in your response and nothing else.
        Prompt: {prompt}
        Response: {fixedPipeLineResponse}
        Ground Truth: {ground_truth}
        Return a score in range between 1.0 to 0.0 if the response is correct or partially correct (matches or is equivalent to the ground truth), or 0.0 if incorrect.
        final score:[]
            """
        try:
            llm_response2 = lm(eval_prompt)
            print("LLM Response2: ", llm_response2)
            response_bracket_matches = re.findall(r'\[(.*?)\]', (llm_response2[0]))
            print("llm matches2: ", response_bracket_matches)
            if response_bracket_matches:
                response_answer = response_bracket_matches[-1].strip()  # Take the last bracketed content
                print(f"llm Response bracket match2: {response_answer}")
            else:
                # Fallback: Extract the last numeric value if no brackets
                response_num_matches = re.findall(r'-?\d*\.?\d+', (llm_response2[0]))
                response_answer = response_num_matches[-1] if response_num_matches else None
                print(f"llm Response last numeric match2: {response_answer}")
            score_str = response_answer
            score = float(score_str.strip())
            print(f"LLM score2: {score_str}, in float: {score}")
            # time.sleep(20)  # Wait 20 seconds after LLM call
            return min(max(score, 0.0), 1.0)

        except (ValueError, TypeError, IndexError):
            print("LLM evaluation failed, defaulting to 0.0")
            return 0.0
    return (max(regex1,regex2))


# Run Test models
def test_model(trained_model, test_prompt):
    trained_model.eval()
    with torch.no_grad():
        pipeline, actions, _ = generate_pipeline(test_prompt)
        response = execute_pipeline(test_prompt, pipeline[:-1] if pipeline[-1]=="stop" else pipeline)
        print("=== Test Result ===")
        print(f"Prompt: {test_prompt}")
        print(f"Pipeline: {pipeline}")
        print(f"Actions: {actions}")
        print(f"Response: {response}")
    return response

**Training Model**

In [None]:
def train_rl(num_episodes=500, learning_rate=1e-4, save_path="trained_policy_model.pt"):
    """
    Train the policy model with REINFORCE.
    """
    optimizer = torch.optim.Adam(policy_model.parameters(), lr=learning_rate)
    rewards, losses = [], []
    baseline, baseline_alpha = 0.0, 0.9

    for episode in range(num_episodes):
        idx = np.random.randint(len(prompts))
        prompt, ground_truth = prompts[idx], ground_truths[idx]

        pipeline, actions, log_probs = generate_pipeline(prompt)
        response, fixedPipeLineResponse = execute_pipeline(prompt, pipeline[:-1] if pipeline[-1]=="stop" else pipeline)

        reward = compute_reward(prompt, response, fixedPipeLineResponse, ground_truth)

        # REINFORCE update
        baseline = baseline_alpha * baseline + (1 - baseline_alpha) * reward
        advantage = reward - baseline

        loss = -sum([lp * advantage for lp in log_probs])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        rewards.append(reward)
        losses.append(loss.item())

        print(f"Episode {episode+1}/{num_episodes} | Reward={reward:.3f}, Loss={loss.item():.3f}, Pipeline={pipeline}")

    torch.save(policy_model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")
    return policy_model, rewards, losses


**Run Training**

In [None]:
# Change the parameters to suit your needs
trained_model, rewards, losses = train_rl(num_episodes=5, learning_rate=2e-5)


**Load Testing Dataset**

In [None]:
# === Load test split of datasets ===
gsm8k_test = load_dataset("gsm8k", "main", split="test[:20]")
hotpotqa_test = load_dataset("hotpot_qa", "fullwiki", split="train[20:40]", trust_remote_code=True)

gsm8k_test_data = [dspy.Example(
    question=ex['question'],
    answer=ex["answer"].split("####")[-1].strip(),
    task_type='math').with_inputs('question')
    for ex in gsm8k_test]

hotpotqa_test_data = [dspy.Example(
    question=ex['question'],
    answer=ex['answer'],
    context=ex.get("context",""),
    task_type='qa').with_inputs('question')
    for ex in hotpotqa_test]

print(f"GSM8K test: {len(gsm8k_test_data)} | HotpotQA test: {len(hotpotqa_test_data)}")


**Testing Model**

In [None]:
def evaluate_module(dataset, model, query=False, output_file="evaluation_results.json"):
    """
    Evaluate the trained model on a dataset and save results to JSON.
    """
    total_correct = 0
    results = []
    start_time = time.time()

    print("Query mode:", query)

    for i, example in enumerate(dataset):
        print(f"------ Iteration {i} ------")
        prompt = f"{example.context}\n\n{example.question}" if query else example.question
        prediction_text = test_model(model, prompt)

        # direct comparison
        is_correct = prediction_text and example.answer and \
                     (prediction_text.strip().lower() == example.answer.strip().lower() or
                      example.answer.strip().lower() in prediction_text.strip().lower())

        score = 1 if is_correct else 0

        if not is_correct:  # fallback evaluation using LLM
            eval_prompt = f"""
            Evaluate if the response correctly answers the prompt based on the ground truth.
            Return [1] if correct, [0] if incorrect.

            Prompt: {example.question}
            Response: {prediction_text}
            Ground Truth: {example.answer}
            Final score: []
            """
            llm_res = lm(eval_prompt)
            score = 1 if "1" in llm_res[0] else 0

        total_correct += score

        result = {
            "question": example.question,
            "context": example.context if query else "N/A",
            "response": prediction_text,
            "ground_truth": example.answer,
            "score": score
        }
        results.append(result)

    total_time = time.time() - start_time

    # Save results
    with open(output_file, "w") as f:
        json.dump(results, f, indent=4)

    print(f"✅ Evaluation complete. Total correct = {total_correct}/{len(dataset)}")
    return total_time, total_correct


**Run Testing**

In [None]:
# Evaluate on HotpotQA test subset
total_time, total_correct = evaluate_module(hotpotqa_test_data, trained_model, query=True)
print(f"Evaluation time: {total_time:.2f}s | Correct: {total_correct}")


In [None]:
# Evaluate on GSM8K test subset
total_time, total_correct = evaluate_module(gsm8k_test_data, trained_model, query=True)
print(f"Evaluation time: {total_time:.2f}s | Correct: {total_correct}")
