# STalk Model Architecture

## Preparation

In [None]:
USE_CUDA = False

In [None]:
!pip install huggingface_hub transformers

In [None]:
!pip install intel-npu-acceleration-library

In [None]:
if USE_CUDA:
    !pip uninstall torch torchvision torchaudio -y
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
else:
    !pip uninstall torch torchvision torchaudio -y
    !pip install torch torchvision torchaudio

In [None]:
if USE_CUDA:
    !pip uninstall llama-cpp-python -y
    !pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
else:
    !pip uninstall llama-cpp-python -y
    !pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu

In [None]:
!pip install pyannote-audio

In [None]:
!pip install faster-whisper

In [None]:
!pip install git+https://github.com/wenet-e2e/wespeaker.git

In [None]:
!pip install pyaudio

## Initialize Pretrained Model

In [None]:
import intel_npu_acceleration_library

from multiprocessing import Process, Manager

import os
import gc

import wave
import pyaudio

import torch
import torchaudio

from llama_cpp import Llama

from pyannote.audio import Audio
from pyannote.core import Segment

from faster_whisper import WhisperModel

from huggingface_hub import hf_hub_download
import wespeaker

DEVICE = "cuda"
if not torch.cuda.is_available():
    DEVICE = "cpu"
    print("INFO: CUDA is diabled on this machine.\n")

print("PyTorch:", torch.__version__)
print("TorchAudio:", torchaudio.__version__)
print("Uses Device:", DEVICE.upper())

In [None]:
class ChatHistory(list):
    messages = []
    
    @classmethod
    def add_messages(cls, role, content):
        if isinstance(content, str):
            cls.messages.append({ 'role': role, 'content': content })
        else:
            for r, c in zip(role, content):
                cls.messages.append({ 'role': r, 'content': c })
    
    @classmethod
    def create_prompt(cls, system_prompt: str, user_prompt: str = ""):
        return [
            {
                "role": "system",
                "content": system_prompt
            },
            *cls.messages,
            {
                "role": "user",
                "content": user_prompt
            }
        ]

In [None]:
def token_stream(token):
    delta = token["choices"][0]["delta"]
    if "content" not in delta:
        return ""
    else:
        return delta["content"]

In [None]:
def get_llama3():
    model_id = "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF"

    chat = Llama.from_pretrained(
        repo_id=model_id,
        filename="*Q4_K_M.gguf",
        #chat_format="llama-3",
        verbose=False
    ).create_chat_completion
    
    def llama3(system_prompt, user_prompt, temp=0.5, show_prompt=False):
        prompt = ChatHistory.create_prompt(system_prompt, user_prompt)

        if show_prompt:
            print("PROMPT:")
            for line in prompt:
                print(line)
            print()
        
        return chat(prompt, temperature=temp, stream=True)
    
    return llama3

In [None]:
def get_whisper():
    model_size = "medium"  #@param ['tiny', 'base', 'small', 'medium', 'large', 'large-v2', 'large-v3']
    compute_type = "int8"  #@param ['float16', 'int8']

    return WhisperModel(model_size, device=DEVICE, cpu_threads=12, compute_type=compute_type).transcribe

In [None]:
def extract_embedding(model, pcm, sample_rate):
    pcm = pcm.to(torch.float)
    if sample_rate != model.resample_rate:
        pcm = torchaudio.transforms.Resample(
            orig_freq=sample_rate, new_freq=model.resample_rate)(pcm)
    feats = model.compute_fbank(
        pcm,
        sample_rate=model.resample_rate,
        cmn=True
    )
    feats = feats.unsqueeze(0)
    feats = feats.to(model.device)
    model.model.eval()
    with torch.no_grad():
        outputs = model.model(feats)
        outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
    embedding = outputs[0].to(torch.device('cpu'))
    return embedding

In [None]:
def recognize(model, pcm, sample_rate):
    q = extract_embedding(model, pcm, sample_rate)
    best_score = 0.0
    best_name = ''
    for name, e in model.table.items():
        score = model.cosine_similarity(q, e)
        if best_score < score:
            best_score = score
            best_name = name
        del score
        gc.collect()
    return {'name': best_name, 'confidence': best_score}

In [None]:
def get_resnet152():
    model_id = "Wespeaker/wespeaker-voxceleb-resnet152-LM"
    model_name = model_id.replace("Wespeaker/wespeaker-", "").replace("-", "_")
    
    root_dir = hf_hub_download(model_id, filename=model_name+".onnx").replace(model_name+".onnx", "")
    
    import os
    if not os.path.isfile(root_dir+"avg_model.pt"):
        os.rename(hf_hub_download(model_id, filename=model_name+".pt"), root_dir+"avg_model.pt")
    if not os.path.isfile(root_dir+"config.yaml"):
        os.rename(hf_hub_download(model_id, filename=model_name+".yaml"), root_dir+"config.yaml")

    resnet = wespeaker.load_model_local(root_dir)

    #print("Compile model for the NPU")
    #resnet.model = intel_npu_acceleration_library.compile(resnet.model)
    
    def resnet152(ado, sample_rate=None):
        if isinstance(ado, str):
            return resnet.recognize(ado)
        else:
            return recognize(resnet, ado, sample_rate)
    
    resnet152.__dict__['register'] = lambda *args, **kwargs: resnet.register(*args, **kwargs)
    
    return resnet152

In [None]:
llama3 = get_llama3()
print("INFO: Llama3 Ready -", llama3)

In [None]:
whisper = get_whisper()
print("INFO: Whisper Ready -", whisper)

In [None]:
audio = Audio()
resnet152 = get_resnet152()
print("INFO: ResNet152 Ready -", resnet152)

## Model Ready

#### Insert System Chat Template to Llama3

In [None]:
system_prompt = "You are a helpful, smart, kind, and efficient Conversation Analysis and Recommendation AI System. You always fulfill the user's requests to the best of your ability. You need to keep listen to the conversations. Please answer in Korean language."

In [None]:
for chunk in llama3(system_prompt, ""):
    print(token_stream(chunk), end="", flush=True)
print()

#### Speaker Registration to ResNet293

In [None]:
speaker1 = "민서", "./SpeakerDiarization/sample_conversation/real/sentence_F.wav"
speaker2 = "연우", "./SpeakerDiarization/sample_conversation/real/sentence_M.wav"
speaker1, speaker2

In [None]:
resnet152.register(*speaker1)
resnet152.register(*speaker2)

In [None]:
user_prompt = f"Based on the conversations between {speaker1[0]} and {speaker2[0]}, on be half of {speaker2[0]}, do recommend a new topic sentence related the current situation or their personal interests."

## Run

In [None]:
TEST_MODE = False

In [None]:
RECORD_FORMAT = pyaudio.paInt16
RECORD_RATE = 44100
RECORD_CHANNELS = 1
RECORD_CHUNK = 1024
recoder = pyaudio.PyAudio()

In [None]:
RECORD_SECONDS = 1
FRAME_LENGTH = int(RECORD_RATE / RECORD_CHUNK * RECORD_SECONDS)

CACHE_FOLDER = os.path.join(".", "cache")
OUTPUT_FILENAME = "conversation_output.wav"

if not os.path.isdir(CACHE_FOLDER):
    os.mkdir(CACHE_FOLDER)

In [None]:
def play_test_audio():
    audio_path = "./SpeakerDiarization/sample_conversation/real/conversation_0530_out.wav"
    test_file = wave.open(audio_path, "rb")

    player = recoder.open(
        format=recoder.get_format_from_width(test_file.getsampwidth()),
        channels=test_file.getnchannels(),
        rate=test_file.getframerate(),
        output=True,
        stream_callback=lambda _, frame_count, __, ___: (test_file.readframes(frame_count), pyaudio.paContinue)
    )

    player.start_stream()
    print("Playing test audio...")
    
    while player.is_active():
        sleep(0.1)

    player.stop_stream()
    player.close()

In [None]:
def record_audio(params):
    stream = recoder.open(
        format=RECORD_FORMAT, channels=RECORD_CHANNELS,
        rate=RECORD_RATE, input=True,
        frames_per_buffer=RECORD_CHUNK
    )
    
    print("Recording started...")

    if TEST_MODE:
        process = Process(target=play_test_audio)
        process.start()

    output_file = wave.open(OUTPUT_FILENAME, "wb")
    output_file.setnchannels(RECORD_CHANNELS)
    output_file.setsampwidth(recoder.get_sample_size(RECORD_FORMAT))
    output_file.setframerate(RECORD_RATE)
    
    while not params['interrupted']:
        read = [stream.read(RECORD_CHUNK) for _ in range(FRAME_LENGTH)]
        frame = b"".join(read)
        output_file.writeframes(frame)
        params['duration'] += len(read) / RECORD_RATE * RECORD_CHUNK

    stream.stop_stream()
    stream.close()
    output_file.close()
    if TEST_MODE:
        process.terminate()
    print("Recording stopped.")

## Release Test

In [None]:
from stalk_models import llama3, whisper, audio, resnet152, system_prompt
from stalk_streamer import record_audio, CACHE_FOLDER, OUTPUT_FILENAME

from multiprocessing import Process, Manager
from time import sleep
import os
import gc

import torch
import torchaudio
from pyannote.core import Segment

In [None]:
speaker1 = "민서", "./SpeakerDiarization/sample_conversation/real/sentence_F.wav"
speaker2 = "연우", "./SpeakerDiarization/sample_conversation/real/sentence_M.wav"
resnet152.register(*speaker1)
resnet152.register(*speaker2)

In [None]:
user_prompt = f"Based on the conversations between {speaker1[0]} and {speaker2[0]}, on be half of {speaker2[0]}, do recommend a new topic sentence related the current situation or their personal interests."

In [None]:
manager = Manager()
RECORD_PARAMS = manager.dict(interrupted=False, duration=0.0)

record_thread = Process(target=record_audio, kwargs=dict(params=RECORD_PARAMS))
record_thread.start()

start_offset = 0.0
temp_file = os.path.join(CACHE_FOLDER, "temp.wav")
error_count = 0

try:
    while not RECORD_PARAMS['duration']:
        sleep(0.001)  # Wait until the recording starts
    
    while True:
        audio_range = Segment(start_offset, RECORD_PARAMS['duration'])
        print("Transcribing audio...", audio_range)
        torchaudio.save(temp_file, *audio.crop(OUTPUT_FILENAME, audio_range))
        
        segments, info = whisper(temp_file, beam_size=5, word_timestamps=False)
        #print("Transcription finished.")
        segments = iter(segments)

        for segment in segments:
            try:
                crop_range = (start_offset + segment.start, start_offset + segment.end)
                portion = audio.crop(OUTPUT_FILENAME, Segment(crop_range[0], crop_range[1]))
                torchaudio.save(os.path.join(CACHE_FOLDER, f"{crop_range[0]}.partial.wav"), *portion)
                
                speaker = resnet152(*portion)

                print(f"\r{crop_range} -> [{speaker['name']}] {segment.text.strip()}", end="", flush=True)
                #ChatHistory.add_messages(speaker['name'], segment.text.strip())
                
                del portion, speaker
                torch.cuda.empty_cache()
    
                if start_offset != crop_range[0]:
                    start_offset = crop_range[0]
                    print()
            except:
                error_count += 1
                continue

        gc.collect()
except KeyboardInterrupt:
    print("Recording stopped by user")
finally:
    RECORD_PARAMS['interrupted'] = True
    record_thread.join()
    print("Error count:", error_count)
    manager.close()

In [None]:
recoder.terminate()
print("Recording finished.")

In [None]:
for message in ChatHistory.messages:
    print(f"[{message['role']}] {message['content']}")

In [None]:
for chunk in llama3(system_prompt, user_prompt):
    print(token_stream(chunk), end="", flush=True)
print()