In [None]:
import os
import random
import threading
import wave
import subprocess
import xml.etree.ElementTree as ET
from abc import ABC, abstractmethod
from collections import defaultdict
from io import BytesIO
from pathlib import Path
from typing import Literal

import ipywidgets as widgets
from ipywidgets import Output
import ipyaudioworklet as ipyaudio
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from gtts import gTTS
import pygame
import spacy
import speech_recognition as sr
from google import genai
from google.genai import types
from IPython.display import display, clear_output

In [None]:
class LLMJokerAgent(ABC):
    """Base class for LLM-powered joke generation agents."""

    MEMORY_SYSTEM_INSTRUCTION = "You are an advanced agent memory manager that keeps track of conversation content by maintaining a SHORT summary. You transform the old summary, with the new user message and agent response to the new summary."

    def __init__(self, prompts_dir: Path):
        self.prompts_dir = Path(prompts_dir)
        self.memory = ""

        output_format_path = self.prompts_dir / "output-format.md"
        with open(output_format_path, 'r') as f:
            self.output_format = f.read()

        self.profiles = {}
        for file in self.prompts_dir.iterdir():
            if file.suffix == '.md' and 'output' not in file.name:
                with open(file, 'r') as f:
                    self.profiles[file.stem] = f.read()

        self.profile_names = list(self.profiles.keys())

    @abstractmethod
    def _call_llm(self, system_instruction: str, user_content: str) -> str | None:
        """Sub-classes implement this to either prompt with the CLI or the API"""
        pass

    def get_random_profile(self):
        name = np.random.choice(self.profile_names)
        return name, self.profiles[name]

    def generate_response(self, user_response: str = "I don't have much to say.", personality=None, update_memory=True, N=3, example_joke: str | None = None):
        """Generate a joke with a random personality if personaltiy is not set."""
        _, profile = self.get_random_profile()

        if personality is not None:
            profile = self.profiles.get(personality, profile)

        system_instruction = f"{profile}\n{self.output_format.replace('[N]', str(N))}"

        prompt = user_response
        if example_joke:
            prompt = f"{user_response}\n\n<example_of_joke_user_liked>{example_joke}</example_of_joke_user_liked>"

        model_response = self._call_llm(system_instruction, prompt)
        
        if not model_response:
            return ["I'm done for the day"]

        # Strip markdown code fences if present (e.g., ```xml ... ```)
        model_response = model_response.strip()
        if model_response.startswith("```"):
            lines = model_response.split("\n")
            lines = [l for l in lines[1:] if l.strip() != "```"]
            model_response = "\n".join(lines)

        try:
            parsed_response = ET.fromstring(model_response)
            response_jokes = list(map(lambda x: x.text, parsed_response.findall("joke")))
            if response_jokes is None or response_jokes[0] is None:
                return ["I'm not feeling so funny today."]
            if update_memory:
                self.update_memory(user_response, response_jokes[0])
        except:
            print("Actual model response: ", model_response)
            return ["I lost my train of thought. Could you repeat that?"]

        return response_jokes

    def update_memory(self, user_response: str, model_response: str) -> str:
        """Update conversation memory with a summary of the exchange."""
        prompt = f"previous summary: {self.memory}, new user msg: {user_response}, new system msg: {model_response}"

        summary = self._call_llm(self.MEMORY_SYSTEM_INSTRUCTION, prompt)

        if summary:
            self.memory = summary
            print("New memory:\n", self.memory)

        return self.memory

class LLMAPIJokerAgent(LLMJokerAgent):
    """Implementation with the gemini API"""
    def __init__(self, prompts_dir: Path, api_key: str="AIzaSyC2B_9Koiklo6Dh5WsxtZe7J7iU2ZFp01Q", model="gemini-2.5-flash"):
        super().__init__(prompts_dir)
        self.client = genai.Client(api_key=api_key)
        self.model = model

    def _call_llm(self, system_instruction: str, user_content: str) -> str | None:
        response = self.client.models.generate_content(
            model=self.model,
            config=types.GenerateContentConfig(
                system_instruction=system_instruction
            ),
            contents=user_content,
        )
        return response.text

class LLMCLIJokerAgent(LLMJokerAgent):
    """Implementation with the gemini CLI"""

    def __init__(self, prompts_dir: Path):
        super().__init__(prompts_dir)

    def _call_llm(self, system_instruction: str, user_content: str) -> str | None:
        try:
            result = subprocess.run(
                ["gemini", "-p", f"{system_instruction}\n\nUser: {user_content}"],
                capture_output=True,
                text=True,
                timeout=60
            )

            if result.returncode != 0:
                print(f"CLI Error: {result.stderr}")
                return None

            return result.stdout.strip() or None

        except subprocess.TimeoutExpired:
            print("CLI timeout")
            return None
        except FileNotFoundError:
            print("Gemini CLI not found.")
            return None


class MemoryAgent:
    PRONOUNS = {"it", "they", "he", "she", "this", "that", "these", "those", "i", "me", "my", "we", "us", "our", "you", "your"}

    def __init__(self, max_memory_length):
        self.memory = []
        self.max_memory_length = max_memory_length
        
    def user_update(self, content: str):
        features = self._extract_features(content)
        if self.memory and any(s.lower() in self.PRONOUNS for s in features["subj"]):
            prev = self.memory[-1]
            features["resolved_refs"] = prev.get("subj", []) + prev.get("obj", []) + prev.get("entities", [])
        self.memory.append(features)
        if len(self.memory) > self.max_memory_length:
            self.memory = self.memory[-self.max_memory_length:]        

    def get_full_memory_summary(self, n_contrast: int = 1) -> str:
        if not self.memory:
            return ""

        current_idx = len(self.memory) - 1
        relevant_idx = self._get_relevant_indices(current_idx)
        non_relevant_idx = [i for i in range(current_idx + 1) if i not in relevant_idx]

        contrast_groups = []
        used = set()
        for idx in sorted(non_relevant_idx, reverse=True):
            if idx in used:
                continue
            group = [i for i in self._get_relevant_indices(idx) if i in non_relevant_idx]
            if group:
                contrast_groups.append(group)
                used.update(group)

        selected_contrasts = random.sample(contrast_groups, min(n_contrast, len(contrast_groups)))

        parts = [self._summarize([self.memory[i] for i in relevant_idx])]
        for group in selected_contrasts:
            parts.append(self._summarize([self.memory[i] for i in group]))

        return " | ".join(filter(None, parts))

    def _get_relevant_indices(self, target_idx: int) -> set[int]:
        if target_idx < 0 or target_idx >= len(self.memory):
            return set()
        target_refs = self._get_refs(self.memory[target_idx])
        return {i for i in range(target_idx + 1) if not target_refs.isdisjoint(self._get_refs(self.memory[i]))}

    def _get_refs(self, f: dict) -> set:
        all_refs = f.get("subj", []) + f.get("obj", []) + f.get("entities", []) + f.get("resolved_refs", [])
        return {r for r in all_refs if r.lower() not in self.PRONOUNS}

    def _summarize(self, memory_list: list[dict]) -> str:
        if not memory_list:
            return ""
        tuples = []
        for f in memory_list:
            tuples.append((
                " ".join(w for w in f.get("subj", []) if w.lower() not in self.PRONOUNS),
                " ".join(f.get("verbs", [])),
                " ".join(f.get("adjectives", [])),
                " ".join(w for w in f.get("obj", []) if w.lower() not in self.PRONOUNS),
                " ".join(f.get("entities", [])),
            ))
        return " ".join(filter(None, (" ".join(filter(None, t)) for t in zip(*tuples))))

    def _extract_features(self, sentence: str) -> dict:
        doc = nlp(sentence)
        return {
            "entities": [e.text for e in doc.ents],
            "subj": [t.text for t in doc if "subj" in t.dep_],
            "obj": [t.text for t in doc if "obj" in t.dep_],
            "adjectives": [t.text for t in doc if t.pos_ == "ADJ"],
            "verbs": [t.lemma_ for t in doc if t.pos_ == "VERB"],
            "past_tense": any(t.tag_ == "VBD" for t in doc),
            "negated": any(t.dep_ == "neg" for t in doc),
            "numbers": [t.text for t in doc if t.like_num],
        }

class UserFeedbackTrackerAgent:
    def __init__(self) -> None:
        self.positive_feedbacks = defaultdict(list)
        self.negative_feedbacks = defaultdict(list)
        self.awaiting_feedback = None
        
    def await_feedback(self, joke: str, personality: str):
        self.awaiting_feedback = (personality, joke)
        
    def process_feedback(self, polarity: bool):
        if self.awaiting_feedback is None:
            return
        feedbacks = self.positive_feedbacks if polarity else self.negative_feedbacks
        feedbacks[self.awaiting_feedback[0]].append(self.awaiting_feedback[1])

In [None]:
nlp = spacy.load("en_core_web_sm")

joker_agent = LLMCLIJokerAgent(os.getcwd()/Path("prompts"))
memory = MemoryAgent(10)
feedback = UserFeedbackTrackerAgent()

recognizer = sr.Recognizer()

['I had to take a test once. The only question was "What is the capital of pancakes?" I answered "The sky is made of retired librarians." The teacher marked it correct and then I woke up inside a filing cabinet.']

<div style="text-align: center !important; max-width: 600px; margin: 0 auto;">

# JokeBloke

Press `boot recorder` to wake him up. Then press `record` once, after which you can talk to him. When you're done talking, press `stop` and patiently wait for him to respond.

</div>

In [None]:
# Helper Functions
def get_audio_as_wav_bytes(recorder):
    """Return recorder audio as a WAV byte stream."""
    data = recorder.audiodata
    srate = recorder.sampleRate
    
    if data is None or len(data) == 0:
        raise ValueError("No audio recorded!")

    data_int16 = (data * 32767).astype(np.int16)

    wav_bytes = BytesIO()
    with wave.open(wav_bytes, 'wb') as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(srate)
        wf.writeframes(data_int16.tobytes())
    wav_bytes.seek(0)
    return wav_bytes

def transcribe(recorder):
    wav_bytes = get_audio_as_wav_bytes(recorder)

    with sr.AudioFile(wav_bytes) as source:
        audio = recognizer.record(source)
    
    try:
        text = recognizer.recognize_google(audio)
        print("Transcription:", text)
        return text
    except sr.UnknownValueError:
        print("Could not understand audio")
    except sr.RequestError as e:
        print(f"API error: {e}")
    return None

def speak_text(text):
    with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as f:
        gTTS(text).save(f.name)
        pygame.mixer.init()
        pygame.mixer.music.load(f.name)
        pygame.mixer.music.play()
        while pygame.mixer.music.get_busy():
            pygame.time.wait(100)

LOADING_MESSAGES = [
    "Hold on, I'm workshopping this one...",
    "Comedy takes time, unlike my ex's patience...",
    "Still thinking... unlike my comedy career, this is going somewhere",
    "One moment, genius at work... or at least mild amusement",
    "Loading wit... please stand by",
    "My writers are on strike again...",
    "Buffering humor... have you tried turning me off and on again?",
    "Consulting my inner clown...",
    "The punchline is stuck in traffic...",
    "Almost there... comedy gold doesn't mine itself",
    "Hang tight, I'm funnier than this pause suggests",
    "Processing... this joke better be worth it",
    "My funny bone needs a moment...",
    "Joke loading... unlike my love life, this will complete",
    "I'm not stalling, I'm building suspense...",
    "Rome wasn't built in a day, and neither is comedy gold",
]

def generate_with_loading_messages(generator_func, on_message=None, min_delay=2.0, max_delay=5.0):
    """
    Run generator_func in background thread while showing loading messages.
    on_message callback is called with each loading message to display.
    """
    result = []
    error = []
    done = threading.Event()

    def background_task():
        try:
            result.append(generator_func())
        except Exception as e:
            error.append(e)
        finally:
            done.set()

    thread = threading.Thread(target=background_task, daemon=True)
    thread.start()

    used_messages = []
    while not done.is_set():
        delay = random.uniform(min_delay, max_delay)
        if done.wait(timeout=delay):
            break
        # Pick a message we haven't used yet
        if len(used_messages) >= len(LOADING_MESSAGES):
            used_messages = []
        available = [m for m in LOADING_MESSAGES if m not in used_messages]
        message = random.choice(available)
        used_messages.append(message)
        if on_message:
            on_message(message)

    thread.join()

    if error:
        raise error[0]
    return result[0]

def save_wav_from_recorder(recorder, filename="recording.wav"):
    """Save the recorded audio from recorder into a WAV file."""
    data = recorder.audiodata
    sr = recorder.sampleRate
    
    if data is None or len(data) == 0:
        print("No audio data to save!")
        return None

    data_int16 = (data * 32767).astype(np.int16)

    with wave.open(filename, "wb") as f:
        f.setnchannels(1)
        f.setsampwidth(2)
        f.setframerate(sr)
        f.writeframes(data_int16.tobytes())

    print(f"Saved: {filename}")
    return filename

def plot_audio(rec):        
    plt.figure(figsize=(10,3))
    plt.plot(rec.audiodata)
    plt.title("Recorded Audio Waveform")
    plt.xlabel("Samples")
    plt.ylabel("Amplitude")
    plt.show()

# Widgets
rec = ipyaudio.AudioRecorder()
status_out = Output(layout={'padding': '5px'})
text_label = widgets.Label(value="Do you like this Answer?")
like_button = widgets.Button(description=": )", button_style='success')
dislike_button = widgets.Button(description=": (", button_style='danger')
loading_indicator = widgets.HTML(value='<img src="https://i.gifer.com/ZZ5H.gif" width="30" style="vertical-align:middle;"> <span>Thinking of something funny...</span>')
loading_indicator.layout.display = 'none'
output = widgets.Output()

def set_loading_state(is_loading: bool):
    """Enable/disable UI elements based on loading state."""
    like_button.disabled = is_loading
    dislike_button.disabled = is_loading
    # ipyaudioworklet doesn't have a disabled property, so we use CSS
    if is_loading:
        rec.layout.opacity = '0.5'
        rec.layout.pointer_events = 'none'
    else:
        rec.layout.opacity = '1'
        rec.layout.pointer_events = 'auto'
    loading_indicator.layout.display = 'inline' if is_loading else 'none'
    text_label.layout.display = 'none' if is_loading else 'inline'

# Event Handlers
@status_out.capture(clear_output=True)
def status_changed(change):
    print("Status:", change.new)

@status_out.capture()
def on_status_change(change):
    if change.new in ("STOPPED", "RECORDED"):
        user_input = transcribe(rec)

        if not user_input:
            return

        memory.user_update(user_input)

        personalities = joker_agent.profile_names

        # UCB (Upper Confidence Bound) for personality selection
        n_total = sum(
            len(feedback.positive_feedbacks[pers]) + len(feedback.negative_feedbacks[pers])
            for pers in personalities
        )

        ucb_scores = []
        for pers in personalities:
            positive = len(feedback.positive_feedbacks[pers])
            negative = len(feedback.negative_feedbacks[pers])
            n = positive + negative

            if n == 0:
                # Untested personalities get high UCB to encourage exploration
                ucb_scores.append(float('inf'))
            else:
                mean = positive / n
                exploration_bonus = np.sqrt(2 * np.log(n_total + 1) / n)
                ucb_scores.append(mean + exploration_bonus)

        # Pick personality with highest UCB score (random tiebreak for inf)
        max_ucb = max(ucb_scores)
        best_indices = [i for i, score in enumerate(ucb_scores) if score == max_ucb]
        personality_pick = np.random.choice(best_indices)

        personality = personalities[personality_pick]
        print("Personality: ", personality)

        # Get most recent liked joke for this personality as an example
        liked_jokes = feedback.positive_feedbacks[personality]
        example_joke = liked_jokes[-1] if liked_jokes else None

        def update_loading_message(msg):
            loading_indicator.value = f'<img src="https://i.gifer.com/ZZ5H.gif" width="30" style="vertical-align:middle;"> <span>{msg}</span>'

        set_loading_state(True)
        try:
            prompt = f"<history>{memory.get_full_memory_summary(2)}</history>\n{user_input}"
            jokes = generate_with_loading_messages(
                lambda: joker_agent.generate_response(prompt, update_memory=False, personality=personality, N=1, example_joke=example_joke),
                on_message=update_loading_message
            )
        finally:
            set_loading_state(False)

        print("Response: ", jokes[0])

        speak_text(jokes[0])
        feedback.await_feedback(jokes[0], personality)

def on_like_clicked(b):
    with output:
        clear_output()
        feedback.process_feedback(True)
        speak_text("Thanks!")
        print("You liked this!")

def on_dislike_clicked(b):
    with output:
        clear_output()
        feedback.process_feedback(False)
        speak_text("Wow, tough crowd!")
        print("You disliked this!")

# Wire events
rec.observe(status_changed, "status")
rec.observe(on_status_change, "status")
like_button.on_click(on_like_clicked)
dislike_button.on_click(on_dislike_clicked)

# Display UI
# Inject custom CSS
STYLE_SHEET_CONTENT = """
/* Center the layout and set a max-width */
.jp-Cell-outputArea {
    display: flex;
    flex-direction: column;
    align-items: center;
}

.jp-OutputArea-child {
    max-width: 600px;
    width: 100%;
    padding: 5px 0;
    box-sizing: border-box;
}

/* Hide the audio player widget */
.jupyter-widgets audio,
.jupyter-widgets video {
    display: none !important;
}

/* Style the buttons - auto height to fit content */
.widget-button {
    background-color: #f0f0f0;
    border: 1px solid #ccc;
    border-radius: 8px;
    padding: 4px 12px;
    margin: 5px;
    font-size: 14px;
    font-weight: bold;
    cursor: pointer;
    transition: background-color 0.3s, transform 0.1s;
    box-sizing: border-box;
    height: auto !important;
    min-height: unset !important;
    line-height: 1.4;
}

.widget-button:hover {
    background-color: #e0e0e0;
}

.widget-button:active {
    transform: scale(0.98);
}

.widget-button:disabled {
    background-color: #f5f5f5;
    color: #aaa;
    cursor: not-allowed;
}

.widget-button.mod-success {
    background-color: #28a745;
    color: white;
    border-color: #28a745;
}

.widget-button.mod-success:hover {
    background-color: #218838;
}

.widget-button.mod-danger {
    background-color: #dc3545;
    color: white;
    border-color: #dc3545;
}

.widget-button.mod-danger:hover {
    background-color: #c82333;
}

/* Style the output text area - auto-grow with content */
.widget-output {
    width: 100%;
    min-height: 50px;
    max-height: none;
    overflow: visible;
}

.lm-Widget {
    text-align: center;
}

.widget-output .jp-RenderedText {
    background-color: #f8f9fa;
    border: 1px solid #dee2e6;
    border-radius: 8px;
    padding: 15px;
    margin-top: 10px;
    text-align: left;
    min-height: 40px;
    height: auto;
}

.widget-output .jp-RenderedText pre {
    white-space: pre-wrap;
    word-wrap: break-word;
    font-family: monospace;
    font-size: 14px;
    margin: 0;
}

/* Style the text label */
.widget-label {
    font-weight: bold;
    font-size: 16px;
    text-align: center;
}

/* Style the loading indicator */
.widget-html-content {
    display: flex;
    align-items: center;
    gap: 8px;
}

.widget-html-content img {
    width: 30px;
    height: 30px;
}

.widget-html-content span {
    font-style: italic;
    color: #666;
}

/* Like/dislike buttons sizing - auto height */
.widget-button.mod-success,
.widget-button.mod-danger {
    min-width: 70px;
    font-size: 18px;
    padding: 6px 16px;
    height: auto !important;
    min-height: unset !important;
}
"""
display(widgets.HTML(f"<style>{STYLE_SHEET_CONTENT}</style>"))

status_out.append_stdout("Recorder ready.\n")

feedback_ui = widgets.VBox([
    widgets.HBox([text_label, loading_indicator], layout=widgets.Layout(justify_content='center')),
    widgets.HBox([like_button, dislike_button], layout=widgets.Layout(justify_content='center')),
    output
], layout=widgets.Layout(align_items='center'))

display(rec)
display(status_out)
display(feedback_ui)