# STalk Model Architecture

## Preparation

In [None]:
USE_CUDA = False

In [None]:
!pip install huggingface_hub transformers

In [None]:
!pip install intel-npu-acceleration-library

In [None]:
if USE_CUDA:
    !pip uninstall torch torchvision torchaudio -y
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
else:
    !pip uninstall torch torchvision torchaudio -y
    !pip install torch torchvision torchaudio

In [None]:
if USE_CUDA:
    !pip uninstall llama-cpp-python -y
    !pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
else:
    !pip uninstall llama-cpp-python -y
    !pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu

In [None]:
!pip install pyannote-audio

In [None]:
!pip install faster-whisper

In [None]:
!pip install git+https://github.com/wenet-e2e/wespeaker.git

In [None]:
!pip install pyaudio

## Initialize Pretrained Model

In [1]:
import intel_npu_acceleration_library

import gc

import wave
import pyaudio

import torch
import torchaudio

from llama_cpp import Llama

from pyannote.audio import Audio
from pyannote.core import Segment

from faster_whisper import WhisperModel

from huggingface_hub import hf_hub_download
import wespeaker

DEVICE = "cuda"
if not torch.cuda.is_available():
    DEVICE = "cpu"
    print("INFO: CUDA is diabled on this machine.\n\n")

print("PyTorch:", torch.__version__)
print("TorchAudio:", torchaudio.__version__)
print("Uses Device:", DEVICE.upper())

INFO: CUDA is diabled on this machine.


PyTorch: 2.3.1+cpu
TorchAudio: 2.3.1+cpu
Uses Device: CPU


In [2]:
class ChatHistory(list):
    messages = []
    
    @classmethod
    def add_messages(cls, role, content):
        if isinstance(content, str):
            cls.messages.append({ 'role': role, 'content': content })
        else:
            for r, c in zip(role, content):
                cls.messages.append({ 'role': r, 'content': c })
    
    @classmethod
    def create_prompt(cls, system_prompt: str, user_prompt: str = ""):
        return [
            {
                "role": "system",
                "content": system_prompt
            },
            *cls.messages,
            {
                "role": "user",
                "content": user_prompt
            }
        ]

In [3]:
def token_stream(token):
    delta = token["choices"][0]["delta"]
    if "content" not in delta:
        return ""
    else:
        return delta["content"]

In [4]:
def get_llama3():
    model_id = "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF"

    chat = Llama.from_pretrained(
        repo_id=model_id,
        filename="*Q4_K_M.gguf",
        #chat_format="llama-3",
        verbose=False
    ).create_chat_completion
    
    def llama3(system_prompt, user_prompt, temp=0.5, show_prompt=False):
        prompt = ChatHistory.create_prompt(system_prompt, user_prompt)

        if show_prompt:
            print("PROMPT:")
            for line in prompt:
                print(line)
            print()
        
        return chat(prompt, temperature=temp, stream=True)
    
    return llama3

In [5]:
def get_whisper():
    model_size = "medium"  #@param ['tiny', 'base', 'small', 'medium', 'large', 'large-v2', 'large-v3']
    compute_type = "int8"  #@param ['float16', 'int8']

    return WhisperModel(model_size, device=DEVICE, compute_type=compute_type).transcribe

In [6]:
def extract_embedding(model, pcm, sample_rate):
    pcm = pcm.to(torch.float)
    if sample_rate != model.resample_rate:
        pcm = torchaudio.transforms.Resample(
            orig_freq=sample_rate, new_freq=model.resample_rate)(pcm)
    feats = model.compute_fbank(
        pcm,
        sample_rate=model.resample_rate,
        cmn=True
    )
    feats = feats.unsqueeze(0)
    feats = feats.to(model.device)
    model.model.eval()
    with torch.no_grad():
        outputs = model.model(feats)
        outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
    embedding = outputs[0].to(torch.device('cpu'))
    return embedding

In [7]:
def recognize(model, pcm, sample_rate):
    q = extract_embedding(model, pcm, sample_rate)
    best_score = 0.0
    best_name = ''
    for name, e in model.table.items():
        score = model.cosine_similarity(q, e)
        if best_score < score:
            best_score = score
            best_name = name
        del score
        gc.collect()
    return {'name': best_name, 'confidence': best_score}

In [8]:
def get_resnet152():
    model_id = "Wespeaker/wespeaker-voxceleb-resnet152-LM"
    model_name = model_id.replace("Wespeaker/wespeaker-", "").replace("-", "_")
    
    root_dir = hf_hub_download(model_id, filename=model_name+".onnx").replace(model_name+".onnx", "")
    
    import os
    if not os.path.isfile(root_dir+"avg_model.pt"):
        os.rename(hf_hub_download(model_id, filename=model_name+".pt"), root_dir+"avg_model.pt")
    if not os.path.isfile(root_dir+"config.yaml"):
        os.rename(hf_hub_download(model_id, filename=model_name+".yaml"), root_dir+"config.yaml")

    resnet = wespeaker.load_model_local(root_dir)

    #print("Compile model for the NPU")
    #resnet.model = intel_npu_acceleration_library.compile(resnet.model)
    
    def resnet152(ado, sample_rate=None):
        if isinstance(ado, str):
            return resnet.recognize(ado)
        else:
            return recognize(resnet, ado, sample_rate)
    
    resnet152.__dict__['register'] = lambda *args, **kwargs: resnet.register(*args, **kwargs)
    
    return resnet152

In [9]:
llama3 = get_llama3()
print("INFO: Llama3 Ready -", llama3)

INFO: Llama3 Ready - <function get_llama3.<locals>.llama3 at 0x000001828EBAA3E0>


In [10]:
whisper = get_whisper()
print("INFO: Whisper Ready -", whisper)

INFO: Whisper Ready - <bound method WhisperModel.transcribe of <faster_whisper.transcribe.WhisperModel object at 0x000001828EB17B10>>


In [11]:
audio = Audio()
resnet152 = get_resnet152()
print("INFO: ResNet152 Ready -", resnet152)



INFO: ResNet152 Ready - <function get_resnet152.<locals>.resnet152 at 0x000001828EBAB060>


## Model Ready

#### Insert System Chat Template to Llama3

In [12]:
system_prompt = "You are a helpful, smart, kind, and efficient Conversation Analysis and Recommendation AI System. You always fulfill the user's requests to the best of your ability. You need to keep listen to the conversations. Please answer in Korean language."

In [13]:
for chunk in llama3(system_prompt, ""):
    print(token_stream(chunk), end="", flush=True)
print()

😊

안녕하세요! 저는 지능형 대화 분석 및 추천 AI 시스템입니다. 사용자의 요청을 최선을 다해 충족합니다. 대화를.listen하고 답변을 제공할 것입니다. 한국어로 답변이 가능합니다. 무엇을 도와드릴까요? 🤔


#### Speaker Registration to ResNet293

In [14]:
speaker1 = "민서", "./SpeakerDiarization/sample_conversation/real/sentence_F.wav"
speaker2 = "연우", "./SpeakerDiarization/sample_conversation/real/sentence_M.wav"
speaker1, speaker2

(('민서', './SpeakerDiarization/sample_conversation/real/sentence_F.wav'),
 ('연우', './SpeakerDiarization/sample_conversation/real/sentence_M.wav'))

In [15]:
resnet152.register(*speaker1)
resnet152.register(*speaker2)

In [16]:
user_prompt = f"Based on the conversations between {speaker1[0]} and {speaker2[0]}, on be half of {speaker2[0]}, do recommend a new topic sentence related the current situation or their personal interests."

## Test

In [17]:
TEST_MODE = False

In [18]:
RECORD_FORMAT = pyaudio.paInt16
RECORD_RATE = 44100
RECORD_CHANNELS = 1
RECORD_CHUNK = 1024
recoder = pyaudio.PyAudio()

In [19]:
RECORD_SECONDS = 5
FRAME_LENGTH = int(RECORD_RATE / RECORD_CHUNK * RECORD_SECONDS)

CACHE_FILENAME = "./cache/cache.wav"
OUTPUT_FILENAME = "conversation_output.wav"

In [20]:
from time import sleep

def play_test_audio():
    audio_path = "./SpeakerDiarization/sample_conversation/real/conversation_0530_out.wav"
    test_file = wave.open(audio_path, "rb")

    player = recoder.open(
        format=recoder.get_format_from_width(test_file.getsampwidth()),
        channels=test_file.getnchannels(),
        rate=test_file.getframerate(),
        output=True,
        stream_callback=lambda _, frame_count, __, ___: (test_file.readframes(frame_count), pyaudio.paContinue)
    )

    player.start_stream()
    print("Playing test audio...")
    
    while player.is_active():
        sleep(0.1)

    player.stop_stream()
    player.close()

In [21]:
from multiprocessing import Process

RECORD_INTERRUPTED = False
RECORDED_FRAMES = []

def record_audio():
    stream = recoder.open(
        format=RECORD_FORMAT, channels=RECORD_CHANNELS,
        rate=RECORD_RATE, input=True,
        frames_per_buffer=RECORD_CHUNK
    )
    
    print("Recording started...")

    if TEST_MODE:
        process = Process(target=play_test_audio)
        process.start()

    output_file = wave.open(OUTPUT_FILENAME, "wb")
    output_file.setnchannels(RECORD_CHANNELS)
    output_file.setsampwidth(recoder.get_sample_size(RECORD_FORMAT))
    output_file.setframerate(RECORD_RATE)
    
    while not RECORD_INTERRUPTED:
        frame = b"".join([stream.read(RECORD_CHUNK) for _ in range(FRAME_LENGTH)])
        output_file.writeframes(frame)
        RECORDED_FRAMES.append(frame)

    stream.stop_stream()
    stream.close()
    output_file.close()
    if TEST_MODE:
        process.terminate()
    print("Recording stopped.")

In [22]:
from threading import Thread

try:
    RECORD_INTERRUPTED = False
    RECORDED_FRAMES = []
    
    record_thread = Thread(target=record_audio)
    record_thread.start()
    
    i = 0
    last_frame = []
    identical = True
    
    while True:
        if not len(RECORDED_FRAMES):
            continue
        frame = RECORDED_FRAMES.pop(0)
        
        CACHE = CACHE_FILENAME.replace(".wav", f"_{i}.wav")
        
        cache_file = wave.open(CACHE, "wb")
        cache_file.setnchannels(RECORD_CHANNELS)
        cache_file.setsampwidth(recoder.get_sample_size(RECORD_FORMAT))
        cache_file.setframerate(RECORD_RATE)
        cache_file.writeframes(b"".join([*last_frame, frame]))

        segments, info = whisper(CACHE, beam_size=5, word_timestamps=False)
        print("Transcription finished.")
        segments = iter(segments)
        
        if not identical:
            next(segments)
            last_frame = [last_frame[-1]]
        else:
            ChatHistory.messages = ChatHistory.messages[:-1]

        for index, segment in enumerate(segments):
            try:
                embedding = audio.crop(CACHE, Segment(segment.start, segment.end))
            except:
                embedding = (CACHE, )

            print(f"{i+1}-{index+1}", (segment.start, segment.end), "->", end=" ", flush=True)

            speaker = resnet152(*embedding)
            print("[%s] %s" % (speaker['name'], segment.text.strip()))
            ChatHistory.add_messages(speaker['name'], segment.text.strip())
            
            del embedding, speaker
            
            identical = index <= 1
        
        if identical:
            last_frame = [b"".join([*last_frame, frame])]
        else:
            last_frame.append(frame)

        del frame
        gc.collect()
        
        i += 1
except KeyboardInterrupt:
    print("Recording stopped by user")
finally:
    RECORD_INTERRUPTED = True
    record_thread.join()

Recording started...
Transcription finished.
1-1 (0.0, 5.0) -> [민서] 요즘 가족들이 잘 지내고 계신가요?
Transcription finished.
2-1 (0.0, 4.4) -> [민서] 요즘 가족들이 잘 지내고 계신가요?
2-2 (4.4, 7.36) -> [연우] 네, 가족들은 잘 지내고 있어요.
2-3 (7.36, 10.4) -> [민서] 저는 아이들과 스포츠를 즐기는 것 같아요.
Transcription finished.
Transcription finished.
4-1 (2.5, 6.5) -> [민서] 저는 아이들과 스포츠를 즐기는 걸 좋아해요.
4-2 (6.5, 9.5) -> [민서] 어떤 종류의 스포츠를 함께 하시나요?
4-3 (9.5, 11.9) -> [연우] 축구와 테니스를 함께 하고 있어요.
4-4 (11.9, 15.0) -> [민서] 아이들이 적극적으로 참여하면서 즐거운 시간을 보내요.
Transcription finished.
5-1 (4.48, 11.16) -> [연우] 축구와 테니스를 함께 하고 있어요 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요
5-2 (11.16, 17.04) -> [민서] 축구와 테니스는 정말 가족끼리 함께 하기 좋은 스포츠
Transcription finished.
6-1 (0.0, 6.0) -> [연우] 축구와 테니스를 함께하고 있어요. 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요.
6-2 (6.0, 15.0) -> [민서] 축구와 테니스는 정말 가족끼리 함께하기 좋은 스포츠죠. 활동적인 시간을 보내면서 가족 간의 유대감도 높일 수 있어요.
Transcription finished.
7-1 (0.0, 2.0) -> [연우] 축구와 테니스를 함께하고 있어요.
7-2 (2.0, 6.0) -> [연우] 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요.
7-3 (6.0, 11.0) -> [민서] 축구와 테니스는 정말 가족끼리 함

In [23]:
recoder.terminate()
print("Recording finished.")

Recording finished.


In [24]:
for message in ChatHistory.messages:
    print(f"[{message['role']}] {message['content']}")

[민서] 요즘 가족들이 잘 지내고 계신가요?
[연우] 네, 가족들은 잘 지내고 있어요.
[민서] 저는 아이들과 스포츠를 즐기는 것 같아요.
[민서] 저는 아이들과 스포츠를 즐기는 걸 좋아해요.
[민서] 어떤 종류의 스포츠를 함께 하시나요?
[연우] 축구와 테니스를 함께 하고 있어요.
[민서] 아이들이 적극적으로 참여하면서 즐거운 시간을 보내요.
[연우] 축구와 테니스를 함께 하고 있어요 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요
[연우] 축구와 테니스를 함께하고 있어요. 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요.
[연우] 축구와 테니스를 함께하고 있어요.
[연우] 아이들이 적극적으로 참여하면서 즐거운 시간을 보내고 있어요.
[민서] 축구와 테니스는 정말 가족끼리 함께하기 좋은 스포츠죠.
[민서] 활동적인 시간을 보내면서 가족 간의 유대감도 높일 수 있을 거예요.
[민서] 다음에는 함께 스프를 즐기며
[민서] 축구와 테니스는 정말 가족끼리 함께하기 좋은 스포츠죠. 활동적인 시간을 보내면서 가족 간의 유대감도 높일 수 있을 거에요.
[민서] 다음에는 함께 스포츠를 즐기며 가족끼리 더 많은 시간을 보내는 건 어떨까요?
[민서] 네 좋아요


In [None]:
for chunk in llama3(system_prompt, user_prompt):
    print(token_stream(chunk), end="", flush=True)
print()