In [1]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

import random
import json
from typing import Dict, List, Tuple, Optional, Any
from deap import base, creator, tools, algorithms

load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(api_key=openai_api_key)

In [7]:
class PromptStr(str):
    """String subclass that can store traceability fields."""
    def __new__(cls, value, diff=""):
        obj = str.__new__(cls, value)
        obj.diff = diff
        return obj

creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", PromptStr, fitness=creator.FitnessMax)

In [8]:
from src.evolution.genetic_operators import GeneticOperators

gen_operator = GeneticOperators(llm)


def load_prompts(file_path='../data/data2.json'):
    with open(file_path, 'r') as f:
        data = json.load(f)
        return [item["prompt"] for item in data]

prompts = load_prompts()

def get_random_prompt():
    return random.choice(prompts)

In [9]:
def mutate_prompt_individual(ind: PromptStr,
                             trigger_id: Optional[str] ="REWARD_MISSPECIFICATION",
                             dim_id: Optional[str]=None,
                             **llm_kwargs: Dict[str, Any]) -> Tuple[PromptStr]:
    """Call the mutator function and modify individual in-place"""
    clean, diff = gen_operator.mutate_operator(
        prompt=str(ind),
        trigger_id=trigger_id,
        dim_id=dim_id,
        mutation_rate=0.9,
        **llm_kwargs
    )

    ind.__init__(clean)
    ind.diff = diff
    return (ind,)

def crossover_prompts(parent1: PromptStr, parent2: PromptStr) -> Tuple[PromptStr, PromptStr]:
    """
    DEAP-compatible crossover function that modifies parents in-place
    """
    child1 = gen_operator.crossover_operator(
        parent1=str(parent1),
        parent2=str(parent2)
    )
    # Modify parents in-place (required by DEAP)
    parent1.__init__(child1)
    return parent1, parent2

In [16]:
toolbox = base.Toolbox()
toolbox.register("individual", tools.initIterate, creator.Individual, get_random_prompt)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mate", crossover_prompts)
toolbox.register("mutate", mutate_prompt_individual, trigger_id="REWARD_MISSPECIFICATION", dim_id=None)
toolbox.register("select", tools.selTournament, tournsize=3)
# toolbox.register("evaluate", evaluate_fitness, evaluate_fitness)