**Installing Dependencies**

In [None]:
!pip install -U dspy
!pip install datasets
!pip install torch transformers

**Setting Up Ollama**

In [None]:
!curl -fsSL https://ollama.com/install.sh | sh

In [None]:
import subprocess
process = subprocess.Popen("ollama serve", shell=True)

In [None]:
!ollama pull llama3.1:8b

**Import Libraries**

In [None]:
# === Imports ===
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel  # Change the models to your prefered model
import dspy
from datasets import load_dataset
import numpy as np
import re
import matplotlib.pyplot as plt
import time
import random
from pprint import pprint
import subprocess
process = subprocess.Popen("ollama serve", shell=True)


**Configure DSPy & LLM**

In [None]:
# === Configure DSPy with Ollama LLM ===
lm = dspy.LM('ollama_chat/llama3.1:8b', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

**Load Training Datasets**

In [None]:
# === Load GSM8K and HotpotQA for training/testing ===
print("Loading datasets...")
gsm8k = load_dataset("gsm8k", "main", split="train[:10]")
hotpotqa = load_dataset("hotpot_qa", "fullwiki", split="train[:10]", trust_remote_code=True)

gsm8k_list = list(gsm8k)
hotpotqa_list = list(hotpotqa)

# Combine prompts and ground truths
data = [(ex["question"], ex["answer"].split("####")[-1].strip()) for ex in gsm8k_list] \
     + [(ex["question"], ex["answer"]) for ex in hotpotqa_list]

random.shuffle(data)
prompts, ground_truths = zip(*data)

prompts = list(prompts)
ground_truths = list(ground_truths)


**Action Space & Pool**

In [None]:
# === Define Action Space ===
MODULES = ["CoT", "Predict"]
SIGNATURES = [
    "question -> answer", "text -> summary", "question -> reasoning", "question -> hypothesis",
    "problem -> solution", "problem_description -> explanation", "context -> summary", "context -> briefing",
    "word_problem -> solution", "math_problem -> answer", "prompt -> response", "query -> response",
    "text -> response", "prompt -> generated_text", "query -> generated_text"
]
ACTIONS = MODULES + SIGNATURES + ["stop"]

**Policy Model Setup**

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
policy_model = GPT2LMHeadModel.from_pretrained("gpt2")
policy_model.config.pad_token_id = tokenizer.pad_token_id
policy_model.to(device)

print("Policy model loaded")

# ============================================================
class ValueHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.value = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states):
        return self.value(hidden_states)

value_head = ValueHead(policy_model.config.n_embd).to(device)
print("Value head initialized")

# ============================================================
optimizer = Adam(list(policy_model.parameters()) + list(value_head.parameters()), lr=1e-4)
print("Optimizer initialized")

**Helper Functions**

In [None]:
def compute_reward(prompt, response, ground_truth):
    """
    Compute reward based on response accuracy compared to ground truth using LLM.
    """
    if response is None:
        return 0.0

    eval_prompt = f"""
    Evaluate the correctness of the response compared to the ground truth for the given prompt.
    Provide a numerical score between 0.0 and 1.0, where 1.0 means completely correct and 0.0 means completely incorrect.
    Return only the numeric score (e.g., 0.75) with no additional text.

    Prompt: {prompt}
    Response: {response}
    Ground Truth: {ground_truth}
    """

    try:
        llm_response = lm(eval_prompt)[0].strip()
        print(f"[LLM Judge] Response: '{llm_response}'")
        reward = float(llm_response)
        reward = max(0.0, min(1.0, reward))
    except ValueError:
        print(f"[LLM Judge] Invalid score format: '{llm_response}'. Defaulting to 0.0")
        reward = 0.0
    except Exception as e:
        print(f"[LLM Judge] Error during evaluation: {e}. Defaulting to 0.0")
        reward = 0.0

    return reward

# ============================================================
def execute_pipeline(prompt, pipeline):
    """
    Execute a DSPy pipeline based on the generated sequence.
    """
    if len(pipeline) != 2 or pipeline[0] not in MODULES or pipeline[1] not in SIGNATURES:
        print("Invalid pipeline format")
        return None

    module, signature = pipeline
    match = re.match(r"\s*([a-zA-Z_ ]+)\s*->\s*([a-zA-Z_ ]+)\s*", signature)
    inputfield = match.group(1).strip().replace(" ", "_").lower()
    outputfield = match.group(2).strip().replace(" ", "_").lower()

    try:
        if module == "CoT":
            program = dspy.ChainOfThought(signature)
        elif module == "PoT":
            program = dspy.ProgramOfThought(signature)
        else:  # "Predict"
            program = dspy.Predict(signature)

        response = program(**{
            inputfield: f"system instruction: Must give your final answer in square brackets without fail! e.g., [final answer] like this: [5] \n prompt: {prompt}"
        })
        return response.get(outputfield)

    except Exception as e:
        print(f"Execution pipeline failed: {e}")
        return None

# ============================================================
def generate_pipeline_ppo(prompt, max_steps=3):
    """
    Generate a DSPy pipeline using PPO, collecting trajectory data.
    """
    state = {"prompt": prompt, "partial_pipeline": []}
    states, actions, log_probs, values, rewards = [], [], [], [], []

    for step in range(max_steps):
        input_text = f"Prompt: {state['prompt']} Pipeline: {' '.join(state['partial_pipeline'])}"
        inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)

        outputs = policy_model(**inputs, output_hidden_states=True)
        logits = outputs.logits[:, -1, :]

        # Determine valid actions
        if not state["partial_pipeline"]:
            valid_actions = MODULES
        elif len(state["partial_pipeline"]) == 1 and state["partial_pipeline"][0] in MODULES:
            valid_actions = SIGNATURES
        elif len(state["partial_pipeline"]) == 2:
            valid_actions = ["stop"]
        else:
            break

        # Map actions to token IDs
        valid_token_ids = [tokenizer.encode(a, add_special_tokens=False)[0] for a in valid_actions]
        action_logits = logits[0, valid_token_ids]
        action_probs = torch.softmax(action_logits, dim=-1)

        # Sample action
        action_idx = torch.multinomial(action_probs, 1).item()
        action = valid_actions[action_idx]
        log_prob = torch.log(action_probs[action_idx])

        # Compute value
        hidden_states = outputs.hidden_states[-1][:, -1, :]
        value = value_head(hidden_states).item()

        # Debug output
        print(f"Step {step+1}:")
        print(f"  Input: {input_text}")
        print(f"  Valid Actions: {valid_actions}")
        print(f"  Selected Action: {action}")
        print(f"  Log Prob: {log_prob.item():.4f}, Value: {value:.4f}\n")

        # Store trajectory data
        states.append(input_text)
        actions.append(action)
        log_probs.append(log_prob.detach())
        values.append(value)
        rewards.append(0.0)

        state["partial_pipeline"].append(action)
        if action == "stop":
            break

    # Execute pipeline and compute final reward
    pipeline = state["partial_pipeline"][:-1] if state["partial_pipeline"] and state["partial_pipeline"][-1] == "stop" else state["partial_pipeline"]
    response = execute_pipeline(prompt, pipeline)
    ground_truth = ground_truths[prompts.index(prompt)]
    reward = compute_reward(prompt, response, ground_truth)
    rewards[-1] = reward

    print(f"Final Pipeline: {pipeline}")
    print(f"Response: {response}")
    print(f"Computed Reward: {reward}\n")

    return states, actions, log_probs, rewards, values

# ============================================================
def compute_gae(rewards, values, next_value, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation (GAE).
    """
    advantages = []
    gae = 0
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            delta = rewards[t] + gamma * next_value - values[t]
        else:
            delta = rewards[t] + gamma * values[t + 1] - values[t]
        gae = delta + gamma * lam * gae
        advantages.insert(0, gae)
    return advantages

def test_model(trained_model, value_head, test_prompt, max_steps=3):
    """
    Test the trained model by generating and executing a pipeline.
    """
    trained_model.eval()
    with torch.no_grad():
        state = {"prompt": test_prompt, "partial_pipeline": []}
        for step in range(max_steps):
            input_text = f"Prompt: {state['prompt']} Pipeline: {' '.join(state['partial_pipeline'])}"
            inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
            outputs = trained_model(**inputs, output_hidden_states=True)
            logits = outputs.logits[:, -1, :]

            if not state["partial_pipeline"]:
                valid_actions = MODULES
            elif len(state["partial_pipeline"]) == 1 and state["partial_pipeline"][0] in MODULES:
                valid_actions = SIGNATURES
            else:
                valid_actions = ["stop"]

            valid_token_ids = [tokenizer.encode(a, add_special_tokens=False)[0] for a in valid_actions]
            action_logits = logits[0, valid_token_ids]
            action_probs = torch.softmax(action_logits, dim=-1)
            action_idx = torch.argmax(action_probs).item()
            action = valid_actions[action_idx]

            state["partial_pipeline"].append(action)
            print(f"Test Step {step+1}: Selected Action: {action}")

            if action == "stop":
                break

        pipeline = state["partial_pipeline"][:-1] if state["partial_pipeline"] and state["partial_pipeline"][-1] == "stop" else state["partial_pipeline"]
        response = execute_pipeline(test_prompt, pipeline)
        print(f"\nTest Prompt: {test_prompt}")
        print(f"Generated Pipeline: {pipeline}")
        print(f"Response: {response}")
        return response


**Training Model**

In [None]:
def train_ppo(num_episodes=500, clip_eps=0.2, epochs_per_batch=10, batch_size=10):
    """
    Train the policy model using PPO with GAE and batching.
    """
    rewards_history = []

    for episode in range(num_episodes):
        trajectories = []
        print(f"\n=== Episode {episode + 1}/{num_episodes} ===")
        for _ in range(batch_size):
            idx = np.random.randint(len(prompts))
            prompt = prompts[idx]
            traj = generate_pipeline_ppo(prompt)
            trajectories.append(traj)

        # Compute advantages and returns
        all_states, all_actions, all_log_probs, all_advantages, all_returns = [], [], [], [], []
        for traj in trajectories:
            states, actions, log_probs, rewards, values = traj
            advantages = compute_gae(rewards, values, next_value=0)
            returns = [adv + val for adv, val in zip(advantages, values)]
            all_states.extend(states)
            all_actions.extend(actions)
            all_log_probs.extend(log_probs)
            all_advantages.extend(advantages)
            all_returns.extend(returns)

        # Convert to tensors
        all_log_probs = torch.stack(all_log_probs)
        all_advantages = torch.tensor(all_advantages, dtype=torch.float32).to(device)
        all_returns = torch.tensor(all_returns, dtype=torch.float32).to(device)

        # PPO update
        for epoch in range(epochs_per_batch):
            for i, (state, action) in enumerate(zip(all_states, all_actions)):
                inputs = tokenizer(state, return_tensors="pt", padding=True, truncation=True).to(device)
                outputs = policy_model(**inputs, output_hidden_states=True)
                logits = outputs.logits[:, -1, :]

                # Determine valid actions
                partial_pipeline = state.split("Pipeline: ")[1].split() if "Pipeline: " in state else []
                if not partial_pipeline:
                    valid_actions = MODULES
                elif len(partial_pipeline) == 1 and partial_pipeline[0] in MODULES:
                    valid_actions = SIGNATURES
                else:
                    valid_actions = ["stop"]

                valid_token_ids = [tokenizer.encode(a, add_special_tokens=False)[0] for a in valid_actions]
                action_logits = logits[0, valid_token_ids]
                new_probs = torch.softmax(action_logits, dim=-1)

                try:
                    action_idx = valid_actions.index(action)
                except ValueError:
                    print(f"Action '{action}' not found in valid actions: {valid_actions}")
                    continue

                new_log_prob = torch.log(new_probs[action_idx] + 1e-10)

                # PPO clipped objective
                ratio = torch.exp(new_log_prob - all_log_probs[i])
                surr1 = ratio * all_advantages[i]
                surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * all_advantages[i]
                policy_loss = -torch.min(surr1, surr2)

                # Value loss
                hidden_states = outputs.hidden_states[-1][:, -1, :]
                new_value = value_head(hidden_states).squeeze()
                value_loss = F.mse_loss(new_value, all_returns[i])

                # Entropy bonus
                entropy = -torch.sum(new_probs * torch.log(new_probs + 1e-10))

                # Total loss
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Debug output
                if i % 5 == 0:
                    print(f"Epoch {epoch+1}, Update {i+1}: Loss = {loss.item():.4f}, Policy Loss = {policy_loss.item():.4f}, Value Loss = {value_loss.item():.4f}, Entropy = {entropy.item():.4f}")

        avg_reward = np.mean([traj[3][-1] for traj in trajectories])
        rewards_history.append(avg_reward)
        print(f"Episode {episode + 1} Average Reward: {avg_reward:.3f}")

    torch.save(policy_model.state_dict(), "ppo_policy_model.pt")
    torch.save(value_head.state_dict(), "ppo_value_head.pt")
    return policy_model, value_head

**Run Training**

In [None]:
trained_model, trained_value_head = train_ppo(num_episodes=20, clip_eps=0.2, epochs_per_batch=10, batch_size=10)

**Testing Model**

In [None]:
def evaluate_module(dataset, fulldataset, model, trained_value_head, query=False):
    """
    Evaluates the given module on a dataset of examples.
    """
    total_correct = 0
    total_pra_score = 0
    results = []
    start_time = time.time()

    print("Query mode:", query)

    for i, example in enumerate(dataset):
        print(f"-------------------------------------DATASET NUMBER : {i+1}------------------------------------------------")

        prompt = f"{example.context} \n\n {example.question} " if query else example.question
        prediction_text = test_model(model, trained_value_head, prompt)

        # Check correctness
        if example.answer:
            is_correct = (
                prediction_text and prediction_text.strip().lower() == example.answer.strip().lower()
            ) or (
                prediction_text and example.answer.strip().lower() in prediction_text.strip().lower()
            )
        else:
            is_correct = False

        score = 1 if is_correct else 0

        # LLM evaluation if not correct
        if not is_correct:
            eval_prompt = f"""
                Evaluate whether the following response correctly answers the prompt based on the ground truth.
                Return a score of 1.0 or 0.0 if the response is correct and 0 if incorrect.
                Only return the score inside the square brackets [].

                Prompt: {example.question}
                Response: {prediction_text}
                Ground Truth: {example.answer}
                Final score:[]"""

            llm_res = lm(eval_prompt)
            score = 1 if "1" in llm_res[0] else 0

        total_correct += score

        print("Ground Truth: ", example.answer)
        print("Score: ", score)

        # Store result
        if query:
            result = {
                "question": example.question,
                "context": example.context,
                "response": prediction_text,
                "ground_truth": example.answer if example.answer else "N/A",
                "score": score,
            }
            results.append(result)

    total_time = time.time() - start_time
    return total_time, total_correct

**Loading Test Datasets**

In [None]:
gsm8k_test = load_dataset("gsm8k", "main", split="test[0:20]")
hotpotqa_test = load_dataset("hotpot_qa", "fullwiki", split="train[20:40]", trust_remote_code=True)

gsm8k_test_data = [dspy.Example(question=ex['question'], answer=ex["answer"].split("####")[-1].strip(), task_type='math').with_inputs('question') for ex in gsm8k_test]
hotpotqa_test_data = [dspy.Example(question=ex['question'], answer=ex['answer'], context=ex["context"], task_type='qa').with_inputs('question') for ex in hotpotqa_test]


**Run Testing**

In [None]:

# Testing HotpotQA
htotal_time, htotal_correct = evaluate_module(hotpotqa_test_data, hotpotqa_test, policy_model, trained_value_head, query=True)
print(f"HotPotQA - Total time: {htotal_time}, Total correct: {htotal_correct}")

In [None]:
# Testing GSM8K
gtotal_time, gtotal_correct = evaluate_module(gsm8k_test_data, gsm8k_test, policy_model, trained_value_head)
print(f"GSM8K - Total time: {gtotal_time}, Total correct: {gtotal_correct}")