In [1]:
import os
import sys
import torch

from IPython.display import Audio

%load_ext autoreload
%autoreload 2

torch.manual_seed(1234)

<torch._C.Generator at 0x108ce2950>

## Set Global Paths

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'
# device='mps' if torch.backends.mps.is_available() else 'cpu'

print("Cuda available: ", torch.cuda.is_available())
print("MPS available: ", torch.backends.mps.is_available())
print("Using device: ", device)

# TTS Paths
STYLETTS2_CODE_ROOT = '/Users/pn/dev/avtar/other/StyleTTS2' # where StyleTTS2 repo was cloned to
STYLETTS2_CKPT_ROOT = '/Users/pn/dev/avtar/other/Style-Talker/models/styletts2/epoch_2nd_00038.pth'
ESPEAK_PATH = '/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib' # None

# Audio LLM's Paths
QWENAUDIO_CKPT_ROOT = '/Users/pn/dev/avtar/other/Style-Talker/models/qwenaudio/r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820'
# '/engram/naplab/projects/StyleTalker/QwenCkpts/DT_styletalker_ep100_cos/checkpoint-28000'

# Locate StyleTTS2's repository
if str(STYLETTS2_CODE_ROOT) not in sys.path:
    sys.path.append(str(STYLETTS2_CODE_ROOT))

Cuda available:  False
MPS available:  True
Using device:  cpu


## Load Style-Talker

In [3]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model=None, # 'openai/whisper-large-v3',
    espeak_path=ESPEAK_PATH,
    device=device
)

  from .autonotebook import tqdm as notebook_tqdm
Try importing flash-attention for faster inference...
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00,  5.72it/s]


audio_start_id: 155163, audio_end_id: 155164, audio_pad_id: 151851.
in_style_id: 151769, out_style_id: 151770
trainable params: 31,981,568 || all params: 8,419,877,888 || trainable%: 0.37983410716181637
LoRA Loaded from /Users/pn/dev/avtar/other/Style-Talker/models/qwenaudio/r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820/lora.pt.
Initialized adapted & finetuned QwenAudio for dialog understanding.
177


  peft_state_dict = torch.load(path, map_location=torch.device(device))
  torch.load(
  torch.load(
  params = torch.load(model_path, map_location='cpu')['model']
  params = torch.load(path, map_location='cpu')['net']
  checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
  WeightNorm.apply(module, name, dim)
  params_whole = torch.load(ckpt_root, map_location=device)


Initialized finetuned StyleTTS2 for dialog generation.


## Inference with history texts and styles pre-computed

In [4]:
n = 0 # conversation index
i = 2 # round index

In [5]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_texts': [
        open(f'samples/dailytalk/{n}/r{i}.txt', 'r').read(),
        open(f'samples/dailytalk/{n}/r{i+1}.txt', 'r').read()
    ],
    'history_styles': [
        torch.load(f'samples/dailytalk/{n}/r{i}.pt'),
        torch.load(f'samples/dailytalk/{n}/r{i+1}.pt'),
    ],
}

  torch.load(f'samples/dailytalk/{n}/r{i}.pt'),
  torch.load(f'samples/dailytalk/{n}/r{i+1}.pt'),


In [10]:
generated = styletalker(**sample_inputs)
wav = generated['audio']
text = generated['text']

print(text)

'About three hundred dollars.'


### History -3

In [11]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [12]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [13]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [14]:
Audio(wav, rate=24000)

### Ground-truth follow-up

In [15]:
Audio(f'samples/dailytalk/{n}/r{i+3}.wav')

## Inference with history speeches

In [16]:
from inference.styletalker import StyleTalker

styletalker = StyleTalker(
    tts_ckpt_root=STYLETTS2_CKPT_ROOT,
    audiollm_ckpt_root=QWENAUDIO_CKPT_ROOT,
    tts_code_root=STYLETTS2_CODE_ROOT,
    audiollm_kwargs={
        'bf16': True,
        'lora_r': 16,
        'lora_modules': ['c_attn', 'attn.c_proj', 'w1', 'w2', 'query', 'key', 'value'],
    },
    asr_model='openai/whisper-large-v3', # offline asr model
    espeak_path=ESPEAK_PATH,
    device=device
)

Try importing flash-attention for faster inference...
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████| 9/9 [00:01<00:00,  7.31it/s]


audio_start_id: 155163, audio_end_id: 155164, audio_pad_id: 151851.
in_style_id: 151769, out_style_id: 151770
trainable params: 31,981,568 || all params: 8,419,877,888 || trainable%: 0.37983410716181637
LoRA Loaded from /Users/pn/dev/avtar/other/Style-Talker/models/qwenaudio/r16_lr1e-4_ga8_ls1_ep20/checkpoint-44820/lora.pt.
Initialized adapted & finetuned QwenAudio for dialog understanding.
177


  peft_state_dict = torch.load(path, map_location=torch.device(device))
  torch.load(
  torch.load(
  params = torch.load(model_path, map_location='cpu')['model']
  params = torch.load(path, map_location='cpu')['net']
  checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
  WeightNorm.apply(module, name, dim)
  params_whole = torch.load(ckpt_root, map_location=device)


Initialized finetuned StyleTTS2 for dialog generation.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initialized openai/whisper-large-v3 for offline speech recognition.


In [17]:
n = 0 # conversation index
i = 2 # round index

In [18]:
sample_inputs = {
    'latest_speech': f'samples/dailytalk/{n}/r{i+2}.wav',
    'history_speeches': [
        f'samples/dailytalk/{n}/r{i}.wav',
        f'samples/dailytalk/{n}/r{i+1}.wav',
    ]
}

### Transcribe and compute styles of history speeches

In [19]:
history_texts = [
    styletalker.transcribe(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


In [20]:
history_styles = [
    styletalker.compute_style(history_speech)
    for history_speech in sample_inputs['history_speeches']
]


### Or, pass in history_speeches directly

In [None]:
generated = styletalker(**sample_inputs)
wav = generated['audio']
text = generated['text']

print(text)

### History -3

In [19]:
Audio(f'samples/dailytalk/{n}/r{i}.wav')

### History -2

In [21]:
Audio(f'samples/dailytalk/{n}/r{i+1}.wav')

### History -1 (raw speech without transcription)

In [22]:
Audio(f'samples/dailytalk/{n}/r{i+2}.wav')

### Generated follow-up

In [23]:
Audio(wav, rate=24000)