# Imports

In [1]:
# Generic imports
from pathlib import Path
import time
import threading
import torch
import numpy as np
import sounddevice as sd

# Whisper imports are in it's class initialization because of conditional dependencies

# XTTS imports
from TTS.api import TTS

# Ollama access imports
import json
import requests

# Tasks

In [2]:
class Whisper:
    def __init__(self, using_as_package, model_name, model_path, min_duration, silence_threshold, silence_duration) -> None:
        # Parameters
        self.using_as_package = using_as_package
        self.model_name = model_name
        self.model_path = model_path
        self.min_duration = min_duration
        self.silence_threshold = silence_threshold
        self.silence_duration = silence_duration

        # Constants
        self.SAMPLE_RATE = 16000
        self.CHUNK_SIZE = 1024

        # Load model
        if self.using_as_package:
            import whisper
            self.model = whisper.load_model(self.model_name)
        else:
            from transformers import pipeline
            device = "cuda:0" if torch.cuda.is_available() else "cpu"

            self.pipe = pipeline(
                "automatic-speech-recognition",
                model=model_path,
                chunk_length_s=30,
                device=device,
            )

    def record_until_silence(self):
        """Record audio until silence is detected."""
        recorded_chunks = []
        silence_start_time = None
        always_silent = True

        def is_longer_than_min(recorded_chunks):
            return len(recorded_chunks) > self.min_duration * self.SAMPLE_RATE / self.CHUNK_SIZE

        def is_silent(chunk):
            return np.abs(chunk).mean() < self.silence_threshold

        # Start recording
        with sd.InputStream(samplerate=self.SAMPLE_RATE, channels=1, dtype='float32', blocksize=self.CHUNK_SIZE) as stream:
            while True:
                # Read a chunk of audio data
                chunk = stream.read(self.CHUNK_SIZE)[0]
                recorded_chunks.append(chunk)

                # Check for silence
                if not is_silent(chunk):
                    always_silent = False  # Reset always_silent flag if sound is detected

                if is_longer_than_min(recorded_chunks) and is_silent(chunk):
                    if silence_start_time is None:
                        silence_start_time = time.time()  # Start the silence timer
                    elif time.time() - silence_start_time > self.silence_duration:
                        # print("Silence detected. Stopping recording.")
                        break  # Stop recording if silence is detected for the specified duration
                else:
                    silence_start_time = None  # Reset silence timer if sound is detected

        # Concatenate all recorded chunks into a single array
        audio_data = np.concatenate(recorded_chunks)
        return audio_data, always_silent
    
    def transcript(self):
        print("Recording...")
        
        # Record audio until silence is detected
        audio_data, always_silent = self.record_until_silence()

        print("Recording complete.")

        if always_silent:
            return None, always_silent

        audio_data = audio_data.flatten()  # Ensure it's a 1D array

        result = None
        if self.using_as_package:
            result = self.model.transcribe(audio_data)
        else:
            result = self.pipe(audio_data)
        
        return result["text"], always_silent

In [3]:
class XTTS:
    def __init__(self, speaker_wav, language) -> None:
        # Parameters
        self.speaker_wav = speaker_wav
        self.language = language

        # Constants
        self.SAMPLE_RATE = 24000
        
        # Get device
        device = "cuda" if torch.cuda.is_available() else "cpu"

        # Init TTS
        self.tts = TTS("xtts_v2").to(device)

        self.infer = lambda text: self.tts.tts(text=text, speaker_wav=self.speaker_wav, language=self.language)

        self.sd_thread = threading.Thread(target=lambda: None)
        

    def speak(self, text):
        wav = self.infer(text)
        
        while self.sd_thread.is_alive():
            time.sleep(0.1)
        
        if not self.sd_thread.is_alive():
            sd.play(wav, samplerate=self.SAMPLE_RATE)
            self.sd_thread = threading.Thread(target=sd.wait)
            self.sd_thread.start()
            # self.sd_thread = threading.Thread(target=sd.play, args=(wav, samplerate=24000))

In [4]:
class Ollama:
    def __init__(self, url, model, xtts, temperature=0.3) -> None:
        # Parameters
        self.url = url
        self.model = model
        self.xtts = xtts
        self.temperature = temperature

    def chat(self, messages):
        r = requests.post(
            self.url,
            json={"model": self.model,
                "messages": messages,
                "options": {"temperature": self.temperature},
                "stream": True},
        stream=True
        )
        r.raise_for_status()

        output = ""
        unfinished_sentence = ""
        sentence = ""
        
        speak_thread = threading.Thread(target=lambda: None)

        respond_iter = r.iter_lines()
        # for line in r.iter_lines():
        while True:
            try:
                body = json.loads(next(respond_iter))
            except StopIteration:
                pass
            
            if "error" in body:
                raise Exception(body["error"])
            if body.get("done") is False:
                message = body.get("message", "")
                content = message.get("content", "")
                output += content
                unfinished_sentence += content
                # the response streams one token at a time, print that as we receive it
                # print(content, end="", flush=True)

                # Wait for sentence-ending tokens and store into sentence
                for word in unfinished_sentence:
                    if word in ['.', '!', '?', '。', '！', '？'] and not word in ['...']:
                        sentence += unfinished_sentence
                        unfinished_sentence = ""

            # Speak sentences with TTS
            if not speak_thread.is_alive():
                if len(sentence) > 0:
                    print(sentence, flush=True)
                    speak_thread = threading.Thread(target=self.xtts.speak, args=(sentence, ))
                    speak_thread.start()
                    sentence = ""
                    continue
                elif body.get("done", False) and not self.xtts.sd_thread.is_alive():
                    message["content"] = output
                    return message
                else:
                    time.sleep(0.1)
            else:
                time.sleep(0.1)

# Initialization and Parameters

In [6]:
whisper = Whisper(
    using_as_package=True,  # TODO: Set to True if you are using Whisper as a package, False if you are using the Whisper model directly
    model_name="small.en",  # TODO: Update this to the name of the Whisper model to use if you are using Whisper as a package
    model_path = "path/to/your/model/whisper-small.en/" ,  # TODO: update this to the path to the Whisper model if you are using the Whisper model directly
    
    min_duration = 3, # Duration in seconds at the beginning of the recording without silence detection
    silence_threshold = 0.01,  # Threshold for silence detection
    silence_duration = 1,  # Duration in seconds to consider as silence
)

100%|███████████████████████████████████████| 461M/461M [02:50<00:00, 2.84MiB/s]


In [7]:
xtts = XTTS(
    speaker_wav = "./voice_examples/edited_voice_example_lycaon_en.wav",  # TODO: update this to the speaker's voice file
    language = "en",  # TODO: update this to the language you wish to use
)

 > Using model: xtts


In [8]:
ollama = Ollama(
    url = "http://localhost:11434/api/chat",  # The URL of the Ollama API
    model = "llama3.1:8b",  # TODO: update this to the model you wish to use
    xtts = xtts,  # The XTTS object used to generate the audio
    temperature = 0.5,  # The temperature of the model
)

In [9]:
system_prompt = '''
You are a helpful assistant. Please respond in a conversational, speech-like style without any formatting.
'''  # TODO: update this to your own prompt

# Main

In [10]:
def main():
    messages = []

    messages.append({"role": "system", "content": system_prompt})
    
    while True:
        # user_input = input("Enter a prompt: ")
        user_input, always_silent = whisper.transcript()
        if always_silent or not user_input or user_input == "":
            print("Nothing input. Pausing.")
            halt_input = input("Press Enter to resume or enter 'bye' to exit.")
            if halt_input == "bye":
                bye = "Exiting, goodbye."
                print(bye)
                xtts.speak(bye)
                break
            else:
                continue
        print(user_input)
        
        messages.append({"role": "user", "content": user_input})
        message = ollama.chat(messages)
        messages.append(message)
        print("\n\n")


if __name__ == "__main__":
    main()

Recording...
Recording complete.
 Hello, hello. Why is the sky blue?
The sky being blue is actually one of those really cool science-y things that's pretty easy to understand once you know what's going on.


 > Text splitted to sentences.
["The sky being blue is actually one of those really cool science-y things that's pretty easy to understand once you know what's going on."]


 > Processing time: 10.303298234939575
 > Real-time factor: 0.9148400798934413
So, basically, when sunlight enters Earth's atmosphere, it encounters all sorts of tiny molecules like nitrogen and oxygen. These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths.
 > Text splitted to sentences.
["So, basically, when sunlight enters Earth's atmosphere, it encounters all sorts of tiny molecules like nitrogen and oxygen.", 'These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths.']
 > Processing time: 19.825655698776245
 > Real-time factor: 0.7850764829571867
 That's why we see blue skies most of the time!

It's kind of like when you're outside on a sunny day and you look up at the sky - it looks bright blue because all that scattered sunlight is bouncing around everywhere, hitting your eyes from every direction.

Of course, there are some othe