In [None]:
import os
import sys
import torch

from IPython.display import Audio

%load_ext autoreload
%autoreload 2

torch.manual_seed(1234)

## Set Global Paths

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
# device='mps' if torch.backends.mps.is_available() else 'cpu'

print("Cuda available: ", torch.cuda.is_available())
print("MPS available: ", torch.backends.mps.is_available())
print("Using device: ", device)

# TTS Paths
STYLETTS2_CODE_ROOT = '/Users/pn/dev/avtar/other/StyleTTS2' # where StyleTTS2 repo was cloned to
STYLETTS2_CKPT_ROOT = '/Users/pn/dev/avtar/other/Style-Talker/models/styletts2/epoch_2nd_00038.pth'
ESPEAK_PATH = '/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib' # None

# Audio LLM's Paths
QWENAUDIO_CKPT_ROOT = '/Users/pn/dev/avtar/other/Style-Talker/models/qwenaudio/r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820'
# '/engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_ep100_cos/checkpoint-28000'

# Locate StyleTTS2's repository
if str(STYLETTS2_CODE_ROOT) not in sys.path:
    sys.path.append(str(STYLETTS2_CODE_ROOT))

## Load Style-Talker

In [None]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model=None, # 'openai/whisper-large-v3',
    espeak_path=ESPEAK_PATH,
    device=device
)

## Inference with history texts and styles pre-computed

In [None]:
n = 0 # conversation index
i = 3 # round index

In [None]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_texts': [
        open(f'samples/dailytalk/{n}/r{i}.txt', 'r').read(),
        open(f'samples/dailytalk/{n}/r{i+1}.txt', 'r').read()
    ],
    'history_styles': [
        torch.load(f'samples/dailytalk/{n}/r{i}.pt'),
        torch.load(f'samples/dailytalk/{n}/r{i+1}.pt'),
    ],
}

In [None]:
# generated = styletalker(**sample_inputs, override_text = "You know, I appreciate you saying that. I mean, you know, like, I don't know if I'm gonna be doing this for another 10 years or whatever, but I, I really enjoy doing it and I really enjoy the feedback that I get from people. So, you know, if I'm doing it, I'm doing it. And if I'm not doing it, I'm not doing it. But, you know, I'm gonna do it for as long as I can.")
generated = styletalker(**sample_inputs, override_text = "Oh, my goodness. I mean, it's been a whole new world of, of, of things. Um, but I would say the most mischievous thing they've done, I think, is that they've learned how to get into the trash. So, um, they've been getting into the trash, um, at night. And they, they leave their little paw prints all over the trash and they pull things out and, um-")
wav = generated['audio']
text = generated['text']

print(text)

### History -3

In [None]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [None]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [None]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [None]:
Audio(wav, rate=24000)

### Ground-truth follow-up

In [None]:
Audio(f'samples/dailytalk/{n}/r{i+3}.wav')

## Inference with history speeches

In [None]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model='openai/whisper-large-v3', # offline asr model
    espeak_path=ESPEAK_PATH,
    device=device
)

In [None]:
n = 0 # conversation index
i = 2 # round index

In [None]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_speeches': [
        f'samples/dailytalk/{n}/r{i}.wav',
        f'samples/dailytalk/{n}/r{i+1}.wav',
    ]
}

### Transcribe and compute styles of history speeches

In [None]:
history_texts = [
    styletalker.transcribe(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


In [None]:
history_styles = [
    styletalker.compute_style(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


### Or, pass in history_speeches directly

In [None]:
generated = styletalker("I really like Harry Potter", **sample_inputs)
wav = generated['audio']
text = generated['text']

print(text)

### History -3

In [None]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [None]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [None]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [None]:
Audio(wav, rate=24000)

In [None]:
from scipy.io.wavfile import write

sample_rate = 24000
filename = "generated-follow-up.wav"

write(filename, sample_rate, wav)
print(f"Audio saved to {filename}")