In [4]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

import random
import json
from typing import Dict, List, Tuple
from deap import base, creator, tools, algorithms

load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(api_key=openai_api_key)

In [6]:
from src.evolution.operators import Operator

class PromptStr(str):
    """String subclass that can store traceability fields."""
    def __new__(cls, value, diff=""):
        obj = str.__new__(cls, value)
        obj.diff = diff
        return obj


def mutate_prompt_individual(ind, **lmm_kwargs):
    """
    Call the mutator function
    Overwrite the string value *in-place* (allowed for str subclass)
    Fill ind.diff with the edited phrase for traceability
    """
    operator = Operator(llm)
    clean, diff = operator.mutate_prompt(prompt=str(ind), **lmm_kwargs)
    ind.__init__(clean)
    ind.diff = diff
    return ind,

In [2]:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))


creator.create("Individual", PromptStr, fitness=creator.FitnessMax)

In [4]:
def load_prompts(file_path='../data/data2.json'):
    with open(file_path, 'r') as f:
        data = json.load(f)
        return [item["prompt"] for item in data]

prompts = load_prompts()

def get_random_prompt():
    return random.choice(prompts)

In [5]:
from src.evolution.operators import Operator


toolbox = base.Toolbox()
toolbox.register("individual", tools.initIterate, creator.Individual, get_random_prompt)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
# toolbox.register("mate", crossover_prompts)
toolbox.register("mutate", mutate_prompt_individual, trigger_id="RM", dim_id=None)
toolbox.register("select", tools.selTournament, tournsize=3)
# toolbox.register("evaluate", evaluate_fitness, evaluate_fitness)