In [1]:
# Copyright 2023, YOUDAO
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import streamlit as st
import os, glob
import numpy as np
from yacs import config as CONFIG
import torch
import re

from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p
from config.joint.config import Config
from models.prompt_tts_modified.jets import JETSGenerator
from models.prompt_tts_modified.simbert import StyleEncoder
from transformers import AutoTokenizer

import base64
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_WAV_VALUE = 32768.0

config = Config()

def scan_checkpoint(cp_dir, prefix, c=8):
    pattern = os.path.join(cp_dir, prefix + '?'*c)
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return None
    return sorted(cp_list)[-1]
@st.cache_resource
def get_models():
    
    am_checkpoint_path = scan_checkpoint(f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_')

    style_encoder_checkpoint_path = scan_checkpoint(f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6)#f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431' 

    with open(config.model_config_path, 'r') as fin:
        conf = CONFIG.load_cfg(fin)
    
    conf.n_vocab = config.n_symbols
    conf.n_speaker = config.speaker_n_labels

    style_encoder = StyleEncoder(config)
    model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu")
    model_ckpt = {}
    for key, value in model_CKPT['model'].items():
        new_key = key[7:]
        model_ckpt[new_key] = value
    style_encoder.load_state_dict(model_ckpt, strict=False)
    generator = JETSGenerator(conf).to(DEVICE)

    model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE)
    generator.load_state_dict(model_CKPT['generator'])
    generator.eval()

    tokenizer = AutoTokenizer.from_pretrained(config.bert_path)

    with open(config.token_list_path, 'r') as f:
        token2id = {t.strip():idx for idx, t, in enumerate(f.readlines())}

    with open(config.speaker2id_path, encoding='utf-8') as f:
        speaker2id = {t.strip():idx for idx, t in enumerate(f.readlines())}


    return (style_encoder, generator, tokenizer, token2id, speaker2id)

def get_style_embedding(prompt, tokenizer, style_encoder):
    prompt = tokenizer([prompt], return_tensors="pt")
    input_ids = prompt["input_ids"]
    token_type_ids = prompt["token_type_ids"]
    attention_mask = prompt["attention_mask"]
    with torch.no_grad():
        output = style_encoder(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
    )
    style_embedding = output["pooled_output"].cpu().squeeze().numpy()
    return style_embedding

def tts(name, text, prompt, content, speaker, models):
    (style_encoder, generator, tokenizer, token2id, speaker2id)=models
    

    style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
    content_embedding = get_style_embedding(content, tokenizer, style_encoder)
    print(speaker2id)
    print(123)
    speaker = speaker2id[speaker]
    
    text_int = [token2id[ph] for ph in text.split()]
    
    sequence = torch.from_numpy(np.array(text_int)).to(DEVICE).long().unsqueeze(0)
    sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE)
    style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0)
    content_embedding = torch.from_numpy(content_embedding).to(DEVICE).unsqueeze(0)
    speaker = torch.from_numpy(np.array([speaker])).to(DEVICE)

    with torch.no_grad():

        infer_output = generator(
                inputs_ling=sequence,
                inputs_style_embedding=style_embedding,
                input_lengths=sequence_len,
                inputs_content_embedding=content_embedding,
                inputs_speaker=speaker,
                alpha=1.0
            )

    audio = infer_output["wav_predictions"].squeeze()* MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')

    return audio


In [15]:
g2p = G2p()
#快乐、兴奋、悲伤、愤怒
models = get_models()
speakers = config.speakers
lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
text =  g2p_cn_en("哈哈哈哈哈哈哈，我就是宇智波一族", g2p, lexicon)
path = tts(0, text, "愤怒", "哈哈哈哈哈哈哈，我就是宇智波一族", "11614", models)

{'8051': 0, '11614': 1, '9017': 2, '6097': 3, '6671': 4, '6670': 5, '9136': 6, '11697': 7, '92': 8, '12787': 9, '1006': 10, '1012': 11, '1018': 12, '101': 13, '1025': 14, '1027': 15, '1028': 16, '102': 17, '1034': 18, '103': 19, '1040': 20, '1046': 21, '1049': 22, '104': 23, '1050': 24, '1051': 25, '1052': 26, '1058': 27, '1061': 28, '1065': 29, '1066': 30, '1069': 31, '1079': 32, '107': 33, '1081': 34, '1084': 35, '1085': 36, '1088': 37, '1092': 38, '1093': 39, '1094': 40, '1096': 41, '1097': 42, '1098': 43, '1107': 44, '110': 45, '1110': 46, '1112': 47, '1116': 48, '111': 49, '1121': 50, '1124': 51, '112': 52, '1132': 53, '1152': 54, '1154': 55, '1160': 56, '1161': 57, '1165': 58, '1166': 59, '1168': 60, '1171': 61, '1175': 62, '1179': 63, '1182': 64, '1183': 65, '1184': 66, '1187': 67, '118': 68, '1195': 69, '119': 70, '1200': 71, '1222': 72, '1224': 73, '1225': 74, '1226': 75, '122': 76, '1230': 77, '1235': 78, '1239': 79, '123': 80, '1246': 81, '1250': 82, '1252': 83, '1258': 84, 

In [16]:
import soundfile as sf

def save_audio(audio, filename):
    sf.write(filename, audio, config.sampling_rate, 'PCM_16')
audio_path = 'output.wav'
save_audio(path, audio_path)
st.audio(audio_path, sample_rate=config.sampling_rate)

DeltaGenerator()