In [1]:
import os
import sys
import torch

from IPython.display import Audio

%load_ext autoreload
%autoreload 2

torch.manual_seed(1234)

<torch._C.Generator at 0x2b20032255d0>

## Set Global Paths

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'

# TTS Paths
STYLETTS2_CODE_ROOT = '/engram/naplab/users/xj2289/repos/StyleTTS2' # where StyleTTS2 repo was cloned to
STYLETTS2_CKPT_ROOT = '/engram/naplab/projects/StyleTalker/DailyTalkModel/epoch_2nd_00038.pth'
ESPEAK_PATH = '/engram/naplab/shared/espeak-ng/lib/libespeak-ng.so.1' # None

# Audio LLM's Paths
QWENAUDIO_CKPT_ROOT = '/engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820'
# '/engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_ep100_cos/checkpoint-28000'

# Locate StyleTTS2's repository
if str(STYLETTS2_CODE_ROOT) not in sys.path:
    sys.path.append(str(STYLETTS2_CODE_ROOT))

## Load Style-Talker

In [3]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model=None, # 'openai/whisper-large-v3',
    espeak_path=ESPEAK_PATH,
    device=device
)

Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

audio_start_id: 155163, audio_end_id: 155164, audio_pad_id: 151851.
in_style_id: 151769, out_style_id: 151770
trainable params: 31,981,568 || all params: 8,419,877,888 || trainable%: 0.3798
LoRA Loaded from /engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820/lora.pt.
Initialized adapted & finetuned QwenAudio for dialog understanding.
177




Initialized finetuned StyleTTS2 for dialog generation.


## Inference with history texts and styles pre-computed

In [4]:
n = 0 # conversation index
i = 2 # round index

In [6]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_texts': [
        open(f'samples/dailytalk/{n}/r{i}.txt', 'r').read(),
        open(f'samples/dailytalk/{n}/r{i+1}.txt', 'r').read()
    ],
    'history_styles': [
        torch.load(f'samples/dailytalk/{n}/r{i}.pt'),
        torch.load(f'samples/dailytalk/{n}/r{i+1}.pt'),
    ],
}

In [7]:
generated = styletalker(**sample_inputs)
wav = generated['audio']
text = generated['text']

print(text)

'About five hundred dollars.'


### History -3

In [8]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [9]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [10]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [11]:
Audio(wav, rate=24000)

### Ground-truth follow-up

In [16]:
Audio(f'samples/dailytalk/{n}/r{i+3}.wav')

## Inference with history speeches

In [12]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model='openai/whisper-large-v3', # offline asr model
    espeak_path=ESPEAK_PATH,
    device=device
)

Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

audio_start_id: 155163, audio_end_id: 155164, audio_pad_id: 151851.
in_style_id: 151769, out_style_id: 151770
trainable params: 31,981,568 || all params: 8,419,877,888 || trainable%: 0.3798
LoRA Loaded from /engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820/lora.pt.
Initialized adapted & finetuned QwenAudio for dialog understanding.
177




Initialized finetuned StyleTTS2 for dialog generation.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initialized openai/whisper-large-v3 for offline speech recognition.


In [13]:
n = 0 # conversation index
i = 2 # round index

In [14]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_speeches': [
        f'samples/dailytalk/{n}/r{i}.wav',
        f'samples/dailytalk/{n}/r{i+1}.wav',
    ]
}

### Transcribe and compute styles of history speeches

In [15]:
history_texts = [
    styletalker.transcribe(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


In [16]:
history_styles = [
    styletalker.compute_style(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


### Or, pass in history_speeches directly

In [20]:
generated = styletalker(**sample_inputs)
wav = generated['audio']
text = generated['text']

print(text)

'Umm about two thousand dollars.'


### History -3

In [19]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [21]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [22]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [23]:
Audio(wav, rate=24000)