# Demo the Text-to-Speech module

In [None]:
import glados.tts as tts
import sounddevice as sd

In [None]:
# Instantiate the TTS engine
glados_tts = tts.TTSEngine()

In [None]:
# Generate the audio.
# Glados is spelt incorrectly on purpose to make the pronunciation more accurate.
audio = glados_tts.generate_speech_audio("Hello, my name is Gladohs. I am an AI created by Aperture Science.")

# Play the audio
sd.play(audio, tts.RATE)

# Demo the Automatic speech recognition system
This will detect and transcribe your voice. In this demo, it will then get GlaDOS to repeat back to you what was heard.

In [None]:
import glados.voice_recognition as vr

In [None]:
def say_text(text: str):
    """Say text using text-to-speech engine
    """
    audio = glados_tts.generate_speech_audio(text)
    sd.play(audio, tts.RATE)
    sd.wait()

# Instantiate VoiceRecognition class with the say_text function
demo = vr.VoiceRecognition(function=say_text)

# Start the demo
demo.start()

# Demo the LLM
This allows you to interact directly with the LLM

In [None]:
import os
from glados.llama import LlamaServer, LlamaServerConfig
from pathlib import Path


# We have two different ways of creating a llama.cpp object.
# Either the Llama server can be started directly from within python,
# or we can use a reference to an external server.
# We will start the server for this case:
llama_server_config = LlamaServerConfig.from_yaml(Path(os.path.expanduser("glados_config.yml")).resolve())

In [None]:
llama_server_config

In [None]:
server = LlamaServer.from_config(llama_server_config)
server.start()

In [None]:
import json
import re
import requests
import yaml
from dataclasses import dataclass
from datetime import datetime
from jinja2 import Template
from typing import List, Optional, Sequence


LLM_STOP_SEQUENCE = "<|eot_id|>"  # End of sentence token for Meta-Llama-3
TEMPLATES = {
    "LLAMA3": "".join([
        "{% set loop_messages = messages %}",
        "{% for message in loop_messages %}",
        "    {% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}",
        "    {% if loop.index0 == 0 %}",
        "        {% set content = bos_token + content %}",
        "    {% endif %}",
        "    {{ content }}",
        "{% endfor %}",
        "{% if add_generation_prompt %}",
        "    {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
        "{% endif %}"
    ]),
    "CHATML": "".join([
        "{% if messages[0]['role'] == 'system' %}",
        "    {% set offset = 1 %}",
        "{% else %}",
        "    {% set offset = 0 %}",
        "{% endif %}",
        "",
        "{{ bos_token }}",
        "{% for message in messages %}",
        "    {% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}",
        "        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}",
        "    {% endif %}",
        "",
        "    {{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}",
        "{% endfor %}",
        "",
        "{% if add_generation_prompt %}",
        "    {{ '<|im_start|>assistant\n' }}",
        "{% endif %}"
    ])
}
DEFAULT_PERSONALITY_PREPROMPT = [
    {
        "role": "system",
        "content": "You are a helpful AI assistant. You are here to assist the user in their tasks. The current time is {t:%l}:{t:%M} {t:%p}.",
    },
]


@dataclass
class GladosConfig:
    completion_url: str
    api_key: Optional[str]
    wake_word: Optional[str]
    announcement: Optional[str]
    personality_preprompt: List[dict[str, str]]
    interruptible: bool
    template: str = "LLAMA3"
    voice_model: str = "glados.onnx"
    speaker_id: int = None

    @classmethod
    def from_yaml(cls, path: str, key_to_config: Sequence[str] | None = ("Glados",)):
        key_to_config = key_to_config or []

        with open(path, "r") as file:
            data = yaml.safe_load(file)

        config = data
        for nested_key in key_to_config:
            config = config[nested_key]

        return cls(**config)


class LlamaClient:
    @property
    def messages(self) -> Sequence[dict[str, str]]:
        return self._messages

    def __init__(
        self,
        server: LlamaServer,
        api_key: str | None = None,
        template: str = "LLAMA3",
        personality_preprompt: Sequence[dict[str, str]] = DEFAULT_PERSONALITY_PREPROMPT
    ):
        self.completion_url = server.completion_url
        # LLAMA_SERVER_HEADERS
        self.prompt_headers = {"Authorization": api_key or "Bearer your_api_key_here"}
        self._messages = personality_preprompt
        self.template = Template(TEMPLATES[template])

    @classmethod
    def from_config(
        cls,
        server: LlamaServer,
        config: GladosConfig
    ):

        personality_preprompt = []
        for line in config.personality_preprompt:
            personality_preprompt.append(
                {"role": list(line.keys())[0], "content": list(line.values())[0]}
            )

        return cls(
            server=server,
            api_key=config.api_key,
            template=config.template,
            personality_preprompt=personality_preprompt
        )
        
    def process_query(self, query):
        self.messages.append({"role": "user", "content": query})
        now = datetime.now()
        prompt = self.template.render(
            messages=[{"role": message['role'], "content": message['content'].format(t=now)} for message in self.messages],
            bos_token="<|begin_of_text|>",
            add_generation_prompt=True
        )
        print(prompt)
        data = {
            "stream": True,
            "prompt": prompt
        }
        sentences = []
        with requests.post(
            self.completion_url,
            headers=self.prompt_headers,
            json=data,
            stream=True
        ) as response:
            sentence = []
            for line in response.iter_lines():
                if line:
                    line = self._clean_raw_bytes(line)
                    next_token = self._process_line(line)
                    if next_token:
                        # print(next_token, end="")
                        print(f"\x1b[36m*{next_token}* \x1b[0m", end="")
                        sentence.append(next_token)
                        # If there is a pause token, print the queue so far
                        if next_token in [
                            ".",
                            "!",
                            "?",
                            "?!",
                            "\n",
                            "\n\n"
                        ]:
                            sentences.append(self._process_sentence(sentence))
                            sentence = []
                        if next_token == "<|im_end|>":
                            break
            if sentence:
                sentences.append(self._process_sentence(sentence))
        self.messages.append({"role": "assistant", "content": "".join(sentences)})

    def _clean_raw_bytes(self, line):
        line = line.decode("utf-8")
        line = line.removeprefix("data: ")
        line = json.loads(line)
        return line

    def _process_line(self, line):
        if not line['stop']:
            token = line['content']
            return token

    def _process_sentence(self, current_sentence: List[str]):
        sentence = "".join(current_sentence)
        sentence = re.sub(r"\<\|im_end\|\>.*$", "", sentence)
        sentence = re.sub(r"\*.*?\*|\(.*?\)|\<\|.*?\|\>", "", sentence)
        sentence = (
            sentence.replace("\n\n", ". ")
            .replace("\n", ". ")
            .replace("  ", " ")
        )
        if sentence:
            print()
            print(sentence)
        return sentence


In [None]:
llama_client_config = GladosConfig.from_yaml(Path(os.path.expanduser("glados_config.yml")).resolve())

In [None]:
llama_client_config

In [None]:
client = LlamaClient.from_config(server, llama_client_config)

In [None]:
client.process_query("Hello, how are you today?")

In [None]:
client.process_query("What time is it?")

In [None]:
server.stop()