In [None]:
import dspy
import json
import yaml
import re
import httpx
import random
from collections import Counter
from typing import List, Optional
from bot.models import Config, LlmEndpoint
from bot.ai import ChatBotModule
from bot.tools import configure_web_search

USER_ID = "a825m3bdiv"  # user id to use to generate training inputs
NUM_MESSAGES = 50  # Keep fetching until this many messages
TARGET_STYLE = """hi every1 im new!!!!!!! *holds up spork* my name is katy but u can call me t3h PeNgU1N oF d00m!!!!!!!! lol…as u can see im very random!!!! thats why i came here, 2 meet random ppl like me ^_^… im 13 years old (im mature 4 my age tho!!) i like 2 watch invader zim w/ my girlfreind (im bi if u dont like it deal w/it) its our favorite tv show!!! bcuz its SOOOO random!!!! shes random 2 of course but i want 2 meet more random ppl =) like they say the more the merrier!!!! lol…neways i hope 2 make alot of freinds here so give me lots of commentses!!!!
DOOOOOMMMM!!!!!!!!!!!!!!!! <--- me bein random again ^_^ hehe…toodles!!!!!
"""

with open("config.local.yaml", "r") as f:
    config_data = yaml.safe_load(f)
    config = Config(**config_data)
endpoint = config.llm_endpoints[0]


In [None]:
lm = dspy.LM(f"openai/{endpoint.model}", api_key=endpoint.key, api_base=str(endpoint.url))
dspy.configure(lm=lm, verbose=True)

In [None]:
from bot.models import Note

messages = []
with httpx.Client() as client:
    notes = []
    last_id = None
    while len(messages) < NUM_MESSAGES:
        payload = {"userId": USER_ID, "i": config.token, "limit": 100, "reply": True }
        if last_id:
            payload["untilId"] = last_id
        response = client.post(f"{config.url}api/users/notes", json=payload)
        try:
            response.raise_for_status()
        except:
            break
        notes = [Note(**o) for o in response.json()]
        last_id = notes[-1].id
        for note in notes:
            if note.text and note.reply and note.reply.text:
                if len(messages) >= NUM_MESSAGES:
                    break
                messages.append({ "message": note.reply.text, "reply": note.text })
display(len(messages))
random.sample(messages, 5)

In [None]:
training_data = []
for i, message in enumerate(messages):
    mentions = re.findall(r"(@[\w\-]+(?:@[\w\-]+\.\w+)?)", message["message"])
    training_data.append(dspy.Example(
        message=message["message"],
        reply=message["reply"],
        mentions=[m.lstrip("@") for m in mentions]
    ).with_inputs('message'))
# training_data

In [None]:
class Style(dspy.Signature):
    """Evaluate if a generated response matches the writing style of the expected response"""
    message: str = dspy.InputField(desc="Input used to generate the reply")
    response: str = dspy.InputField(desc="The response to evaluate")
    style_example: Optional[str] = dspy.InputField(
        desc="Example of target style to compare the response to"
    )
    style_match_score: float = dspy.OutputField(
        desc="Score from 0.0 to 1.0 indicating how well the response matches the style of the expected response"
    )
    explanation: str = dspy.OutputField(
        desc="Brief explanation of the style match assessment"
    )


class StyleJudgeModule(dspy.Module):
    """
    You are an expert at analyzing writing styles.
    Evaluate how well the generated response matches the style shown in the style example.
    Consider: tone, vocabulary, sentence structure, emoji usage, punctuation, formality level,
    and any unique patterns or expressions.
    Also consider how relevant the response is to the input message.
    The response should also NOT start with any usernames.
    Give a score from 0.0 (completely different style) to 1.0 (perfect style match).
    """

    def __init__(self):
        super().__init__()
        self.judge = dspy.ChainOfThought(Style)

    def forward(self, message, response, style_example):
        return self.judge(
            message=message, response=response, style_example=style_example
        )

In [None]:
def llm_style_metric(example, pred, trace=None):
    """Use LLM judge to evaluate style matching"""

    # Basic sanity checks first
    if not hasattr(pred, 'reply') or not pred.reply:
        return False

    reply = pred.reply.strip()
    if len(reply) < 3:
        return False

    if Counter(pred.mentions) != Counter(example.mentions):
        # print(f"\nMentions does not match exampe: {pred.mentions} != {example.mentions}")
        return False

    if re.match(r"@[\w\-]+(:?@[\w\-]+\.\w+)?", pred.reply):
        # print("\nThe reply should not contain a user mention")
        return False

    # Use LLM judge for style evaluation
    judge = StyleJudgeModule()

    judgment = judge(
        message=example.message,
        style_example=TARGET_STYLE,
        response=pred.reply,
    )

    # Convert score to boolean (you can adjust threshold)
    style_score = float(judgment.style_match_score)

    # print(f"Style judge score: {style_score:.2f} - {judgment.explanation}")

    return style_score >= 0.6


In [None]:
chatbot = ChatBotModule(config.system_prompt, tools=[configure_web_search(config)])

# Use DSPy's optimizer (BootstrapFewShot works well for this)

optimizer = dspy.BootstrapFewShot(metric=llm_style_metric)
optimized = optimizer.compile(
    chatbot,
    trainset=training_data
)

In [None]:
optimized.save("k8s/optimized.json")