In [None]:
import os
import numpy as np
import cv2
import torch
import audio
import subprocess
from tqdm import tqdm
import face_detection
from models import Wav2Lip

mel_step_size = 16
def _load(checkpoint_path):
    """加载模型权重（支持CPU/GPU）"""
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    """初始化并加载Wav2Lip模型"""
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()
def get_smoothened_boxes(boxes, T):
    for i in range(len(boxes)):
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i : i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes

def face_detect(images, pads, device, batch_size, nosmooth):
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
    predictions = []
    while True:
        try:
            for i in tqdm(range(0, len(images), batch_size), desc="人脸检测"):
                predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
            break
        except RuntimeError:
            if batch_size == 1:
                raise RuntimeError('图像过大，需降低分辨率')
            batch_size //= 2
            print(f'调整人脸检测批次大小为 {batch_size}')
    results = []
    pady1, pady2, padx1, padx2 = pads
    for rect, image in zip(predictions, images):
        if rect is None:
            cv2.imwrite('temp/faulty_frame.jpg', image)
            raise ValueError('未检测到人脸！请确保视频含有人脸')
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
        results.append([x1, y1, x2, y2])
    boxes = np.array(results)
    if not nosmooth:
        boxes = get_smoothened_boxes(boxes, T=5)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
    del detector
    return results

def datagen(frames, mels, box, static, face_det_batch_size, pads, nosmooth, img_size, device):
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if box[0] == -1:
        if not static:
            face_det_results = face_detect(frames, pads, device, face_det_batch_size, nosmooth)
        else:
            face_det_results = face_detect([frames[0]], pads, device, face_det_batch_size, nosmooth)
    else:
        print('使用手动指定人脸框')
        y1, y2, x1, x2 = box
        face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
    for i, m in enumerate(mels):
        idx = 0 if static else i % len(frames)
        frame_to_save = frames[idx].copy()
        face, coords = face_det_results[idx].copy()
        face = cv2.resize(face, (img_size, img_size))
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= 128:  # 默认批次128，实际可传参
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def prepare_wav2lip_data(audio_path, face_path,
                         pads=[0,10,0,0], box=[-1,-1,-1,-1], crop=[0,-1,0,-1],
                         static=False, fps=25, face_det_batch_size=16, wav2lip_batch_size=128,
                         resize_factor=1, rotate=False, nosmooth=False, img_size=96, device='cuda'):
    # 1. 读帧
    if not os.path.isfile(face_path):
        raise ValueError('人脸文件路径错误')
    if face_path.split('.')[-1].lower() in ['jpg', 'png', 'jpeg']:
        full_frames = [cv2.imread(face_path)]
        if full_frames[0] is None:
            raise ValueError(f"无法读取图片: {face_path}")
        static = True
    else:
        video_stream = cv2.VideoCapture(face_path)
        fps = video_stream.get(cv2.CAP_PROP_FPS)
        full_frames = []
        while True:
            still_reading, frame = video_stream.read()
            if not still_reading:
                video_stream.release()
                break
            if resize_factor > 1:
                frame = cv2.resize(frame, (frame.shape[1] // resize_factor, frame.shape[0] // resize_factor))
            if rotate:
                frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
            y1, y2, x1, x2 = crop
            if x2 == -1: x2 = frame.shape[1]
            if y2 == -1: y2 = frame.shape[0]
            frame = frame[y1:y2, x1:x2]
            full_frames.append(frame)
    # 2. 读音频
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    # 3. 分帧
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    full_frames = full_frames[:len(mel_chunks)]
    # 4. 返回数据生成器和帧等
    gen = datagen(full_frames.copy(), mel_chunks, box, static, face_det_batch_size, pads, nosmooth, img_size, device)
    return gen, full_frames, fps

# 用法示例
if __name__ == '__main__':
    # 参数示例
    audio_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.wav"
    face_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    check_path=r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model=load_model(check_path)
    # 通过该函数获取可用于模型推理的数据生成器
    gen, full_frames, fps = prepare_wav2lip_data(
        audio_path, face_path,
        pads=[0, 10, 0, 0],
        box=[-1, -1, -1, -1],
        crop=[0, -1, 0, -1],
        static=False,
        fps=25,
        face_det_batch_size=16,
        wav2lip_batch_size=128,
        resize_factor=1,
        rotate=False,
        nosmooth=False,
        img_size=96,
        device=device
    )
    # 直接可用于模型推理
    for i, (img_batch, mel_batch, frames, coords) in enumerate(
        tqdm(gen, total=int(np.ceil(len(full_frames) / 128)), desc="处理进度")):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

        with torch.no_grad():
            pred = model(mel_batch, img_batch)

        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)

        for p, f, c in zip(pred, frames, coords):
            y1, y2, x1, x2 = c
            p = cv2.resize(p, (x2 - x1, y2 - y1))
            f[y1:y2, x1:x2] = p
            f = imag = cv2.resize(f, (224, 336))
            cv2.imshow('image', f)
            cv2.waitKey(1)


In [None]:
# 导入必要的库
from os import listdir, path
import numpy as np  # 数值计算库
import scipy, cv2, os, sys, audio  # 科学计算、OpenCV、系统操作、音频处理
import json, subprocess, random, string  # JSON处理、子进程调用、随机数生成
from tqdm import tqdm  # 进度条库
from glob import glob  # 文件路径匹配
import torch, face_detection  # PyTorch框架、人脸检测库
from models import Wav2Lip  # 导入自定义的Wav2Lip模型
import platform  # 系统平台信息获取
imag=r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
imag=cv2.imread(imag)
imag=cv2.resize(imag,(224,336))
cv2.namedWindow('image')
cv2.resizeWindow('image', 300, 500)
cv2.imshow('image', imag)
cv2.waitKey(2000)
# ------------------ 参数设置区（无需命令行，直接修改这里即可） ------------------
class Args:
    # 必需参数
    checkpoint_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints/wav2lip.pth"  # 模型权重文件路径
    face = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"                     # 输入人脸视频/图像路径
    audio = r"D:\coding\projects\Python\human\Wav2Lip\input\1.wav"                   # 输入音频/视频文件路径
    outfile = 'results/woowowo.mp4'          # 输出视频路径

    # 可选参数
    static = False                 # 是否为静态图像输入
    fps = 25.0                     # 静态图像帧率
    pads = [0, 10, 0, 0]           # 人脸检测 padding
    face_det_batch_size = 16       # 人脸检测批次大小
    wav2lip_batch_size = 128       # Wav2Lip模型批次大小
    resize_factor = 1              # 分辨率缩放因子
    crop = [0, -1, 0, -1]          # 视频裁剪区域
    box = [-1, -1, -1, -1]         # 手动人脸框坐标
    rotate = False                 # 是否旋转视频
    nosmooth = False               # 禁用人脸检测平滑
    img_size = 96                  # 设置模型输入图像尺寸为96x96

args = Args

# 判断输入是否为静态图像
if os.path.isfile(args.face) and args.face.split('.')[-1].lower() in ['jpg', 'png', 'jpeg']:
    args.static = True

# -------------------------- 人脸检测相关函数 --------------------------
def get_smoothened_boxes(boxes, T):
    """平滑人脸检测框，减少相邻帧抖动"""
    for i in range(len(boxes)):
        # 取滑动窗口内的检测框求平均
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i: i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes

def face_detect(images):
    """人脸检测函数，返回检测到的人脸区域"""
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
    batch_size = args.face_det_batch_size
    predictions = []

    while True:
        try:
            for i in tqdm(range(0, len(images), batch_size), desc="人脸检测"):
                predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
            break
        except RuntimeError:
            if batch_size == 1:
                raise RuntimeError('图像过大，需降低分辨率')
            batch_size //= 2
            print(f'调整人脸检测批次大小为 {batch_size}')

    results = []
    pady1, pady2, padx1, padx2 = args.pads
    for rect, image in zip(predictions, images):
        if rect is None:
            cv2.imwrite('temp/faulty_frame.jpg', image)
            raise ValueError('未检测到人脸！请确保视频含有人脸')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    if not args.nosmooth:
        boxes = get_smoothened_boxes(boxes, T=5)

    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
    del detector
    return results

# -------------------------- 数据生成器函数 --------------------------
def datagen(frames, mels):
    """生成模型输入数据批次"""
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if args.box[0] == -1:
        if not args.static:
            face_det_results = face_detect(frames)
        else:
            face_det_results = face_detect([frames[0]])
    else:
        print('使用手动指定人脸框')
        y1, y2, x1, x2 = args.box
        face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]

    for i, m in enumerate(mels):
        idx = 0 if args.static else i % len(frames)
        frame_to_save = frames[idx].copy()
        face, coords = face_det_results[idx].copy()
        face = cv2.resize(face, (args.img_size, args.img_size))

        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)

        if len(img_batch) >= args.wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, args.img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, args.img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

# -------------------------- 模型加载相关 --------------------------
mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'使用 {device} 进行推理')

def _load(checkpoint_path):
    """加载模型权重（支持CPU/GPU）"""
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    """初始化并加载Wav2Lip模型"""
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

# -------------------------- 主函数 --------------------------
def main():
    if not os.path.isfile(args.face):
        raise ValueError('人脸文件路径错误')
    #载入模型
    model = load_model(args.checkpoint_path)

    if args.face.split('.')[-1].lower() in ['jpg', 'png', 'jpeg']:
        full_frames = [cv2.imread(args.face)]
        fps = args.fps
    else:
        video_stream = cv2.VideoCapture(args.face)
        fps = video_stream.get(cv2.CAP_PROP_FPS)
        print('读取视频帧...')
        full_frames = []
        while True:
            still_reading, frame = video_stream.read()
            if not still_reading:
                video_stream.release()
                break
            if args.resize_factor > 1:
                frame = cv2.resize(frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor))
            if args.rotate:
                frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
            y1, y2, x1, x2 = args.crop
            if x2 == -1: x2 = frame.shape[1]
            if y2 == -1: y2 = frame.shape[0]
            frame = frame[y1:y2, x1:x2]
            full_frames.append(frame)

    print(f"可用推理帧数: {len(full_frames)}")

    if not args.audio.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i {args.audio} -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        args.audio = 'temp/temp.wav'

    wav = audio.load_wav(args.audio, 16000)
    mel = audio.melspectrogram(wav)
    print(f"梅尔频谱形状: {mel.shape}")

    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')

    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    print(f"梅尔块数量: {len(mel_chunks)}")

    full_frames = full_frames[:len(mel_chunks)]
    batch_size = args.wav2lip_batch_size
    gen = datagen(full_frames.copy(), mel_chunks)

    """out = None
    frame_h, frame_w = full_frames[0].shape[:-1]
    fourcc = cv2.VideoWriter_fourcc(*'DIVX')
    out = cv2.VideoWriter('temp/result.avi', fourcc, fps, (frame_w, frame_h))"""
    for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, total=int(np.ceil(len(mel_chunks) / batch_size)), desc="处理进度")):




        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

        with torch.no_grad():
            pred = model(mel_batch, img_batch)

        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)

        for p, f, c in zip(pred, frames, coords):
            y1, y2, x1, x2 = c
            p = cv2.resize(p, (x2 - x1, y2 - y1))
            f[y1:y2, x1:x2] = p
            f=imag=cv2.resize(f,(224,336))
            cv2.imshow('image', f)
            cv2.waitKey(1)

            #out.write(f)

    """if out is not None:
        out.release()

    print('合并音视频...')
    command = f'ffmpeg -y -i {args.audio} -i temp/result.avi -strict -2 -q:v 1 {args.outfile}'
    subprocess.call(command, shell=platform.system() != 'Windows')"""

if __name__ == '__main__':
    import time
    start_time = time.time()
    main()
    print(f"总耗时: {time.time() - start_time:.2f} 秒")

In [6]:
import os
import time

import numpy as np
import cv2
import torch
import audio
import subprocess
from tqdm import tqdm
import face_detection
from models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path):
    """加载模型权重（支持CPU/GPU）"""
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    """初始化并加载Wav2Lip模型"""
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    # 加载图片
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    # 人脸检测
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (0, img_size, 0, img_size)  # 对静态图直接用(0,96,0,96)
    return face, coords

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    # 1. 读音频
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    # 2. mel切块
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    # 3. 生成batch
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords  # (0, img_size, 0, img_size)
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def wav2lip_infer(
    model,
    gen,
    device,
    batch_size=128,
    window_size=(224, 336),
    show_window=True,
    window_name="Wav2Lip Result"
):
    global a
    """
    执行Wav2Lip模型推理，遍历gen生成的数据batch并实时展示或处理结果（静态图片流程）。
    """
    for i, (img_batch, mel_batch, frames, coords) in enumerate(
        tqdm(gen, desc="处理进度")
    ):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

        with torch.no_grad():
            pred = model(mel_batch, img_batch)

        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)

        for p, f, c in zip(pred, frames, coords):
            y1, y2, x1, x2 = c
            # 这里coords=(0,96,0,96)，就是整张96x96小图
            p = cv2.resize(p, (x2 - x1, y2 - y1))
            f[y1:y2, x1:x2] = p
            f_disp = cv2.resize(f, window_size)
            if show_window:
                cv2.imshow(window_name, f_disp)
                #print(time.time()-a,"------")
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    cv2.destroyAllWindows()
                    return
    if show_window:
        cv2.destroyAllWindows()

In [4]:

if __name__ == '__main__':
    check_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    img_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    audio_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.wav"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    a=time.time()
    model = load_model(check_path)
    print(time.time() - a)
    a=time.time()



    face_img, face_coords = preprocess_image(img_path, device=device)
    print(time.time() - a)

从 D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth 加载模型
1.32600998878479
1.024045467376709


In [7]:
    a=time.time()
    gen = prepare_audio_batches(audio_path, face_img, face_coords)
    print(time.time() - a)
    a=time.time()

    wav2lip_infer(model, gen, device)
    print(time.time() - a)

处理进度: 0it [00:00, ?it/s]

0.0


处理进度: 12it [00:14,  1.20s/it]

14.467391967773438





In [13]:
import os
import time

import numpy as np
import cv2
import torch
import audio
import subprocess
from tqdm import tqdm
import face_detection
from models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path):
    """加载模型权重（支持CPU/GPU）"""
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    """初始化并加载Wav2Lip模型"""
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (y1, y2, x1, x2)
    return face, coords, image

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords  # (y1, y2, x1, x2) for the original image
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def wav2lip_infer(
    model,
    gen,
    device,
    orig_image,
    coords,
    batch_size=128,
    window_size=(224, 336),
    show_window=True,
    window_name="Wav2Lip Result"
):
    global b
    for i, (img_batch, mel_batch, frames, coords_batch) in enumerate(
        tqdm(gen, desc="处理进度")
    ):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

        with torch.no_grad():
            pred = model(mel_batch, img_batch)
        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)

        for p in pred:
            y1, y2, x1, x2 = coords
            h, w = y2 - y1, x2 - x1
            if h <= 0 or w <= 0:
                print("Invalid coords:", coords)
                continue
            p_resized = cv2.resize(p, (w, h))
            show_img = orig_image.copy()
            # 用掩膜粘贴（此处直接贴回人脸区域）
            show_img[y1:y2, x1:x2] = p_resized
            show_img_disp = cv2.resize(show_img, window_size)
            if show_window:
                cv2.imshow(window_name, show_img_disp)
                print(time.time()-b,"----")

                if cv2.waitKey(1) & 0xFF == ord('q'):
                    cv2.destroyAllWindows()
                    return
    if show_window:
        cv2.destroyAllWindows()

In [9]:

if __name__ == '__main__':
    check_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    img_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    audio_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.wav"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    a=time.time()

    model = load_model(check_path)
    print(time.time() - a)
    a=time.time()
    face_img, face_coords, orig_image = preprocess_image(img_path, device=device)
    print(time.time() - a)

从 D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth 加载模型
1.032834768295288
0.8675639629364014


In [14]:
    a=time.time()
    gen = prepare_audio_batches(audio_path, face_img, face_coords)
    print(time.time() - a)
    a=time.time()
    b=time.time()
    wav2lip_infer(model, gen, device, orig_image=orig_image, coords=face_coords)
    print(time.time() - a)

处理进度: 0it [00:00, ?it/s]

0.0
0.7216842174530029 ----
0.7306890487670898 ----
0.7537441253662109 ----
0.7637538909912109 ----
0.7741804122924805 ----
0.7741804122924805 ----
0.7964823246002197 ----
0.8066043853759766 ----
0.8169867992401123 ----
0.8274562358856201 ----
0.8380670547485352 ----
0.8501057624816895 ----
0.8609046936035156 ----
0.8720381259918213 ----
0.8800251483917236 ----
0.8906211853027344 ----
0.900759220123291 ----
0.9115715026855469 ----
0.9208118915557861 ----
0.935631275177002 ----
0.9473636150360107 ----
0.9574370384216309 ----
0.9675302505493164 ----
0.973548173904419 ----
0.9840149879455566 ----
0.9940829277038574 ----
1.004086971282959 ----
1.0140957832336426 ----
1.024106502532959 ----
1.0341150760650635 ----
1.0481393337249756 ----
1.0541324615478516 ----
1.0642547607421875 ----
1.0742661952972412 ----
1.0862843990325928 ----
1.0947074890136719 ----
1.1049871444702148 ----
1.1151738166809082 ----
1.1251804828643799 ----
1.1394896507263184 ----
1.1494879722595215 ----
1.159334421157837

处理进度: 1it [00:02,  2.03s/it]

1.9653503894805908 ----
1.9753570556640625 ----
1.9893746376037598 ----
1.9953858852386475 ----
2.0054023265838623 ----
2.0154306888580322 ----
2.025454044342041 ----
2.2490127086639404 ----
2.2593436241149902 ----
2.2693519592285156 ----
2.281252861022949 ----
2.2955474853515625 ----
2.3015544414520264 ----
2.311558485031128 ----
2.3215692043304443 ----
2.3315906524658203 ----
2.3419535160064697 ----
2.351966619491577 ----
2.36403226852417 ----
2.372039794921875 ----
2.3822128772735596 ----
2.3942313194274902 ----
2.402228355407715 ----
2.412245273590088 ----
2.4222524166107178 ----
2.434581995010376 ----
2.4442808628082275 ----
2.456587076187134 ----
2.462681531906128 ----
2.4726991653442383 ----
2.48272442817688 ----
2.4964027404785156 ----
2.5034291744232178 ----
2.5136194229125977 ----
2.5237624645233154 ----
2.5350182056427 ----
2.545856237411499 ----
2.559525728225708 ----
2.5655245780944824 ----
2.5775630474090576 ----
2.586320638656616 ----
2.5963408946990967 ----
2.6063582897

处理进度: 2it [00:03,  1.88s/it]

3.495086193084717 ----
3.5011074542999268 ----
3.5111353397369385 ----
3.5221776962280273 ----
3.5321900844573975 ----
3.5422117710113525 ----
3.555056095123291 ----
3.78037428855896 ----
3.790832281112671 ----
3.8008694648742676 ----
3.8108811378479004 ----
3.820887565612793 ----
3.831238269805908 ----
3.841662883758545 ----
3.8533098697662354 ----
3.861325740814209 ----
3.8753249645233154 ----
3.8833694458007812 ----
3.8913722038269043 ----
3.9013798236846924 ----
3.914522647857666 ----
3.92339825630188 ----
3.9334099292755127 ----
3.9459848403930664 ----
3.956009864807129 ----
3.96802020072937 ----
3.976036787033081 ----
3.9860665798187256 ----
3.9961936473846436 ----
4.006199598312378 ----
4.016249179840088 ----
4.026258230209351 ----
4.039780378341675 ----
4.0459253787994385 ----
4.0512855052948 ----
4.065926551818848 ----
4.075937509536743 ----
4.086032390594482 ----
4.096258163452148 ----
4.106276512145996 ----
4.116290807723999 ----
4.126318693161011 ----
4.13633918762207 ----


处理进度: 3it [00:05,  1.78s/it]

5.026949644088745 ----
5.0369789600372314 ----
5.049004793167114 ----
5.0570151805877686 ----
5.067426443099976 ----
5.077442169189453 ----
5.0876305103302 ----
5.312392234802246 ----
5.322707414627075 ----
5.334729433059692 ----
5.342980861663818 ----
5.352993011474609 ----
5.363447904586792 ----
5.373119354248047 ----
5.383134841918945 ----
5.393147706985474 ----
5.403159856796265 ----
5.413175344467163 ----
5.423192024230957 ----
5.437305927276611 ----
5.443463563919067 ----
5.453469514846802 ----
5.463474988937378 ----
5.473923444747925 ----
5.483933925628662 ----
5.494425535202026 ----
5.5039637088775635 ----
5.516392946243286 ----
5.524017572402954 ----
5.534030437469482 ----
5.544190168380737 ----
5.558276414871216 ----
5.564276695251465 ----
5.574563503265381 ----
5.586867094039917 ----
5.594867706298828 ----
5.60497784614563 ----
5.617141008377075 ----
5.626981258392334 ----
5.6318583488464355 ----
5.649573087692261 ----
5.659590005874634 ----
5.6696202754974365 ----
5.6800539

处理进度: 4it [00:06,  1.71s/it]

6.554898500442505 ----
6.5610339641571045 ----
6.571041822433472 ----
6.581058025360107 ----
6.591082334518433 ----
6.601525068283081 ----
6.616572380065918 ----
6.621728420257568 ----
6.631732702255249 ----
6.861947774887085 ----
6.874914646148682 ----
6.885297060012817 ----
6.897595405578613 ----
6.917247295379639 ----
6.931677579879761 ----
6.947032690048218 ----
6.9566969871521 ----
6.966705799102783 ----
6.976722955703735 ----
6.986741542816162 ----
6.997115135192871 ----
7.009826898574829 ----
7.017364740371704 ----
7.027378559112549 ----
7.037399530410767 ----
7.047406911849976 ----
7.059431076049805 ----
7.067440032958984 ----
7.081916093826294 ----
7.087922811508179 ----
7.1023924350738525 ----
7.110394716262817 ----
7.120535612106323 ----
7.129572868347168 ----
7.141598463058472 ----
7.149637699127197 ----
7.16002893447876 ----
7.17004132270813 ----
7.182065963745117 ----
7.190079212188721 ----
7.200093507766724 ----
7.2101216316223145 ----
7.22012996673584 ----
7.23017811775

处理进度: 5it [00:08,  1.67s/it]

8.102350950241089 ----
8.105459690093994 ----
8.115497350692749 ----
8.125515699386597 ----
8.13596796989441 ----
8.14614462852478 ----
8.156489372253418 ----
8.176856279373169 ----
8.18691110610962 ----
8.196928262710571 ----
8.208954095840454 ----
8.218875408172607 ----
8.441121101379395 ----
8.453948259353638 ----
8.470726251602173 ----
8.489731311798096 ----
8.50364875793457 ----
8.50974726676941 ----
8.520033836364746 ----
8.53004765510559 ----
8.54006314277649 ----
8.550081968307495 ----
8.560084819793701 ----
8.570107221603394 ----
8.580123901367188 ----
8.590158939361572 ----
8.600176811218262 ----
8.610187768936157 ----
8.622318744659424 ----
8.632338285446167 ----
8.64053750038147 ----
8.651031970977783 ----
8.661165237426758 ----
8.67118525505066 ----
8.681193828582764 ----
8.691622972488403 ----
8.70138692855835 ----
8.711400032043457 ----
8.721420764923096 ----
8.731435537338257 ----
8.741477489471436 ----
8.751511096954346 ----
8.763551712036133 ----
8.781689643859863 ---

处理进度: 6it [00:09,  1.64s/it]

9.679163694381714 ----
9.68541932106018 ----
9.695437908172607 ----
9.705455780029297 ----
9.715871572494507 ----
9.726048231124878 ----
9.738256216049194 ----
9.746248960494995 ----
9.758266925811768 ----
9.77629017829895 ----
9.787088394165039 ----
9.797143459320068 ----
10.02816128730774 ----
10.038169860839844 ----
10.048282146453857 ----
10.060306549072266 ----
10.068320989608765 ----
10.080460548400879 ----
10.08852744102478 ----
10.098538875579834 ----
10.108551979064941 ----
10.118557929992676 ----
10.139065742492676 ----
10.153203964233398 ----
10.159204721450806 ----
10.16920804977417 ----
10.179218053817749 ----
10.193256616592407 ----
10.19938850402832 ----
10.20942234992981 ----
10.219735860824585 ----
10.23311185836792 ----
10.240138292312622 ----
10.250144004821777 ----
10.260154962539673 ----
10.270159244537354 ----
10.280167818069458 ----
10.29018521308899 ----
10.300202369689941 ----
10.310217142105103 ----
10.320364952087402 ----
10.330370664596558 ----
10.3403797149

处理进度: 7it [00:11,  1.63s/it]

11.269317865371704 ----
11.279030323028564 ----
11.28948450088501 ----
11.299506664276123 ----
11.309850692749023 ----
11.319967269897461 ----
11.329972267150879 ----
11.342030763626099 ----
11.350027799606323 ----
11.36040472984314 ----
11.37051248550415 ----
11.38121247291565 ----
11.60758638381958 ----
11.619598388671875 ----
11.627609491348267 ----
11.637622594833374 ----
11.64764142036438 ----
11.657845735549927 ----
11.667861938476562 ----
11.678088665008545 ----
11.690487146377563 ----
11.700738430023193 ----
11.70841360092163 ----
11.718823909759521 ----
11.72885799407959 ----
11.73888111114502 ----
11.748898983001709 ----
11.758914470672607 ----
11.769226312637329 ----
11.779554605484009 ----
11.789997100830078 ----
11.800437927246094 ----
11.810450553894043 ----
11.820550441741943 ----
11.832584381103516 ----
11.840084791183472 ----
11.850100994110107 ----
11.86056900024414 ----
11.87261176109314 ----
11.884415626525879 ----
11.890382289886475 ----
11.902419567108154 ----
11.

处理进度: 8it [00:12,  1.61s/it]

12.84788179397583 ----
12.858232736587524 ----
12.868499517440796 ----
12.881017684936523 ----
12.889501571655273 ----
12.911750555038452 ----
12.919515132904053 ----
12.929736852645874 ----
12.939748764038086 ----
13.189582586288452 ----
13.199586868286133 ----
13.209752798080444 ----
13.219771146774292 ----
13.229872226715088 ----
13.239903211593628 ----
13.260460615158081 ----
13.27281904220581 ----
13.280831813812256 ----
13.29121470451355 ----
13.300922155380249 ----
13.314786672592163 ----
13.32297682762146 ----
13.331016302108765 ----
13.341142654418945 ----
13.353586196899414 ----
13.361613988876343 ----
13.371623754501343 ----
13.384751081466675 ----
13.401554822921753 ----
13.417024850845337 ----
13.42312741279602 ----
13.433692693710327 ----
13.45397400856018 ----
13.464471817016602 ----
13.474496364593506 ----
13.484906911849976 ----
13.494930028915405 ----
13.507215738296509 ----
13.53094744682312 ----
13.545873403549194 ----
13.56189775466919 ----
13.571909427642822 ----


处理进度: 9it [00:14,  1.62s/it]

14.4447340965271 ----
14.455212354660034 ----
14.46723198890686 ----
14.475545883178711 ----
14.485567808151245 ----
14.497865915298462 ----
14.509652137756348 ----
14.521694898605347 ----
14.533662796020508 ----
14.542586326599121 ----
14.5584397315979 ----
14.570865154266357 ----
14.587706804275513 ----
14.601625680923462 ----
14.824790239334106 ----
14.832798480987549 ----
14.842810869216919 ----
14.853212118148804 ----
14.863216400146484 ----
14.87339186668396 ----
14.883416414260864 ----
14.89366626739502 ----
14.903682947158813 ----
14.91580605506897 ----
14.924017429351807 ----
14.93448781967163 ----
14.944652557373047 ----
14.954672574996948 ----
14.96869158744812 ----
14.981126070022583 ----
14.996874570846558 ----
15.005215644836426 ----
15.019260168075562 ----
15.028979301452637 ----
15.03564453125 ----
15.05754566192627 ----
15.071215391159058 ----
15.081130266189575 ----
15.091156244277954 ----
15.101181507110596 ----
15.107194423675537 ----
15.132265329360962 ----
15.1491

处理进度: 10it [00:16,  1.63s/it]

16.09319567680359 ----
16.103493452072144 ----
16.116807460784912 ----
16.127755403518677 ----
16.13545846939087 ----
16.15235733985901 ----
16.16732120513916 ----
16.176859378814697 ----
16.188177824020386 ----
16.19336700439453 ----
16.203380346298218 ----
16.21736741065979 ----
16.22365713119507 ----
16.235753059387207 ----
16.24577260017395 ----
16.50809597969055 ----
16.520601511001587 ----
16.528500080108643 ----
16.53873348236084 ----
16.54875636100769 ----
16.56079912185669 ----
16.569385528564453 ----
16.5794038772583 ----
16.58944344520569 ----
16.60147190093994 ----
16.6200008392334 ----
16.63289475440979 ----
16.650169134140015 ----
16.67665433883667 ----
16.69641375541687 ----
16.71222686767578 ----
16.716642379760742 ----
16.72698187828064 ----
16.737086534500122 ----
16.74739933013916 ----
16.75749921798706 ----
16.767587184906006 ----
16.777795553207397 ----
16.78752064704895 ----
16.799578428268433 ----
16.811472415924072 ----
16.817832708358765 ----
16.829880237579346

处理进度: 12it [00:18,  1.23s/it]

18.009217023849487 ----
18.02262282371521 ----
18.028911352157593 ----
18.038928508758545 ----
18.04893398284912 ----
18.05896830558777 ----
18.071008443832397 ----
18.08152198791504 ----
18.093534469604492 ----
18.101686000823975 ----
18.11176037788391 ----
18.121366024017334 ----
18.131374835968018 ----
18.14139747619629 ----
18.153422355651855 ----
18.16356086730957 ----
18.17356824874878 ----
18.183578729629517 ----


处理进度: 12it [00:18,  1.52s/it]

18.200161933898926





In [1]:
print(2)

2


In [1]:
import os
import time
import threading

import numpy as np
import cv2
import torch
import audio
import subprocess
from tqdm import tqdm
import sounddevice as sd
import soundfile as sf
import face_detection
from models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (y1, y2, x1, x2)
    return face, coords, image

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def play_audio_stream(audio_path, start_event, stop_event):
    data, samplerate = sf.read(audio_path)
    if len(data.shape) == 1:
        data = data[:, None]
    def callback(outdata, frames, time_, status):
        start_event.wait()
        if stop_event.is_set():
            raise sd.CallbackAbort
        chunk = data[callback.idx:callback.idx+frames]
        if len(chunk) < frames:
            outdata[:len(chunk)] = chunk
            outdata[len(chunk):] = 0
            stop_event.set()
            raise sd.CallbackStop
        else:
            outdata[:] = chunk
        callback.idx += frames
    callback.idx = 0
    with sd.OutputStream(channels=data.shape[1], samplerate=samplerate, callback=callback):
        start_event.wait()
        while not stop_event.is_set() and callback.idx < len(data):
            sd.sleep(100)

def wav2lip_sync_play(
    model,
    gen,
    device,
    orig_image,
    coords,
    audio_path,
    fps=25,
    window_size=(224, 336),
    show_window=True,
    window_name="Wav2Lip Result"
):
    global a
    start_event = threading.Event()
    stop_event = threading.Event()
    audio_thread = threading.Thread(target=play_audio_stream, args=(audio_path, start_event, stop_event))
    audio_thread.start()

    frame_interval = 1.0 / fps
    frame_count = 0
    start_time = time.time()
    start_event.set()

    for i, (img_batch, mel_batch, frames, coords_batch) in enumerate(gen):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
        with torch.no_grad():
            pred = model(mel_batch, img_batch)
        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)
        for p in pred:
            y1, y2, x1, x2 = coords
            h, w = y2 - y1, x2 - x1
            if h <= 0 or w <= 0:
                continue
            p_resized = cv2.resize(p, (w, h))
            show_img = orig_image.copy()
            show_img[y1:y2, x1:x2] = p_resized
            show_img_disp = cv2.resize(show_img, window_size)
            # 音视频时间对齐
            target_time = start_time + frame_count * frame_interval
            now = time.time()
            if now < target_time:
                time.sleep(target_time - now)
            #print(time.time() - a,"-------")
            cv2.imshow(window_name, show_img_disp)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                stop_event.set()
                cv2.destroyAllWindows()
                audio_thread.join()
                return
            frame_count += 1
    stop_event.set()
    audio_thread.join()
    cv2.destroyAllWindows()

In [2]:

if __name__ == '__main__':
    check_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    img_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    audio_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.wav"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = load_model(check_path)
    face_img, face_coords, orig_image = preprocess_image(img_path, device=device)

从 D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth 加载模型


In [3]:
    a=time.time()
    gen = prepare_audio_batches(audio_path, face_img, face_coords)
    print(time.time() - a)
    a = time.time()
    wav2lip_sync_play(model, gen, device, orig_image=orig_image, coords=face_coords, audio_path=audio_path)

0.0


KeyboardInterrupt: 

In [1]:
import os
import time
import threading

import numpy as np
import cv2
import torch
import Wav2Lip.audio as audio
import subprocess
from tqdm import tqdm
import sounddevice as sd
import soundfile as sf
import Wav2Lip.face_detection as face_detection
from Wav2Lip.models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (y1, y2, x1, x2)
    return face, coords, image

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'.format(audio_path)
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def play_audio_stream(audio_path, start_event, stop_event):
    data, samplerate = sf.read(audio_path)
    if len(data.shape) == 1:
        data = data[:, None]
    def callback(outdata, frames, time_, status):
        start_event.wait()
        if stop_event.is_set():
            raise sd.CallbackAbort
        chunk = data[callback.idx:callback.idx+frames]
        if len(chunk) < frames:
            outdata[:len(chunk)] = chunk
            outdata[len(chunk):] = 0
            stop_event.set()
            raise sd.CallbackStop
        else:
            outdata[:] = chunk
        callback.idx += frames
    callback.idx = 0
    with sd.OutputStream(channels=data.shape[1], samplerate=samplerate, callback=callback):
        start_event.wait()
        while not stop_event.is_set() and callback.idx < len(data):
            sd.sleep(100)

def wav2lip_sync_play(
    model,
    gen,
    device,
    orig_image,
    coords,
    audio_path,
    fps=25,
    window_size=(224, 336),
    show_window=True,
    window_name="Wav2Lip Result"
):
    start_event = threading.Event()
    stop_event = threading.Event()
    audio_thread = threading.Thread(target=play_audio_stream, args=(audio_path, start_event, stop_event))
    audio_thread.start()

    frame_interval = 1.0 / fps
    frame_count = 0
    start_time = time.time()
    start_event.set()

    for i, (img_batch, mel_batch, frames, coords_batch) in enumerate(gen):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
        with torch.no_grad():
            pred = model(mel_batch, img_batch)
        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)
        for p in pred:
            y1, y2, x1, x2 = coords
            h, w = y2 - y1, x2 - x1
            if h <= 0 or w <= 0:
                continue
            p_resized = cv2.resize(p, (w, h))
            show_img = orig_image.copy()
            show_img[y1:y2, x1:x2] = p_resized
            show_img_disp = cv2.resize(show_img, window_size)
            # 音视频时间对齐
            target_time = start_time + frame_count * frame_interval
            now = time.time()
            if now < target_time:
                time.sleep(target_time - now)
            cv2.imshow(window_name, show_img_disp)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                stop_event.set()
                cv2.destroyAllWindows()
                audio_thread.join()
                return
            frame_count += 1
    stop_event.set()
    audio_thread.join()
    cv2.destroyAllWindows()

def idle_display(orig_image, window_size=(224, 336), window_name="Wav2Lip Result"):
    show_img_disp = cv2.resize(orig_image, window_size)
    cv2.imshow(window_name, show_img_disp)
    while True:
        key = cv2.waitKey(100)
        if key == ord('q'):
            cv2.destroyAllWindows()
            return False  # quit
        elif key != -1:
            break
    return True  # got some input

In [2]:

if __name__ == '__main__':
    check_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    img_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = load_model(check_path)
    face_img, face_coords, orig_image = preprocess_image(img_path, device=device)

    print("==== Wav2Lip 音频驱动图像同步播放 ====")
    print("按 'q' 退出；其余任意键输入音频路径并播放合成结果。")

    while True:
        # 空闲时持续展示图片
        if not idle_display(orig_image):
            break

        audio_path = input("请输入要驱动的音频文件路径（或按回车跳过，输入 q 退出）：").strip()
        if audio_path == "" or not os.path.exists(audio_path):
            print("未输入音频或音频文件不存在，继续等待输入...")
            continue
        if audio_path.lower() == "q":
            break

        try:
            gen = prepare_audio_batches(audio_path, face_img, face_coords)
            wav2lip_sync_play(model, gen, device, orig_image=orig_image, coords=face_coords, audio_path=audio_path)
        except Exception as e:
            print(f"音频处理或合成异常: {e}")

    print("程序已退出。")

从 D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth 加载模型
==== Wav2Lip 音频驱动图像同步播放 ====
按 'q' 退出；其余任意键输入音频路径并播放合成结果。


KeyboardInterrupt: 

In [4]:
import os
import time
import threading

import numpy as np
import cv2
import torch
import Wav2Lip.audio as audio
import subprocess
import sounddevice as sd
import soundfile as sf
import Wav2Lip.face_detection as face_detection
from Wav2Lip.models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path):
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (y1, y2, x1, x2)
    return face, coords, image

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    if not audio_path.endswith('.wav'):
        print('提取音频...')
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

def play_audio_stream(audio_path, start_event, stop_event):
    data, samplerate = sf.read(audio_path)
    if len(data.shape) == 1:
        data = data[:, None]
    def callback(outdata, frames, time_, status):
        start_event.wait()
        if stop_event.is_set():
            raise sd.CallbackAbort
        chunk = data[callback.idx:callback.idx+frames]
        if len(chunk) < frames:
            outdata[:len(chunk)] = chunk
            outdata[len(chunk):] = 0
            stop_event.set()
            raise sd.CallbackStop
        else:
            outdata[:] = chunk
        callback.idx += frames
    callback.idx = 0
    with sd.OutputStream(channels=data.shape[1], samplerate=samplerate, callback=callback):
        start_event.wait()
        while not stop_event.is_set() and callback.idx < len(data):
            sd.sleep(100)

def wav2lip_sync_play(
    model,
    gen,
    device,
    orig_image,
    coords,
    audio_path,
    fps=25,
    window_size=(224, 336),
    show_window=True,
    window_name="Wav2Lip Result"
):
    start_event = threading.Event()
    stop_event = threading.Event()
    audio_thread = threading.Thread(target=play_audio_stream, args=(audio_path, start_event, stop_event))
    audio_thread.start()

    frame_interval = 1.0 / fps
    frame_count = 0
    start_time = time.time()
    start_event.set()

    for i, (img_batch, mel_batch, frames, coords_batch) in enumerate(gen):
        img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
        mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
        with torch.no_grad():
            pred = model(mel_batch, img_batch)
        pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
        pred = pred.astype(np.uint8)
        for p in pred:
            y1, y2, x1, x2 = coords
            h, w = y2 - y1, x2 - x1
            if h <= 0 or w <= 0:
                continue
            p_resized = cv2.resize(p, (w, h))
            show_img = orig_image.copy()
            show_img[y1:y2, x1:x2] = p_resized
            show_img_disp = cv2.resize(show_img, window_size)
            # 音视频时间对齐
            target_time = start_time + frame_count * frame_interval
            now = time.time()
            if now < target_time:
                time.sleep(target_time - now)
            cv2.imshow(window_name, show_img_disp)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                stop_event.set()
                cv2.destroyAllWindows()
                audio_thread.join()
                return
            frame_count += 1
    stop_event.set()
    audio_thread.join()
    cv2.destroyAllWindows()

def show_image_idle(orig_image, window_size=(224, 336), window_name="Wav2Lip"):
    img = cv2.resize(orig_image, window_size)
    cv2.imshow(window_name, img)
    # 只刷新，不阻塞主线程，需外部循环配合
    cv2.waitKey(1)

def wait_for_audio_path():
    audio_path = input("请输入要讲解的音频文件路径（或输入 q 退出）：").strip()
    if audio_path.lower() == "q":
        return None
    if not os.path.exists(audio_path):
        print("音频文件不存在，请重新输入。")
        return ""
    return audio_path

In [5]:

if __name__ == '__main__':
    check_path = r"D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth"
    img_path = r"D:\coding\projects\Python\human\Wav2Lip\input\1.png"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = load_model(check_path)
    face_img, face_coords, orig_image = preprocess_image(img_path, device=device)


从 D:\coding\projects\Python\human\Wav2Lip\checkpoints\wav2lip_gan.pth 加载模型


In [6]:
    window_name = "Wav2Lip 自助讲解"
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)

    print("==== 自助讲解员系统 ====")
    print("系统空闲时自动展示图片，收到音频地址后自动讲解并口型同步。按 'q' 退出。")

    while True:
        # 持续展示图片，直到有音频输入
        print("空闲展示中，等待输入讲解音频路径 ...")
        while True:
            show_image_idle(orig_image, window_name=window_name)
            # 检查窗口是否被关闭/按下q
            key = cv2.waitKey(100)
            if key == ord('q'):
                print("用户退出。")
                cv2.destroyAllWindows()
                exit(0)
            # 检查命令行输入是否就绪（非阻塞方式，提示用户切到命令行输入）
            # 推荐只在命令行输入后进入下一步
            break

        # 等待用户输入音频路径
        audio_path = wait_for_audio_path()
        if audio_path is None:
            break
        if not audio_path:
            continue

        try:
            gen = prepare_audio_batches(audio_path, face_img, face_coords)
            wav2lip_sync_play(
                model, gen, device,
                orig_image=orig_image, coords=face_coords,
                audio_path=audio_path,
                window_name=window_name
            )
        except Exception as e:
            print(f"音频处理或合成异常: {e}")

    cv2.destroyAllWindows()
    print("讲解系统已关闭。")

==== 自助讲解员系统 ====
系统空闲时自动展示图片，收到音频地址后自动讲解并口型同步。按 'q' 退出。
空闲展示中，等待输入讲解音频路径 ...
空闲展示中，等待输入讲解音频路径 ...


KeyboardInterrupt: Interrupted by user

In [None]:
#face
import os
import numpy as np
import cv2
import torch
import threading
import time
import Wav2Lip.audio as audio
import subprocess
import sounddevice as sd
import soundfile as sf
import Wav2Lip.face_detection as face_detection
from Wav2Lip.models import Wav2Lip

mel_step_size = 16

def _load(checkpoint_path, device):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint

def load_model(path, device):
    model = Wav2Lip()
    print(f"从 {path} 加载模型")
    checkpoint = _load(path, device)
    new_s = {}
    for k, v in checkpoint["state_dict"].items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

def preprocess_image(
    image_path,
    pads=[0,10,0,0],
    box=[-1,-1,-1,-1],
    img_size=96,
    face_det_batch_size=16,
    nosmooth=False,
    device='cuda'
):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"无法读取图片: {image_path}")
    if box[0] == -1:
        detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
        predictions = detector.get_detections_for_batch(np.array([image]))
        rect = predictions[0]
        if rect is None:
            raise RuntimeError("未检测到人脸")
        pady1, pady2, padx1, padx2 = pads
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
    else:
        y1, y2, x1, x2 = box
    face = image[y1:y2, x1:x2]
    face = cv2.resize(face, (img_size, img_size))
    coords = (y1, y2, x1, x2)
    return face, coords, image

def prepare_audio_batches(
    audio_path,
    face_img,
    face_coords,
    static=True,
    fps=25,
    mel_step_size=16,
    wav2lip_batch_size=128,
    img_size=96
):
    if not audio_path.endswith('.wav'):
        command = f'ffmpeg -y -i "{audio_path}" -strict -2 temp/temp.wav'
        subprocess.call(command, shell=True)
        audio_path = 'temp/temp.wav'
    wav = audio.load_wav(audio_path, 16000)
    mel = audio.melspectrogram(wav)
    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError('梅尔频谱包含NaN值，请检查音频质量')
    mel_idx_multiplier = 80. / fps
    mel_chunks = []
    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
        i += 1
    # 补帧：确保嘴型帧不会太短
    min_frames = int((len(mel[0]) / 80) * fps / (16000 / 80))
    if len(mel_chunks) < min_frames and len(mel_chunks) > 0:
        repeat = int(np.ceil(min_frames / len(mel_chunks)))
        mel_chunks = (mel_chunks * repeat)[:min_frames]

    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    for i, m in enumerate(mel_chunks):
        frame_to_save = face_img.copy()
        face = face_img.copy()
        coords = face_coords
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)
        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
            img_masked = img_batch.copy()
            img_masked[:, img_size // 2:] = 0
            img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
            yield img_input, mel_input, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_input = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_input = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
        yield img_input, mel_input, frame_batch, coords_batch

class LipSyncPlayer:
    def __init__(self, model, device, orig_image, face_coords, fps=25):
        self.model = model
        self.device = device
        self.orig_image = orig_image
        self.face_coords = face_coords
        self.fps = fps

    def infer_frames(self, batch_gen):
        """
        只做推理，返回(耗时, all_frames)
        """
        t_infer_start = time.perf_counter()
        all_frames = []
        gen_iter = iter(batch_gen)
        for img_batch, mel_batch, frames, coords_batch in gen_iter:
            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
            with torch.no_grad():
                pred = self.model(mel_batch, img_batch)
            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
            pred = pred.astype(np.uint8)
            for p in pred:
                y1, y2, x1, x2 = self.face_coords
                h, w = y2 - y1, x2 - x1
                if h <= 0 or w <= 0:
                    continue
                p_resized = cv2.resize(p, (w, h))
                show_img = self.orig_image.copy()
                show_img[y1:y2, x1:x2] = p_resized
                all_frames.append(show_img.copy())
        t_infer_end = time.perf_counter()
        infer_time = t_infer_end - t_infer_start
        return infer_time, all_frames

    def play_frames(self, audio_path, all_frames, frame_callback):
        """
        只做播放，返回耗时
        """
        # 加载音频数据
        data, samplerate = sf.read(audio_path)
        duration = len(data) / samplerate
        n_frames = len(all_frames)
        frame_interval = 1.0 / self.fps
        total_frames_by_audio = int(duration * self.fps)

        stop_event = threading.Event()
        def audio_thread_func():
            sd.play(data, samplerate)
            sd.wait()
            stop_event.set()
        audio_thread = threading.Thread(target=audio_thread_func)
        audio_thread.start()

        t_play_start = time.perf_counter()
        t_start = time.time()
        frame_idx = 0
        last_frame = None
        while not stop_event.is_set():
            now = time.time()
            expected_frame = int((now - t_start) * self.fps)
            if expected_frame >= total_frames_by_audio:
                break
            if frame_idx < n_frames:
                frame_callback(all_frames[frame_idx])
                last_frame = all_frames[frame_idx]
            else:
                if last_frame is not None:
                    frame_callback(last_frame)
            frame_idx += 1
            next_time = t_start + frame_idx * frame_interval
            time.sleep(max(0, next_time - time.time()))
        audio_thread.join()
        t_play_end = time.perf_counter()
        play_time = t_play_end - t_play_start
        return play_time

cf

In [1]:
#app:
import sys
import threading
import asyncio
import time
import numpy as np
import cv2
from PyQt5 import QtWidgets, QtCore, QtGui
from config import API_KEY, BASE_URL, WAV2LIP_MODEL_PATH, IDLE_IMAGE_PATH, FACE_IMAGE_PATH, DEVICE
from llm import ChatBot
from tts import generate_speech
from face import load_model, preprocess_image, prepare_audio_batches, LipSyncPlayer

class DigitalHumanUI(QtWidgets.QWidget):
    append_history_signal = QtCore.pyqtSignal(str, str)
    show_frame_signal = QtCore.pyqtSignal(np.ndarray)
    idle_signal = QtCore.pyqtSignal()
    stage_signal = QtCore.pyqtSignal(str)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("数字人问答演示")
        self.resize(960, 520)
        self.init_ui()
        self.init_resources()
        self.append_history_signal.connect(self.append_history)
        self.show_frame_signal.connect(self.show_video_frame)
        self.idle_signal.connect(self.show_idle)
        self.stage_signal.connect(self.show_stage)
        self.show_idle()
        self.show_stage("系统待命...")

    def init_ui(self):
        self.input_box = QtWidgets.QLineEdit(self)
        self.input_box.setPlaceholderText("请输入你的问题...")
        self.send_btn = QtWidgets.QPushButton("发送", self)
        self.send_btn.clicked.connect(self.on_submit)
        self.chat_history = QtWidgets.QTextEdit(self)
        self.chat_history.setReadOnly(True)
        self.chat_history.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        self.face_label = QtWidgets.QLabel(self)
        self.face_label.setFixedSize(336, 448)
        self.stage_label = QtWidgets.QTextEdit(self)
        self.stage_label.setReadOnly(True)
        self.stage_label.setAlignment(QtCore.Qt.AlignLeft)
        self.stage_label.setStyleSheet("color:blue; font-size:14px;")
        self.stage_label.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        left_panel = QtWidgets.QVBoxLayout()
        left_panel.addWidget(self.chat_history)
        left_panel.addWidget(self.stage_label)
        input_layout = QtWidgets.QHBoxLayout()
        input_layout.addWidget(self.input_box)
        input_layout.addWidget(self.send_btn)
        left_panel.addLayout(input_layout)
        main_layout = QtWidgets.QHBoxLayout(self)
        main_layout.addLayout(left_panel, 2)
        main_layout.addWidget(self.face_label, 1)
        self.setLayout(main_layout)

    def init_resources(self):
        self.bot = ChatBot(
            api_key=API_KEY,
            base_url=BASE_URL,
            log_dir="logs",
            default_background="你是一个知识渊博的助手，能够简洁地回答问题。",
            default_prefix="请简洁地回答下述问题："
        )
        self.model = load_model(WAV2LIP_MODEL_PATH, DEVICE)
        self.face_img, self.face_coords, self.orig_image = preprocess_image(FACE_IMAGE_PATH, device=DEVICE)
        self.idle_img = cv2.imread(IDLE_IMAGE_PATH)
        self.lip_player = LipSyncPlayer(self.model, DEVICE, self.orig_image, self.face_coords, fps=25)

    def show_idle(self):
        img = cv2.cvtColor(self.idle_img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (336, 448))
        qimg = QtGui.QImage(img.data, img.shape[1], img.shape[0], QtGui.QImage.Format_RGB888)
        pix = QtGui.QPixmap.fromImage(qimg)
        self.face_label.setPixmap(pix)

    @QtCore.pyqtSlot(np.ndarray)
    def show_video_frame(self, frame):
        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (336, 448))
        qimg = QtGui.QImage(img.data, img.shape[1], img.shape[0], QtGui.QImage.Format_RGB888)
        pix = QtGui.QPixmap.fromImage(qimg)
        self.face_label.setPixmap(pix)

    @QtCore.pyqtSlot(str, str)
    def append_history(self, speaker, text):
        self.chat_history.append(f"<b>{speaker}：</b>{text}")

    @QtCore.pyqtSlot(str)
    def show_stage(self, text):
        self.stage_label.append(text)
        self.stage_label.verticalScrollBar().setValue(self.stage_label.verticalScrollBar().maximum())

    def on_submit(self):
        question = self.input_box.text().strip()
        if not question:
            return
        self.input_box.setText("")
        self.append_history("用户", question)
        self.show_idle()
        self.show_stage("开始处理...")
        threading.Thread(target=self.process_conversation, args=(question,)).start()

    def process_conversation(self, question):
        t0 = time.perf_counter()
        self.stage_signal.emit("等待大模型回复...")
        t1 = time.perf_counter()
        answer = self.bot.chat(question)
        t2 = time.perf_counter()
        self.append_history_signal.emit("助手", answer)
        self.stage_signal.emit(f"大模型回复完成，耗时：{t2-t1:.2f}s")
        self.stage_signal.emit("正在合成语音...")

        # TTS生成
        if not answer or len(answer.strip()) < 2:
            self.stage_signal.emit("回答内容过短，跳过语音与口型合成。")
            time.sleep(0.5)
            self.idle_signal.emit()
            return

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        t3 = time.perf_counter()
        try:
            audio_path = loop.run_until_complete(generate_speech(answer))
        except Exception as e:
            self.stage_signal.emit(f"语音合成失败：{e}")
            self.idle_signal.emit()
            return
        t4 = time.perf_counter()
        import os
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 800:
            self.stage_signal.emit("语音文件生成失败或内容太短，跳过口型合成。")
            self.idle_signal.emit()
            return
        self.stage_signal.emit(f"语音合成完成，耗时：{t4-t3:.2f}s")
        self.stage_signal.emit("正在生成嘴型动画...")

        # Wav2Lip驱动并在Qt界面显示帧
        t5 = time.perf_counter()
        gen = prepare_audio_batches(audio_path, self.face_img, self.face_coords)
        # 修改：推理和播放分开，推理完立即输出推理耗时
        infer_time, all_frames = self.lip_player.infer_frames(gen)
        self.stage_signal.emit(f"视频帧推理完成，耗时{infer_time:.2f}s")
        play_time = self.lip_player.play_frames(
            audio_path,
            all_frames,
            lambda frame: self.show_frame_signal.emit(frame)
        )
        self.stage_signal.emit(f"嘴型播放完成，耗时{play_time:.2f}s")
        t6 = time.perf_counter()
        self.stage_signal.emit(
            f"阶段总结：LLM:{t2-t1:.2f}s，TTS:{t4-t3:.2f}s，"
            f"视频帧推理:{infer_time:.2f}s，嘴型播放:{play_time:.2f}s，总计:{t6-t0:.2f}s"
        )
        self.idle_signal.emit()

if __name__ == "__main__":
    import os
    os.environ["QT_FONT_DPI"] = "96"
    if sys.platform == "win32":
        import ctypes
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("digitalhuman.app")
    sys.stdout.reconfigure(encoding='utf-8')
    app = QtWidgets.QApplication(sys.argv)
    win = DigitalHumanUI()
    win.show()
    sys.exit(app.exec_())

SyntaxError: invalid syntax (2323340730.py, line 1)

In [None]:
import sys
import threading
import asyncio
import time
import numpy as np
import cv2
from PyQt5 import QtWidgets, QtCore, QtGui

from config import API_KEY, BASE_URL, WAV2LIP_MODEL_PATH, IDLE_VIDEO_PATH, FACE_IMAGE_PATH, DEVICE
from llm import ChatBot
from tts import generate_speech
from face import load_model, preprocess_image, prepare_audio_batches, LipSyncPlayer
from asr import run_asr_thread

class BubbleTextEdit(QtWidgets.QTextEdit):
    """带有气泡背景的聊天框"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setReadOnly(True)
        self.setStyleSheet("""
            QTextEdit {
                background: #f7f9fa;
                border-radius: 10px;
                padding: 8px;
                font-size: 15px;
            }
        """)

class CardFrame(QtWidgets.QFrame):
    """圆角卡片样式"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("""
            QFrame {
                background: rgba(255,255,255,0.78);
                border-radius: 16px;
                border: 1px solid #e7e7e7;
            }
        """)

class FaceDisplayWidget(QtWidgets.QLabel):
    """纯圆角卡片，不带高亮光环"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def setPixmap(self, pixmap):
        # 圆角遮罩
        mask = QtGui.QPixmap(pixmap.size())
        mask.fill(QtCore.Qt.transparent)
        painter = QtGui.QPainter(mask)
        painter.setRenderHint(QtGui.QPainter.Antialiasing)
        radius = 60
        painter.setBrush(QtCore.Qt.white)
        painter.setPen(QtCore.Qt.NoPen)
        painter.drawRoundedRect(0, 0, pixmap.width(), pixmap.height(), radius, radius)
        painter.end()
        pixmap = pixmap.copy()
        pixmap.setMask(mask.createMaskFromColor(QtCore.Qt.transparent))
        super().setPixmap(pixmap)

class DigitalHumanUI(QtWidgets.QWidget):
    append_history_signal = QtCore.pyqtSignal(str, str)
    show_frame_signal = QtCore.pyqtSignal(np.ndarray)
    idle_signal = QtCore.pyqtSignal()
    stage_signal = QtCore.pyqtSignal(str)
    asr_text_signal = QtCore.pyqtSignal(str)
    asr_status_signal = QtCore.pyqtSignal(bool, bool)  # (is_asr_running, is_wake)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("数字人问答演示")
        self.setStyleSheet("background: #ecf0f4;")
        self.resize(1180, 700)
        self.asr_running = False
        self.asr_thread = None
        self._asr_wake = False
        self.busy = False
        self.idle_video_thread = None
        self.idle_video_running = threading.Event()
        self.init_ui()
        self.init_resources()
        self.append_history_signal.connect(self.append_history)
        self.show_frame_signal.connect(self.show_video_frame)
        self.idle_signal.connect(self.show_idle)
        self.stage_signal.connect(self.show_stage)
        self.asr_text_signal.connect(self.on_asr_text)
        self.asr_status_signal.connect(self.update_asr_status)
        self.show_idle()
        self.show_stage("系统待命...")

    def init_ui(self):
        # 聊天区
        self.chat_history = BubbleTextEdit(self)
        self.chat_history.setFixedHeight(280)

        # 阶段/进度区
        self.stage_card = CardFrame(self)
        self.stage_label = QtWidgets.QTextEdit(self.stage_card)
        self.stage_label.setReadOnly(True)
        self.stage_label.setAlignment(QtCore.Qt.AlignLeft)
        self.stage_label.setStyleSheet("color:#2065d6; font-size:15px;background:transparent;border:none;")
        self.stage_label.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        vbox = QtWidgets.QVBoxLayout(self.stage_card)
        vbox.addWidget(self.stage_label)
        vbox.setContentsMargins(8,8,8,8)

        # 状态栏
        self.asr_status_label = QtWidgets.QLabel("语音识别：关闭 | 唤醒：未唤醒", self)
        self.asr_status_label.setStyleSheet("color:green; font-size:14px;")
        self.asr_status_label.setFixedHeight(24)

        # 输入区（底部悬浮）
        self.input_box = QtWidgets.QLineEdit(self)
        self.input_box.setPlaceholderText("请输入你的问题或说话（支持语音唤醒）...")
        self.input_box.setStyleSheet("font-size:15px; border-radius:7px; padding:6px;background:#fff;")
        self.send_btn = QtWidgets.QPushButton("发送", self)
        self.send_btn.setStyleSheet("font-size:15px; padding:6px 18px; background:#2065d6; color:#fff; border-radius:7px;")
        self.send_btn.clicked.connect(self.on_submit)
        self.asr_btn = QtWidgets.QPushButton("🎤", self)
        self.asr_btn.setStyleSheet("font-size:19px; padding:6px 15px; background:#fff; color:#2065d6; border-radius:50%;")
        self.asr_btn.setCheckable(True)
        self.asr_btn.clicked.connect(self.on_toggle_asr)

        left_layout = QtWidgets.QVBoxLayout()
        left_layout.addWidget(self.chat_history)
        left_layout.addWidget(self.stage_card)
        left_layout.addWidget(self.asr_status_label)

        left_layout.setStretch(0, 3)
        left_layout.setStretch(1, 2)
        left_layout.setStretch(2, 0)

        # 输入区
        input_layout = QtWidgets.QHBoxLayout()
        input_layout.addWidget(self.input_box, 3)
        input_layout.addWidget(self.send_btn, 1)
        input_layout.addWidget(self.asr_btn, 0)
        input_layout.setSpacing(12)
        left_layout.addLayout(input_layout)

        # 人像视频区（右侧卡片，无呼吸光环）
        self.face_card = CardFrame(self)
        self.face_card.setFixedSize(410, 570)
        self.face_card.setStyleSheet("background:rgba(255,255,255,0.85);border-radius:30px;")
        face_vbox = QtWidgets.QVBoxLayout(self.face_card)
        face_vbox.setContentsMargins(0,0,0,0)
        face_vbox.setAlignment(QtCore.Qt.AlignCenter)
        self.face_label = FaceDisplayWidget(self.face_card)
        self.face_label.setFixedSize(370, 520)
        face_vbox.addWidget(self.face_label)

        # 主布局
        main_layout = QtWidgets.QHBoxLayout(self)
        main_layout.addLayout(left_layout, 3)
        main_layout.addWidget(self.face_card, 2)
        main_layout.setSpacing(34)
        self.setLayout(main_layout)

    def init_resources(self):
        self.bot = ChatBot(
            api_key=API_KEY,
            base_url=BASE_URL,
            log_dir="logs",
            default_background="你是一个知识渊博的助手，能够简洁地回答问题。",
            default_prefix="请简洁地回答下述问题："
        )
        self.model = load_model(WAV2LIP_MODEL_PATH, DEVICE)
        self.face_img, self.face_coords, self.orig_image = preprocess_image(FACE_IMAGE_PATH, device=DEVICE)
        self.idle_video_path = IDLE_VIDEO_PATH  # 例如 'idle.mp4'
        self.lip_player = LipSyncPlayer(self.model, DEVICE, self.orig_image, self.face_coords, fps=25)

    def show_idle(self):
        self.idle_video_running.set()
        if self.idle_video_thread is None or not self.idle_video_thread.is_alive():
            self.idle_video_thread = threading.Thread(target=self.play_idle_video, daemon=True)
            self.idle_video_thread.start()

    def play_idle_video(self):
        while self.idle_video_running.is_set():
            cap = cv2.VideoCapture(self.idle_video_path)
            if not cap.isOpened():
                print(f"无法打开idle视频：{self.idle_video_path}")
                return
            while self.idle_video_running.is_set():
                ret, frame = cap.read()
                if not ret:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (360, 480))
                pix = QtGui.QPixmap.fromImage(QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], QtGui.QImage.Format_RGB888))
                self.face_label.setPixmap(pix)
                time.sleep(1.0 / 25)
            cap.release()

    def stop_idle_video(self):
        self.idle_video_running.clear()
        if self.idle_video_thread is not None:
            self.idle_video_thread.join(timeout=0.2)

    @QtCore.pyqtSlot(np.ndarray)
    def show_video_frame(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (360, 480))
        pix = QtGui.QPixmap.fromImage(QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], QtGui.QImage.Format_RGB888))
        self.face_label.setPixmap(pix)

    @QtCore.pyqtSlot(str, str)
    def append_history(self, speaker, text):
        if speaker == "用户":
            self.chat_history.append(f"<div style='text-align:right; margin:10px;'><span style='background:#2065d6;color:white;border-radius:12px;padding:8px 12px;display:inline-block;'>{text}</span></div>")
        else:
            self.chat_history.append(f"<div style='text-align:left; margin:10px;'><span style='background:#fff;color:#222;border-radius:12px;padding:8px 12px;display:inline-block;border:1px solid #e7e7e7;'>{text}</span></div>")

    @QtCore.pyqtSlot(str)
    def show_stage(self, text):
        self.stage_label.append(text)
        self.stage_label.verticalScrollBar().setValue(self.stage_label.verticalScrollBar().maximum())

    @QtCore.pyqtSlot(str)
    def on_asr_text(self, text):
        if self.busy:
            self.stage_signal.emit("正在播报回答，请稍后再提问。")
            return
        text = text.strip()
        if text:
            self.input_box.setText(text)
            self.on_submit()

    @QtCore.pyqtSlot(bool, bool)
    def update_asr_status(self, running, wake):
        self.asr_running = running
        self._asr_wake = wake
        s = f"语音识别：{'开启' if running else '关闭'} | 唤醒：{'已唤醒' if wake else '未唤醒'}"
        color = "blue" if running else "gray"
        wcolor = "red" if wake else "green"
        self.asr_status_label.setText(s)
        self.asr_status_label.setStyleSheet(f"color:{wcolor if wake else color}; font-size:14px;")
        if not running:
            self.asr_btn.setChecked(False)
            self.asr_btn.setText("🎤")
        else:
            self.asr_btn.setChecked(True)
            self.asr_btn.setText("⏹")

    def on_toggle_asr(self):
        if self.asr_running:
            self._stop_asr()
        else:
            self._start_asr()

    def _start_asr(self):
        if self.asr_running:
            return
        self.asr_running = True
        self.asr_btn.setText("⏹")
        self.asr_status_signal.emit(True, False)
        self.asr_thread = threading.Thread(target=self.start_asr, daemon=True)
        self.asr_thread.start()

    def _stop_asr(self):
        if not self.asr_running:
            return
        self.asr_running = False
        self.asr_btn.setText("🎤")
        self.asr_status_signal.emit(False, False)

    def start_asr(self):
        def asr_callback(text, wake_state):
            self.asr_status_signal.emit(True, wake_state)
            if text and wake_state and not self.busy:
                self.asr_text_signal.emit(text)
        try:
            run_asr_thread(asr_callback, lambda: self.asr_running)
        except Exception as e:
            self.asr_status_signal.emit(False, False)

    def on_submit(self):
        if self.busy:
            self.stage_signal.emit("正在播报上一个回答，请稍后...")
            return
        question = self.input_box.text().strip()
        if not question:
            return
        self.input_box.setText("")
        self.append_history("用户", question)
        self.stop_idle_video()
        self.show_stage("开始处理...")
        self.busy = True
        threading.Thread(target=self.process_conversation, args=(question,)).start()

    def process_conversation(self, question):
        t0 = time.perf_counter()
        self.stage_signal.emit("等待大模型回复...")
        t1 = time.perf_counter()
        answer = self.bot.chat(question)
        t2 = time.perf_counter()
        self.append_history_signal.emit("助手", answer)
        self.stage_signal.emit(f"大模型回复完成，耗时：{t2-t1:.2f}s")
        self.stage_signal.emit("正在合成语音...")

        if not answer or len(answer.strip()) < 2:
            self.stage_signal.emit("回答内容过短，跳过语音与口型合成。")
            time.sleep(0.5)
            self.idle_signal.emit()
            self.busy = False
            return

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        t3 = time.perf_counter()
        try:
            audio_path = loop.run_until_complete(generate_speech(answer))
        except Exception as e:
            self.stage_signal.emit(f"语音合成失败：{e}")
            self.idle_signal.emit()
            self.busy = False
            return
        t4 = time.perf_counter()
        import os
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 800:
            self.stage_signal.emit("语音文件生成失败或内容太短，跳过口型合成。")
            self.idle_signal.emit()
            self.busy = False
            return
        self.stage_signal.emit(f"语音合成完成，耗时：{t4-t3:.2f}s")
        self.stage_signal.emit("正在生成嘴型动画...")

        t5 = time.perf_counter()
        gen = prepare_audio_batches(audio_path, self.face_img, self.face_coords)
        infer_time, all_frames = self.lip_player.infer_frames(gen)
        self.stage_signal.emit(f"视频帧推理完成，耗时{infer_time:.2f}s")
        play_time = self.lip_player.play_frames(
            audio_path,
            all_frames,
            lambda frame: self.show_frame_signal.emit(frame)
        )
        self.stage_signal.emit(f"嘴型播放完成，耗时{play_time:.2f}s")
        t6 = time.perf_counter()
        self.stage_signal.emit(
            f"阶段总结：LLM:{t2-t1:.2f}s，TTS:{t4-t3:.2f}s，"
            f"视频帧推理:{infer_time:.2f}s，嘴型播放:{play_time:.2f}s，总计:{t6-t0:.2f}s"
        )
        self.idle_signal.emit()
        self.busy = False

if __name__ == "__main__":
    import os
    os.environ["QT_FONT_DPI"] = "96"
    if sys.platform == "win32":
        import ctypes
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("digitalhuman.app")
    sys.stdout.reconfigure(encoding='utf-8')
    app = QtWidgets.QApplication(sys.argv)
    win = DigitalHumanUI()
    win.show()
    sys.exit(app.exec_())

In [None]:
import sys
import threading
import asyncio
import time
import numpy as np
import cv2
import os
from PyQt5 import QtWidgets, QtCore, QtGui, QtMultimedia

from config import API_KEY, BASE_URL, WAV2LIP_MODEL_PATH, IDLE_VIDEO_PATH, FACE_IMAGE_PATH, DEVICE
from llm import ChatBot
from tts import generate_speech
from face import load_model, preprocess_image, prepare_audio_batches, LipSyncPlayer
from asr import run_asr_thread

class BubbleTextEdit(QtWidgets.QTextEdit):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setReadOnly(True)
        self.setStyleSheet("""
            QTextEdit {
                background: #181a1b;
                border-radius: 12px;
                padding: 8px;
                font-size: 16px;
                color: #e6e6e6;
                border: 1px solid #232323;
            }
        """)

    def append_bubble(self, text, speaker):
        if speaker == "用户":
            self.append(
                f"<div style='text-align:right; margin:10px;'><span style='background:#2e2f31;color:#4faaff;border-radius:14px;padding:10px 16px;display:inline-block;'>{text}</span></div>"
            )
        else:
            self.append(
                f"<div style='text-align:left; margin:10px;'><span style='background:#2e2f31;color:#f7f7f7;border-radius:14px;padding:10px 16px;display:inline-block;'>{text}</span></div>"
            )

class CardFrame(QtWidgets.QFrame):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("""
            QFrame {
                background: rgba(34,34,34,0.92);
                border-radius: 22px;
                border: 1.5px solid #232323;
            }
        """)

class FaceDisplayWidget(QtWidgets.QLabel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("background:transparent;")
        self.setAlignment(QtCore.Qt.AlignCenter)
        self.setScaledContents(True)

class AudioPlayer(QtCore.QObject):
    finished = QtCore.pyqtSignal()
    def __init__(self, parent=None):
        super().__init__(parent)
        self.player = QtMultimedia.QMediaPlayer()
        self.player.setVolume(100)
        self.player.mediaStatusChanged.connect(self.handle_status)

    def play(self, audio_path):
        if self.player.state() == QtMultimedia.QMediaPlayer.PlayingState:
            self.player.stop()
        url = QtCore.QUrl.fromLocalFile(os.path.abspath(audio_path))
        self.player.setMedia(QtMultimedia.QMediaContent(url))
        self.player.play()

    def handle_status(self, status):
        if status in (QtMultimedia.QMediaPlayer.EndOfMedia, QtMultimedia.QMediaPlayer.InvalidMedia):
            self.finished.emit()

class DigitalHumanUI(QtWidgets.QWidget):
    append_history_signal = QtCore.pyqtSignal(str, str)
    show_frame_signal = QtCore.pyqtSignal(QtGui.QPixmap)
    idle_signal = QtCore.pyqtSignal()
    stage_signal = QtCore.pyqtSignal(str)
    asr_text_signal = QtCore.pyqtSignal(str)
    asr_status_signal = QtCore.pyqtSignal(bool, bool)
    play_video_frames_signal = QtCore.pyqtSignal(list, str, float)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("数字人问答演示")
        self.setStyleSheet("QWidget { background: #000; }")
        self.resize(1400, 900)
        self.asr_running = False
        self.asr_thread = None
        self._asr_wake = False
        self.busy = False
        self.idle_video_thread = None
        self.idle_video_running = threading.Event()
        self._last_pixmap = None
        self._last_face_size = (0, 0)
        self.audio_player = AudioPlayer()
        self.sync_timer = QtCore.QTimer(self)
        self.sync_timer.timeout.connect(self._sync_frame_with_audio)
        self.video_frames = []
        self.video_frame_count = 0
        self.target_fps = 25
        self.audio_total_ms = 0
        self.init_ui()
        self.init_resources()
        self.append_history_signal.connect(self.append_history)
        self.show_frame_signal.connect(self._show_pixmap_mainthread)
        self.idle_signal.connect(self.show_idle)
        self.stage_signal.connect(self.show_stage)
        self.asr_text_signal.connect(self.on_asr_text)
        self.asr_status_signal.connect(self.update_asr_status)
        self.play_video_frames_signal.connect(self.play_video_frames)
        self.show_idle()
        self.show_stage("系统待命...")

    def init_ui(self):
        self.outer_layout = QtWidgets.QVBoxLayout(self)
        self.outer_layout.setContentsMargins(0, 0, 0, 0)
        self.outer_layout.setSpacing(0)
        self.top_panel = QtWidgets.QWidget(self)
        self.top_layout = QtWidgets.QHBoxLayout(self.top_panel)
        self.top_layout.setContentsMargins(0, 0, 0, 0)
        self.top_layout.setSpacing(0)

        self.left_panel = QtWidgets.QWidget(self.top_panel)
        self.left_layout = QtWidgets.QVBoxLayout(self.left_panel)
        self.left_layout.setContentsMargins(24, 24, 12, 24)
        self.left_layout.setSpacing(18)
        self.chat_history = BubbleTextEdit(parent=self.left_panel)
        self.left_layout.addWidget(self.chat_history)
        self.left_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.left_panel)

        self.center_panel = QtWidgets.QWidget(self.top_panel)
        self.center_panel.setStyleSheet("background:transparent;")
        self.face_label = FaceDisplayWidget(self.center_panel)
        self.center_layout = QtWidgets.QVBoxLayout(self.center_panel)
        self.center_layout.setContentsMargins(0, 0, 0, 0)
        self.center_layout.setAlignment(QtCore.Qt.AlignCenter)
        self.center_layout.addStretch()
        self.center_layout.addWidget(self.face_label, alignment=QtCore.Qt.AlignHCenter | QtCore.Qt.AlignVCenter)
        self.center_layout.addStretch()
        self.top_layout.addWidget(self.center_panel)

        self.right_panel = QtWidgets.QWidget(self.top_panel)
        self.right_layout = QtWidgets.QVBoxLayout(self.right_panel)
        self.right_layout.setContentsMargins(12, 24, 24, 24)
        self.right_layout.setSpacing(18)
        self.stage_card = CardFrame(self.right_panel)
        self.stage_label = QtWidgets.QTextEdit(self.stage_card)
        self.stage_label.setReadOnly(True)
        self.stage_label.setAlignment(QtCore.Qt.AlignLeft)
        self.stage_label.setStyleSheet("color:#4faaff; font-size:16px;background:transparent;border:none;")
        self.stage_label.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        vbox = QtWidgets.QVBoxLayout(self.stage_card)
        vbox.addWidget(self.stage_label)
        vbox.setContentsMargins(10, 10, 10, 10)
        self.right_layout.addWidget(self.stage_card)
        self.right_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.right_panel)

        self.outer_layout.addWidget(self.top_panel, stretch=10)

        self.input_panel = QtWidgets.QWidget(self)
        input_layout = QtWidgets.QHBoxLayout(self.input_panel)
        input_layout.setContentsMargins(48, 0, 48, 18)
        input_layout.setSpacing(16)
        self.input_box = QtWidgets.QLineEdit(self.input_panel)
        self.input_box.setPlaceholderText("请输入你的问题或说话（支持语音唤醒）...")
        self.input_box.setStyleSheet("""
            font-size:17px; border-radius:12px; padding:10px; background:#181a1b;
            border:2px solid #232323; color:#e6e6e6;
        """)
        self.send_btn = QtWidgets.QPushButton("发送", self.input_panel)
        self.send_btn.setStyleSheet("""
            font-size:17px; padding:10px 28px; background:#4faaff; color:#15181a; border-radius:10px;
            font-weight:bold;
        """)
        self.send_btn.clicked.connect(self.on_submit)
        self.asr_btn = QtWidgets.QPushButton("🎤", self.input_panel)
        self.asr_btn.setStyleSheet("""
            font-size:24px; padding:10px 16px; background:#232323; color:#4faaff; border-radius:50%;
            border:2px solid #232323;
        """)
        self.asr_btn.setCheckable(True)
        self.asr_btn.clicked.connect(self.on_toggle_asr)

        input_layout.addWidget(self.input_box, 10)
        input_layout.addWidget(self.send_btn, 2)
        input_layout.addWidget(self.asr_btn, 1)
        self.input_panel.setStyleSheet("background:transparent;")
        self.outer_layout.addWidget(self.input_panel, stretch=0)

        self.asr_status_label = QtWidgets.QLabel("语音识别：关闭 | 唤醒：未唤醒", self)
        self.asr_status_label.setStyleSheet("color:#4faaff; font-size:15px; font-weight:bold; background:transparent;")
        self.asr_status_label.setAlignment(QtCore.Qt.AlignRight)
        self.asr_status_label.setFixedHeight(24)
        self.outer_layout.addWidget(self.asr_status_label, alignment=QtCore.Qt.AlignRight)

        self.top_layout.setStretch(0, 1)
        self.top_layout.setStretch(1, 1)
        self.top_layout.setStretch(2, 1)
        self.outer_layout.setStretch(0, 15)
        self.outer_layout.setStretch(1, 0)
        self.outer_layout.setStretch(2, 0)
        self.resizeEvent = self.on_resize

    def update_center_panel_geometry(self):
        w = self.center_panel.width()
        h = self.center_panel.height()
        self.face_label.setMinimumSize(w, h)
        self.face_label.setMaximumSize(w, h)
        self._last_face_size = (w, h)
        if self._last_pixmap is not None:
            self._show_pixmap_mainthread(self._last_pixmap)

    def on_resize(self, event):
        self.update_center_panel_geometry()
        event.accept()

    def init_resources(self):
        self.bot = ChatBot(
            api_key=API_KEY,
            base_url=BASE_URL,
            log_dir="logs",
            default_background="你是一个知识渊博的助手，能够简洁地回答问题。",
            default_prefix="请简洁地回答下述问题："
        )
        self.model = load_model(WAV2LIP_MODEL_PATH, DEVICE)
        self.face_img, self.face_coords, self.orig_image = preprocess_image(FACE_IMAGE_PATH, device=DEVICE)
        self.idle_video_path = IDLE_VIDEO_PATH
        self.lip_player = LipSyncPlayer(self.model, DEVICE, self.orig_image, self.face_coords, fps=25)

    def show_idle(self):
        self.stop_sync()
        self.idle_video_running.set()
        if self.idle_video_thread is None or not self.idle_video_thread.is_alive():
            self.idle_video_thread = threading.Thread(target=self.play_idle_video, daemon=True)
            self.idle_video_thread.start()

    def play_idle_video(self):
        while self.idle_video_running.is_set():
            cap = cv2.VideoCapture(self.idle_video_path)
            if not cap.isOpened():
                print(f"无法打开idle视频：{self.idle_video_path}")
                return
            while self.idle_video_running.is_set():
                ret, frame = cap.read()
                if not ret:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                w, h = self._last_face_size
                if w < 10 or h < 10:
                    w, h = 300, 400
                frame = cv2.resize(frame, (w, h))
                qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], QtGui.QImage.Format_RGB888)
                pix = QtGui.QPixmap.fromImage(qtimg)
                self._last_pixmap = pix
                self.show_frame_signal.emit(pix)
                time.sleep(1.0 / 25)
            cap.release()

    def stop_idle_video(self):
        self.idle_video_running.clear()

    def stop_sync(self):
        if self.sync_timer.isActive():
            self.sync_timer.stop()
        self.video_frames = []
        self.video_frame_count = 0
        self.audio_total_ms = 0

    @QtCore.pyqtSlot(QtGui.QPixmap)
    def _show_pixmap_mainthread(self, pix):
        w, h = self._last_face_size
        scaled = pix.scaled(w, h, QtCore.Qt.IgnoreAspectRatio, QtCore.Qt.SmoothTransformation)
        self.face_label.setPixmap(scaled)
        self._last_pixmap = pix

    @QtCore.pyqtSlot(str, str)
    def append_history(self, speaker, text):
        self.chat_history.append_bubble(text, speaker)

    @QtCore.pyqtSlot(str)
    def show_stage(self, text):
        self.stage_label.append(f"<div style='color:#4faaff'>{text}</div>")
        self.stage_label.verticalScrollBar().setValue(self.stage_label.verticalScrollBar().maximum())

    @QtCore.pyqtSlot(np.ndarray)
    def show_video_frame(self, frame):
        w, h = self._last_face_size
        if w < 10 or h < 10:
            w, h = 300, 400
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (w, h))
        qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], QtGui.QImage.Format_RGB888)
        pix = QtGui.QPixmap.fromImage(qtimg)
        self._last_pixmap = pix
        self.show_frame_signal.emit(pix)

    @QtCore.pyqtSlot(list, str, float)
    def play_video_frames(self, frames, audio_path, audio_duration):
        self.stop_idle_video()  # 只在此处暂停idle
        self.stop_sync()
        self.video_frames = frames
        self.video_frame_count = len(frames)
        self.target_fps = 25
        self.audio_total_ms = int(audio_duration * 1000)
        self.audio_player.play(audio_path)
        self.sync_timer.start(20)

    def _sync_frame_with_audio(self):
        ms = self.audio_player.player.position()
        if ms <= 0:
            return
        idx = int(ms * self.target_fps / 1000)
        idx = min(idx, self.video_frame_count - 1)
        if 0 <= idx < self.video_frame_count:
            frame = self.video_frames[idx]
            self.show_video_frame(frame)
        if ms >= self.audio_total_ms - 20 or idx >= self.video_frame_count - 1:
            self.sync_timer.stop()
            self.idle_signal.emit()

    @QtCore.pyqtSlot(str)
    def on_asr_text(self, text):
        if self.busy:
            self.stage_signal.emit("正在播报回答，请稍后再提问。")
            return
        text = text.strip()
        if text:
            self.input_box.setText(text)
            self.on_submit()

    @QtCore.pyqtSlot(bool, bool)
    def update_asr_status(self, running, wake):
        self.asr_running = running
        self._asr_wake = wake
        s = f"语音识别：{'开启' if running else '关闭'} | 唤醒：{'已唤醒' if wake else '未唤醒'}"
        color = "#4faaff" if running else "#555"
        wcolor = "#ff5050" if wake else "#4faaff"
        self.asr_status_label.setText(s)
        self.asr_status_label.setStyleSheet(f"color:{wcolor if wake else color}; font-size:15px; font-weight:bold; background:transparent;")
        if not running:
            self.asr_btn.setChecked(False)
            self.asr_btn.setText("🎤")
        else:
            self.asr_btn.setChecked(True)
            self.asr_btn.setText("⏹")

    def on_toggle_asr(self):
        if self.asr_running:
            self._stop_asr()
        else:
            self._start_asr()

    def _start_asr(self):
        if self.asr_running:
            return
        self.asr_running = True
        self.asr_btn.setText("⏹")
        self.asr_status_signal.emit(True, False)
        self.asr_thread = threading.Thread(target=self.start_asr, daemon=True)
        self.asr_thread.start()

    def _stop_asr(self):
        if not self.asr_running:
            return
        self.asr_running = False
        self.asr_btn.setText("🎤")
        self.asr_status_signal.emit(False, False)

    def start_asr(self):
        def asr_callback(text, wake_state):
            self.asr_status_signal.emit(True, wake_state)
            if text and wake_state and not self.busy:
                self.asr_text_signal.emit(text)
        try:
            run_asr_thread(asr_callback, lambda: self.asr_running)
        except Exception as e:
            self.asr_status_signal.emit(False, False)

    def on_submit(self):
        if self.busy:
            self.stage_signal.emit("正在播报上一个回答，请稍后...")
            return
        question = self.input_box.text().strip()
        if not question:
            return
        self.input_box.setText("")
        self.append_history("用户", question)
        # self.stop_idle_video()  # 不要在这里暂停idle
        self.show_stage("开始处理...")
        self.busy = True
        threading.Thread(target=self.process_conversation, args=(question,)).start()

    def process_conversation(self, question):
        t0 = time.perf_counter()
        self.stage_signal.emit("等待大模型回复...")
        t1 = time.perf_counter()
        answer = self.bot.chat(question)
        t2 = time.perf_counter()
        self.append_history_signal.emit("助手", answer)
        self.stage_signal.emit(f"大模型回复完成，耗时：{t2-t1:.2f}s")
        self.stage_signal.emit("正在合成语音...")

        if not answer or len(answer.strip()) < 2:
            self.stage_signal.emit("回答内容过短，跳过语音与口型合成。")
            time.sleep(0.5)
            self.idle_signal.emit()
            self.busy = False
            return

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        t3 = time.perf_counter()
        try:
            audio_path = loop.run_until_complete(generate_speech(answer))
        except Exception as e:
            self.stage_signal.emit(f"语音合成失败：{e}")
            self.idle_signal.emit()
            self.busy = False
            return
        t4 = time.perf_counter()
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 800:
            self.stage_signal.emit("语音文件生成失败或内容太短，跳过口型合成。")
            self.idle_signal.emit()
            self.busy = False
            return
        import soundfile as sf
        audio_info = sf.info(audio_path)
        audio_duration = float(audio_info.duration)
        self.stage_signal.emit(f"语音合成完成，耗时：{t4-t3:.2f}s")
        self.stage_signal.emit("正在生成嘴型动画...")

        t5 = time.perf_counter()
        gen = prepare_audio_batches(audio_path, self.face_img, self.face_coords)
        infer_time, all_frames = self.lip_player.infer_frames(gen)
        self.stage_signal.emit(f"视频帧推理完成，耗时{infer_time:.2f}s")
        self.stage_signal.emit("正在播放语音与动画...")

        self.play_video_frames_signal.emit(all_frames, audio_path, audio_duration)

        t6 = time.perf_counter()
        self.busy = False

if __name__ == "__main__":
    os.environ["QT_FONT_DPI"] = "96"
    if sys.platform == "win32":
        import ctypes
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("digitalhuman.app")
    sys.stdout.reconfigure(encoding='utf-8')
    app = QtWidgets.QApplication(sys.argv)
    win = DigitalHumanUI()
    win.show()
    sys.exit(app.exec_())

nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnb,wameni


In [None]:
import sys
import threading
import asyncio
import time
import numpy as np
import cv2
import os
from PyQt5 import QtWidgets, QtCore, QtGui, QtMultimedia

from config import API_KEY, BASE_URL, WAV2LIP_MODEL_PATH, IDLE_VIDEO_PATH, FACE_IMAGE_PATH, DEVICE
from llm import ChatBot
from tts import generate_speech
from face import load_model, preprocess_image, prepare_audio_batches, LipSyncPlayer
from asr import run_asr_thread

class BubbleTextEdit(QtWidgets.QTextEdit):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setReadOnly(True)
        self.setStyleSheet("""
            QTextEdit {
                background: #181a1b;
                border-radius: 12px;
                padding: 8px;
                font-size: 16px;
                color: #e6e6e6;
                border: 1px solid #232323;
            }
        """)

    def append_bubble(self, text, speaker):
        if speaker == "用户":
            self.append(
                f"<div style='text-align:right; margin:10px;'><span style='background:#2e2f31;color:#4faaff;border-radius:14px;padding:10px 16px;display:inline-block;'>{text}</span></div>"
            )
        else:
            self.append(
                f"<div style='text-align:left; margin:10px;'><span style='background:#2e2f31;color:#f7f7f7;border-radius:14px;padding:10px 16px;display:inline-block;'>{text}</span></div>"
            )

class CardFrame(QtWidgets.QFrame):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("""
            QFrame {
                background: rgba(34,34,34,0.92);
                border-radius: 22px;
                border: 1.5px solid #232323;
            }
        """)

class FaceDisplayWidget(QtWidgets.QLabel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("background:transparent;")
        self.setAlignment(QtCore.Qt.AlignCenter)
        self.setScaledContents(True)

class AudioPlayer(QtCore.QObject):
    finished = QtCore.pyqtSignal()
    def __init__(self, parent=None):
        super().__init__(parent)
        self.player = QtMultimedia.QMediaPlayer()
        self.player.setVolume(100)
        self.player.mediaStatusChanged.connect(self.handle_status)

    def play(self, audio_path):
        if self.player.state() == QtMultimedia.QMediaPlayer.PlayingState:
            self.player.stop()
        url = QtCore.QUrl.fromLocalFile(os.path.abspath(audio_path))
        self.player.setMedia(QtMultimedia.QMediaContent(url))
        self.player.play()

    def handle_status(self, status):
        if status in (QtMultimedia.QMediaPlayer.EndOfMedia, QtMultimedia.QMediaPlayer.InvalidMedia):
            self.finished.emit()

class DigitalHumanUI(QtWidgets.QWidget):
    append_history_signal = QtCore.pyqtSignal(str, str)
    show_frame_signal = QtCore.pyqtSignal(QtGui.QPixmap)
    idle_signal = QtCore.pyqtSignal()
    stage_signal = QtCore.pyqtSignal(str)
    asr_text_signal = QtCore.pyqtSignal(str)
    asr_status_signal = QtCore.pyqtSignal(bool, bool)
    play_video_frames_signal = QtCore.pyqtSignal(list, str, float)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("数字人问答演示")
        self.setStyleSheet("QWidget { background: #000; }")
        self.resize(1400, 900)
        self.asr_running = False
        self.asr_thread = None
        self._asr_wake = False
        self.busy = False
        self.idle_video_thread = None
        self.idle_video_running = threading.Event()
        self._last_pixmap = None
        self._last_face_size = (0, 0)
        self.audio_player = AudioPlayer()
        self.sync_timer = QtCore.QTimer(self)
        self.sync_timer.timeout.connect(self._sync_frame_with_audio)
        self.video_frames = []
        self.video_frame_count = 0
        self.target_fps = 25
        self.audio_total_ms = 0
        self.init_ui()
        self.init_resources()
        self.append_history_signal.connect(self.append_history)
        self.show_frame_signal.connect(self._show_pixmap_mainthread)
        self.idle_signal.connect(self.show_idle)
        self.stage_signal.connect(self.show_stage)
        self.asr_text_signal.connect(self.on_asr_text)
        self.asr_status_signal.connect(self.update_asr_status)
        self.play_video_frames_signal.connect(self.play_video_frames)
        self.show_idle()
        self.show_stage("系统待命...")

    def init_ui(self):
        self.outer_layout = QtWidgets.QVBoxLayout(self)
        self.outer_layout.setContentsMargins(0, 0, 0, 0)
        self.outer_layout.setSpacing(0)
        self.top_panel = QtWidgets.QWidget(self)
        self.top_layout = QtWidgets.QHBoxLayout(self.top_panel)
        self.top_layout.setContentsMargins(0, 0, 0, 0)
        self.top_layout.setSpacing(0)

        self.left_panel = QtWidgets.QWidget(self.top_panel)
        self.left_layout = QtWidgets.QVBoxLayout(self.left_panel)
        self.left_layout.setContentsMargins(24, 24, 12, 24)
        self.left_layout.setSpacing(18)
        self.chat_history = BubbleTextEdit(parent=self.left_panel)
        self.left_layout.addWidget(self.chat_history)
        self.left_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.left_panel)

        self.center_panel = QtWidgets.QWidget(self.top_panel)
        self.center_panel.setStyleSheet("background:transparent;")
        self.face_label = FaceDisplayWidget(self.center_panel)
        self.center_layout = QtWidgets.QVBoxLayout(self.center_panel)
        self.center_layout.setContentsMargins(0, 0, 0, 0)
        self.center_layout.setAlignment(QtCore.Qt.AlignCenter)
        self.center_layout.addStretch()
        self.center_layout.addWidget(self.face_label, alignment=QtCore.Qt.AlignHCenter | QtCore.Qt.AlignVCenter)
        self.center_layout.addStretch()
        self.top_layout.addWidget(self.center_panel)

        self.right_panel = QtWidgets.QWidget(self.top_panel)
        self.right_layout = QtWidgets.QVBoxLayout(self.right_panel)
        self.right_layout.setContentsMargins(12, 24, 24, 24)
        self.right_layout.setSpacing(18)
        self.stage_card = CardFrame(self.right_panel)
        self.stage_label = QtWidgets.QTextEdit(self.stage_card)
        self.stage_label.setReadOnly(True)
        self.stage_label.setAlignment(QtCore.Qt.AlignLeft)
        self.stage_label.setStyleSheet("color:#4faaff; font-size:16px;background:transparent;border:none;")
        self.stage_label.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        vbox = QtWidgets.QVBoxLayout(self.stage_card)
        vbox.addWidget(self.stage_label)
        vbox.setContentsMargins(10, 10, 10, 10)
        self.right_layout.addWidget(self.stage_card)
        self.right_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.right_panel)

        self.outer_layout.addWidget(self.top_panel, stretch=10)

        self.input_panel = QtWidgets.QWidget(self)
        input_layout = QtWidgets.QHBoxLayout(self.input_panel)
        input_layout.setContentsMargins(48, 0, 48, 18)
        input_layout.setSpacing(16)
        self.input_box = QtWidgets.QLineEdit(self.input_panel)
        self.input_box.setPlaceholderText("请输入你的问题或说话（支持语音唤醒）...")
        self.input_box.setStyleSheet("""
            font-size:17px; border-radius:12px; padding:10px; background:#181a1b;
            border:2px solid #232323; color:#e6e6e6;
        """)
        self.send_btn = QtWidgets.QPushButton("发送", self.input_panel)
        self.send_btn.setStyleSheet("""
            font-size:17px; padding:10px 28px; background:#4faaff; color:#15181a; border-radius:10px;
            font-weight:bold;
        """)
        self.send_btn.clicked.connect(self.on_submit)
        self.asr_btn = QtWidgets.QPushButton("🎤", self.input_panel)
        self.asr_btn.setStyleSheet("""
            font-size:24px; padding:10px 16px; background:#232323; color:#4faaff; border-radius:50%;
            border:2px solid #232323;
        """)
        self.asr_btn.setCheckable(True)
        self.asr_btn.clicked.connect(self.on_toggle_asr)

        input_layout.addWidget(self.input_box, 10)
        input_layout.addWidget(self.send_btn, 2)
        input_layout.addWidget(self.asr_btn, 1)
        self.input_panel.setStyleSheet("background:transparent;")
        self.outer_layout.addWidget(self.input_panel, stretch=0)

        self.asr_status_label = QtWidgets.QLabel("语音识别：关闭 | 唤醒：未唤醒", self)
        self.asr_status_label.setStyleSheet("color:#4faaff; font-size:15px; font-weight:bold; background:transparent;")
        self.asr_status_label.setAlignment(QtCore.Qt.AlignRight)
        self.asr_status_label.setFixedHeight(24)
        self.outer_layout.addWidget(self.asr_status_label, alignment=QtCore.Qt.AlignRight)

        self.top_layout.setStretch(0, 1)
        self.top_layout.setStretch(1, 1)
        self.top_layout.setStretch(2, 1)
        self.outer_layout.setStretch(0, 15)
        self.outer_layout.setStretch(1, 0)
        self.outer_layout.setStretch(2, 0)
        self.resizeEvent = self.on_resize

    def update_center_panel_geometry(self):
        w = self.center_panel.width()
        h = self.center_panel.height()
        self.face_label.setMinimumSize(w, h)
        self.face_label.setMaximumSize(w, h)
        self._last_face_size = (w, h)
        if self._last_pixmap is not None:
            self._show_pixmap_mainthread(self._last_pixmap)

    def on_resize(self, event):
        self.update_center_panel_geometry()
        event.accept()

    def init_resources(self):
        self.bot = ChatBot(
            api_key=API_KEY,
            base_url=BASE_URL,
            log_dir="logs",
            default_background="你是一个知识渊博的助手，能够简洁地回答问题。",
            default_prefix="请简洁地回答下述问题："
        )
        self.model = load_model(WAV2LIP_MODEL_PATH, DEVICE)
        self.face_img, self.face_coords, self.orig_image = preprocess_image(FACE_IMAGE_PATH, device=DEVICE)
        self.idle_video_path = IDLE_VIDEO_PATH
        self.lip_player = LipSyncPlayer(self.model, DEVICE, self.orig_image, self.face_coords, fps=25)

    def show_idle(self):
        self.stop_sync()
        self.idle_video_running.set()
        if self.idle_video_thread is None or not self.idle_video_thread.is_alive():
            self.idle_video_thread = threading.Thread(target=self.play_idle_video, daemon=True)
            self.idle_video_thread.start()

    def play_idle_video(self):
        while self.idle_video_running.is_set():
            cap = cv2.VideoCapture(self.idle_video_path)
            if not cap.isOpened():
                print(f"无法打开idle视频：{self.idle_video_path}")
                return
            while self.idle_video_running.is_set():
                ret, frame = cap.read()
                if not ret:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                w, h = self._last_face_size
                if w < 10 or h < 10:
                    w, h = 300, 400
                frame = cv2.resize(frame, (w, h))
                qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], frame.strides[0], QtGui.QImage.Format_RGB888)
                pix = QtGui.QPixmap.fromImage(qtimg)
                self._last_pixmap = pix
                self.show_frame_signal.emit(pix)
                time.sleep(1.0 / 25)
            cap.release()

    def stop_idle_video(self):
        self.idle_video_running.clear()

    def stop_sync(self):
        if self.sync_timer.isActive():
            self.sync_timer.stop()
        self.video_frames = []
        self.video_frame_count = 0
        self.audio_total_ms = 0

    @QtCore.pyqtSlot(QtGui.QPixmap)
    def _show_pixmap_mainthread(self, pix):
        w, h = self._last_face_size
        scaled = pix.scaled(w, h, QtCore.Qt.IgnoreAspectRatio, QtCore.Qt.SmoothTransformation)
        self.face_label.setPixmap(scaled)
        self._last_pixmap = pix

    @QtCore.pyqtSlot(str, str)
    def append_history(self, speaker, text):
        self.chat_history.append_bubble(text, speaker)

    @QtCore.pyqtSlot(str)
    def show_stage(self, text):
        self.stage_label.append(f"<div style='color:#4faaff'>{text}</div>")
        self.stage_label.verticalScrollBar().setValue(self.stage_label.verticalScrollBar().maximum())

    @QtCore.pyqtSlot(np.ndarray)
    def show_video_frame(self, frame):
        w, h = self._last_face_size
        # 确保frame为RGB
        if len(frame.shape) == 2 or frame.shape[2] == 1:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
        elif frame.shape[2] == 3:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (w, h))
        qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], frame.strides[0], QtGui.QImage.Format_RGB888)
        pix = QtGui.QPixmap.fromImage(qtimg)
        self.face_label.setPixmap(pix)
        self._last_pixmap = pix

    @QtCore.pyqtSlot(list, str, float)
    def play_video_frames(self, frames, audio_path, audio_duration):
        self.stop_idle_video()  # 只在此处暂停idle
        self.stop_sync()
        self.video_frames = frames
        self.video_frame_count = len(frames)
        self.target_fps = 25
        self.audio_total_ms = int(audio_duration * 1000)
        self.audio_player.play(audio_path)
        self.sync_timer.start(20)

    def _sync_frame_with_audio(self):
        ms = self.audio_player.player.position()
        if ms <= 0:
            return
        idx = int(ms * self.target_fps / 1000)
        idx = min(idx, self.video_frame_count - 1)
        if 0 <= idx < self.video_frame_count:
            frame = self.video_frames[idx]
            self.show_video_frame(frame)
        if ms >= self.audio_total_ms - 20 or idx >= self.video_frame_count - 1:
            self.sync_timer.stop()
            self.idle_signal.emit()

    # ----------- 语音识别部分（参考修正版） -----------
    @QtCore.pyqtSlot(str)
    def on_asr_text(self, text):
        if self.busy:
            self.stage_signal.emit("正在播报回答，请稍后再提问。")
            return
        text = text.strip()
        if text:
            self.input_box.setText(text)
            self.on_submit()

    @QtCore.pyqtSlot(bool, bool)
    def update_asr_status(self, running, wake):
        self.asr_running = running
        self._asr_wake = wake
        s = f"语音识别：{'开启' if running else '关闭'} | 唤醒：{'已唤醒' if wake else '未唤醒'}"
        color = "#4faaff" if running else "#555"
        wcolor = "#ff5050" if wake else "#4faaff"
        self.asr_status_label.setText(s)
        self.asr_status_label.setStyleSheet(f"color:{wcolor if wake else color}; font-size:15px; font-weight:bold; background:transparent;")
        if not running:
            self.asr_btn.setChecked(False)
            self.asr_btn.setText("🎤")
        else:
            self.asr_btn.setChecked(True)
            self.asr_btn.setText("⏹")

    def on_toggle_asr(self):
        if self.asr_running:
            self._stop_asr()
        else:
            self._start_asr()

    def _start_asr(self):
        if self.asr_running:
            return
        self.asr_running = True
        self.asr_btn.setText("⏹")
        self.asr_status_signal.emit(True, False)
        self.asr_thread = threading.Thread(target=self.start_asr, daemon=True)
        self.asr_thread.start()

    def _stop_asr(self):
        if not self.asr_running:
            return
        self.asr_running = False
        self.asr_btn.setText("🎤")
        self.asr_status_signal.emit(False, False)

    def start_asr(self):
        def asr_callback(text, wake_state):
            self.asr_status_signal.emit(True, wake_state)
            if text and wake_state and not self.busy:
                self.asr_text_signal.emit(text)
        try:
            run_asr_thread(asr_callback, lambda: self.asr_running)
        except Exception as e:
            self.asr_status_signal.emit(False, False)
    # ----------- 语音识别部分 END -----------

    def on_submit(self):
        if self.busy:
            self.stage_signal.emit("正在播报上一个回答，请稍后...")
            return
        question = self.input_box.text().strip()
        if not question:
            return
        self.input_box.setText("")
        self.append_history("用户", question)
        self.show_stage("开始处理...")
        self.busy = True
        threading.Thread(target=self.process_conversation, args=(question,)).start()

    def process_conversation(self, question):
        t0 = time.perf_counter()
        self.stage_signal.emit("等待大模型回复...")
        t1 = time.perf_counter()
        answer = self.bot.chat(question)
        t2 = time.perf_counter()
        self.append_history_signal.emit("助手", answer)
        self.stage_signal.emit(f"大模型回复完成，耗时：{t2-t1:.2f}s")
        self.stage_signal.emit("正在合成语音...")

        if not answer or len(answer.strip()) < 2:
            self.stage_signal.emit("回答内容过短，跳过语音与口型合成。")
            time.sleep(0.5)
            self.idle_signal.emit()
            self.busy = False
            return

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        t3 = time.perf_counter()
        try:
            audio_path = loop.run_until_complete(generate_speech(answer))
        except Exception as e:
            self.stage_signal.emit(f"语音合成失败：{e}")
            self.idle_signal.emit()
            self.busy = False
            return
        t4 = time.perf_counter()
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 800:
            self.stage_signal.emit("语音文件生成失败或内容太短，跳过口型合成。")
            self.idle_signal.emit()
            self.busy = False
            return
        import soundfile as sf
        audio_info = sf.info(audio_path)
        audio_duration = float(audio_info.duration)
        self.stage_signal.emit(f"语音合成完成，耗时：{t4-t3:.2f}s")
        self.stage_signal.emit("正在生成嘴型动画...")

        t5 = time.perf_counter()
        gen = prepare_audio_batches(audio_path, self.face_img, self.face_coords)
        infer_time, all_frames = self.lip_player.infer_frames(gen)
        self.stage_signal.emit(f"视频帧推理完成，耗时{infer_time:.2f}s")
        self.stage_signal.emit("正在播放语音与动画...")

        self.play_video_frames_signal.emit(all_frames, audio_path, audio_duration)

        t6 = time.perf_counter()
        self.busy = False

if __name__ == "__main__":
    os.environ["QT_FONT_DPI"] = "96"
    if sys.platform == "win32":
        import ctypes
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("digitalhuman.app")
    sys.stdout.reconfigure(encoding='utf-8')
    app = QtWidgets.QApplication(sys.argv)
    win = DigitalHumanUI()
    win.show()
    sys.exit(app.exec_())

In [None]:
import sys
import threading
import asyncio
import time
import numpy as np
import cv2
import os
from PyQt5 import QtWidgets, QtCore, QtGui, QtMultimedia

from config import (
    API_KEY, BASE_URL, WAV2LIP_MODEL_PATH, IDLE_VIDEO_PATH, FACE_IMAGE_PATH, DEVICE,
    CHAT_FONT_SIZE,  # 左侧聊天区字体大小
    STAGE_FONT_SIZE  # 右侧右栏状态区字体大小
)
from llm import ChatBot
from tts import generate_speech
from face import load_model, preprocess_image, prepare_audio_batches, LipSyncPlayer
from asr import run_asr_thread

class BubbleTextEdit(QtWidgets.QTextEdit):
    def __init__(self, *args, font_size=16, **kwargs):
        super().__init__(*args, **kwargs)
        self.setReadOnly(True)
        self.setFont(QtGui.QFont("微软雅黑", font_size))
        # 优雅白色背景，滚动条也更清新
        self.setStyleSheet(f"""
            QTextEdit {{
                background: #ffffff;
                border-radius: 18px;
                padding: 12px 10px 12px 10px;
                font-size: {font_size}px;
                color: #222;
                border: 1.5px solid #e2e2e2;
            }}
            QScrollBar:vertical {{
                background: #f5f6fa;
                width: 10px;
                margin: 2px 0 2px 0;
                border-radius: 4px;
            }}
            QScrollBar::handle:vertical {{
                background: #e2e2e2;
                border-radius: 5px;
                min-height: 20px;
            }}
        """)

    def append_bubble(self, text, speaker="用户"):
        # 更高级的气泡，带渐变、投影、圆角、淡色边框
        if speaker == "用户":
            self.append(
                "<div style='margin:12px 0; text-align:right;'>"
                "<span style=\""
                "background: linear-gradient(135deg, #7ed6ff 0%, #81ecec 100%);"
                "color:#2176ae;"
                "border-radius:20px 4px 20px 20px;"
                "border: 1.5px solid #b6eaff;"
                "box-shadow: 0 4px 18px rgba(126,214,255,0.11);"
                "padding:14px 26px;"
                "font-weight:500;"
                "display:inline-block;"
                "max-width:67%;"
                "line-height:1.8;"
                "word-break:break-all;"
                "transition: background 0.2s;"
                "\">"
                f"{text}"
                "</span>"
                "</div>")
        else:
            self.append(
                "<div style='margin:12px 0; text-align:left;'>"
                "<span style=\""
                "background: linear-gradient(135deg, #f5f6fa 0%, #e9e9e9 100%);"
                "color:#222;"
                "border-radius:4px 20px 20px 20px;"
                "border: 1.5px solid #e2e2e2;"
                "box-shadow: 0 4px 18px rgba(225,225,225,0.10);"
                "padding:14px 26px;"
                "font-weight:500;"
                "display:inline-block;"
                "max-width:67%;"
                "line-height:1.8;"
                "word-break:break-all;"
                "transition: background 0.2s;"
                "\">"
                f"{text}"
                "</span>"
                "</div>")

class BubbleStageEdit(QtWidgets.QTextEdit):
    """右侧状态区的气泡文本框"""
    def __init__(self, *args, font_size=16, **kwargs):
        super().__init__(*args, **kwargs)
        self.setReadOnly(True)
        self.setFont(QtGui.QFont("微软雅黑", font_size))
        self.setStyleSheet(f"""
            QTextEdit {{
                background: #ffffff;
                border-radius: 18px;
                padding: 12px 10px 12px 10px;
                font-size: {font_size}px;
                color: #2176ae;
                border: 1.5px solid #e2e2e2;
            }}
            QScrollBar:vertical {{
                background: #f5f6fa;
                width: 10px;
                margin: 2px 0 2px 0;
                border-radius: 4px;
            }}
            QScrollBar::handle:vertical {{
                background: #e2e2e2;
                border-radius: 5px;
                min-height: 20px;
            }}
        """)

    def append_bubble(self, text, speaker="助手"):
        # 右侧气泡更精致，带蓝色微渐变和浅灰
        if speaker == "助手":
            self.append(
                "<div style='margin:12px 0; text-align:left;'>"
                "<span style=\""
                "background: linear-gradient(135deg, #eaf6ff 0%, #d0e2ff 100%);"
                "color:#2176ae;"
                "border-radius:4px 20px 20px 20px;"
                "border: 1.5px solid #b1d2fa;"
                "box-shadow: 0 4px 18px rgba(161,206,255,0.13);"
                "padding:14px 26px;"
                "font-weight:500;"
                "display:inline-block;"
                "max-width:75%;"
                "line-height:1.8;"
                "word-break:break-all;"
                "transition: background 0.2s;"
                "\">"
                f"{text}"
                "</span>"
                "</div>")
        else:
            self.append(
                "<div style='margin:12px 0; text-align:right;'>"
                "<span style=\""
                "background: linear-gradient(135deg, #f5f6fa 0%, #e9e9e9 100%);"
                "color:#222;"
                "border-radius:20px 4px 20px 20px;"
                "border: 1.5px solid #e2e2e2;"
                "box-shadow: 0 4px 18px rgba(225,225,225,0.10);"
                "padding:14px 26px;"
                "font-weight:500;"
                "display:inline-block;"
                "max-width:75%;"
                "line-height:1.8;"
                "word-break:break-all;"
                "transition: background 0.2s;"
                "\">"
                f"{text}"
                "</span>"
                "</div>")

class CardFrame(QtWidgets.QFrame):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("""
            QFrame {
                background: rgba(34,34,34,0.92);
                border-radius: 22px;
                border: 1.5px solid #232323;
            }
        """)

class FaceDisplayWidget(QtWidgets.QLabel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setStyleSheet("background:transparent;")
        self.setAlignment(QtCore.Qt.AlignCenter)
        self.setScaledContents(True)

class AudioPlayer(QtCore.QObject):
    finished = QtCore.pyqtSignal()
    def __init__(self, parent=None):
        super().__init__(parent)
        self.player = QtMultimedia.QMediaPlayer()
        self.player.setVolume(100)
        self.player.mediaStatusChanged.connect(self.handle_status)

    def play(self, audio_path):
        if self.player.state() == QtMultimedia.QMediaPlayer.PlayingState:
            self.player.stop()
        url = QtCore.QUrl.fromLocalFile(os.path.abspath(audio_path))
        self.player.setMedia(QtMultimedia.QMediaContent(url))
        self.player.play()

    def handle_status(self, status):
        if status in (QtMultimedia.QMediaPlayer.EndOfMedia, QtMultimedia.QMediaPlayer.InvalidMedia):
            self.finished.emit()

class DigitalHumanUI(QtWidgets.QWidget):
    append_history_signal = QtCore.pyqtSignal(str, str)
    show_frame_signal = QtCore.pyqtSignal(QtGui.QPixmap)
    idle_signal = QtCore.pyqtSignal()
    stage_signal = QtCore.pyqtSignal(str)
    asr_text_signal = QtCore.pyqtSignal(str)
    asr_status_signal = QtCore.pyqtSignal(bool, bool)
    play_video_frames_signal = QtCore.pyqtSignal(list, str, float)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("数字人问答演示")
        self.setStyleSheet("QWidget { background: #000; }")
        self.resize(1400, 900)
        self.asr_running = False
        self.asr_thread = None
        self._asr_wake = False
        self.busy = False
        self.idle_video_thread = None
        self.idle_video_running = threading.Event()
        self._last_pixmap = None
        self._last_face_size = (0, 0)
        self.audio_player = AudioPlayer()
        self.sync_timer = QtCore.QTimer(self)
        self.sync_timer.timeout.connect(self._sync_frame_with_audio)
        self.video_frames = []
        self.video_frame_count = 0
        self.target_fps = 25
        self.audio_total_ms = 0
        self.init_ui()
        self.init_resources()
        self.append_history_signal.connect(self.append_history)
        self.show_frame_signal.connect(self._show_pixmap_mainthread)
        self.idle_signal.connect(self.show_idle)
        self.stage_signal.connect(self.show_stage)
        self.asr_text_signal.connect(self.on_asr_text)
        self.asr_status_signal.connect(self.update_asr_status)
        self.play_video_frames_signal.connect(self.play_video_frames)
        self.show_idle()
        self.show_stage("系统待命...")

    def init_ui(self):
        self.outer_layout = QtWidgets.QVBoxLayout(self)
        self.outer_layout.setContentsMargins(0, 0, 0, 0)
        self.outer_layout.setSpacing(0)
        self.top_panel = QtWidgets.QWidget(self)
        self.top_layout = QtWidgets.QHBoxLayout(self.top_panel)
        self.top_layout.setContentsMargins(0, 0, 0, 0)
        self.top_layout.setSpacing(0)

        self.left_panel = QtWidgets.QWidget(self.top_panel)
        self.left_layout = QtWidgets.QVBoxLayout(self.left_panel)
        self.left_layout.setContentsMargins(24, 24, 12, 24)
        self.left_layout.setSpacing(18)
        self.chat_history = BubbleTextEdit(parent=self.left_panel, font_size=CHAT_FONT_SIZE)
        self.left_layout.addWidget(self.chat_history)
        self.left_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.left_panel)

        self.center_panel = QtWidgets.QWidget(self.top_panel)
        self.center_panel.setStyleSheet("background:transparent;")
        self.face_label = FaceDisplayWidget(self.center_panel)
        self.center_layout = QtWidgets.QVBoxLayout(self.center_panel)
        self.center_layout.setContentsMargins(0, 0, 0, 0)
        self.center_layout.setAlignment(QtCore.Qt.AlignCenter)
        self.center_layout.addStretch()
        self.center_layout.addWidget(self.face_label, alignment=QtCore.Qt.AlignHCenter | QtCore.Qt.AlignVCenter)
        self.center_layout.addStretch()
        self.top_layout.addWidget(self.center_panel)

        self.right_panel = QtWidgets.QWidget(self.top_panel)
        self.right_layout = QtWidgets.QVBoxLayout(self.right_panel)
        self.right_layout.setContentsMargins(12, 24, 24, 24)
        self.right_layout.setSpacing(18)
        self.stage_card = CardFrame(self.right_panel)
        # 右侧用气泡文本
        self.stage_label = BubbleStageEdit(self.stage_card, font_size=STAGE_FONT_SIZE)
        self.stage_label.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        vbox = QtWidgets.QVBoxLayout(self.stage_card)
        vbox.addWidget(self.stage_label)
        vbox.setContentsMargins(10, 10, 10, 10)
        self.right_layout.addWidget(self.stage_card)
        self.right_panel.setStyleSheet("background:transparent;")
        self.top_layout.addWidget(self.right_panel)

        self.outer_layout.addWidget(self.top_panel, stretch=10)

        self.input_panel = QtWidgets.QWidget(self)
        input_layout = QtWidgets.QHBoxLayout(self.input_panel)
        input_layout.setContentsMargins(48, 0, 48, 18)
        input_layout.setSpacing(16)
        self.input_box = QtWidgets.QLineEdit(self.input_panel)
        self.input_box.setPlaceholderText("请输入你的问题或说话（支持语音唤醒）...")
        self.input_box.setStyleSheet("""
            font-size:17px; border-radius:12px; padding:10px; background:#181a1b;
            border:2px solid #232323; color:#e6e6e6;
        """)
        self.send_btn = QtWidgets.QPushButton("发送", self.input_panel)
        self.send_btn.setStyleSheet("""
            font-size:17px; padding:10px 28px; background:#4faaff; color:#15181a; border-radius:10px;
            font-weight:bold;
        """)
        self.send_btn.clicked.connect(self.on_submit)
        self.asr_btn = QtWidgets.QPushButton("🎤", self.input_panel)
        self.asr_btn.setStyleSheet("""
            font-size:24px; padding:10px 16px; background:#232323; color:#4faaff; border-radius:50%;
            border:2px solid #232323;
        """)
        self.asr_btn.setCheckable(True)
        self.asr_btn.clicked.connect(self.on_toggle_asr)

        input_layout.addWidget(self.input_box, 10)
        input_layout.addWidget(self.send_btn, 2)
        input_layout.addWidget(self.asr_btn, 1)
        self.input_panel.setStyleSheet("background:transparent;")
        self.outer_layout.addWidget(self.input_panel, stretch=0)

        self.asr_status_label = QtWidgets.QLabel("语音识别：关闭 | 唤醒：未唤醒", self)
        self.asr_status_label.setStyleSheet("color:#4faaff; font-size:15px; font-weight:bold; background:transparent;")
        self.asr_status_label.setAlignment(QtCore.Qt.AlignRight)
        self.asr_status_label.setFixedHeight(24)
        self.outer_layout.addWidget(self.asr_status_label, alignment=QtCore.Qt.AlignRight)

        self.top_layout.setStretch(0, 1)
        self.top_layout.setStretch(1, 1)
        self.top_layout.setStretch(2, 1)
        self.outer_layout.setStretch(0, 15)
        self.outer_layout.setStretch(1, 0)
        self.outer_layout.setStretch(2, 0)
        self.resizeEvent = self.on_resize

    def update_center_panel_geometry(self):
        w = self.center_panel.width()
        h = self.center_panel.height()
        self.face_label.setMinimumSize(w, h)
        self.face_label.setMaximumSize(w, h)
        self._last_face_size = (w, h)
        if self._last_pixmap is not None:
            self._show_pixmap_mainthread(self._last_pixmap)

    def on_resize(self, event):
        self.update_center_panel_geometry()
        event.accept()

    def init_resources(self):
        self.bot = ChatBot(
            api_key=API_KEY,
            base_url=BASE_URL,
            log_dir="logs",
            default_background="你是一个知识渊博的助手，能够简洁地回答问题。",
            default_prefix="请简洁地回答下述问题："
        )
        self.model = load_model(WAV2LIP_MODEL_PATH, DEVICE)
        self.face_img, self.face_coords, self.orig_image = preprocess_image(FACE_IMAGE_PATH, device=DEVICE)
        self.idle_video_path = IDLE_VIDEO_PATH
        self.lip_player = LipSyncPlayer(self.model, DEVICE, self.orig_image, self.face_coords, fps=25)

    def show_idle(self):
        self.stop_sync()
        self.idle_video_running.set()
        if self.idle_video_thread is None or not self.idle_video_thread.is_alive():
            self.idle_video_thread = threading.Thread(target=self.play_idle_video, daemon=True)
            self.idle_video_thread.start()

    def play_idle_video(self):
        while self.idle_video_running.is_set():
            cap = cv2.VideoCapture(self.idle_video_path)
            if not cap.isOpened():
                print(f"无法打开idle视频：{self.idle_video_path}")
                return
            while self.idle_video_running.is_set():
                ret, frame = cap.read()
                if not ret:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
                    continue
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                w, h = self._last_face_size
                if w < 10 or h < 10:
                    w, h = 300, 400
                frame = cv2.resize(frame, (w, h))
                qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], frame.strides[0], QtGui.QImage.Format_RGB888)
                pix = QtGui.QPixmap.fromImage(qtimg)
                self._last_pixmap = pix
                self.show_frame_signal.emit(pix)
                time.sleep(1.0 / 25)
            cap.release()

    def stop_idle_video(self):
        self.idle_video_running.clear()

    def stop_sync(self):
        if self.sync_timer.isActive():
            self.sync_timer.stop()
        self.video_frames = []
        self.video_frame_count = 0
        self.audio_total_ms = 0

    @QtCore.pyqtSlot(QtGui.QPixmap)
    def _show_pixmap_mainthread(self, pix):
        w, h = self._last_face_size
        scaled = pix.scaled(w, h, QtCore.Qt.IgnoreAspectRatio, QtCore.Qt.SmoothTransformation)
        self.face_label.setPixmap(scaled)
        self._last_pixmap = pix

    @QtCore.pyqtSlot(str, str)
    def append_history(self, speaker, text):
        self.chat_history.append_bubble(text, speaker)

    @QtCore.pyqtSlot(str)
    def show_stage(self, text):
        self.stage_label.append_bubble(text, "助手")
        self.stage_label.verticalScrollBar().setValue(self.stage_label.verticalScrollBar().maximum())

    @QtCore.pyqtSlot(np.ndarray)
    def show_video_frame(self, frame):
        w, h = self._last_face_size
        # 确保frame为RGB
        if len(frame.shape) == 2 or frame.shape[2] == 1:
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
        elif frame.shape[2] == 3:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (w, h))
        qtimg = QtGui.QImage(frame.data, frame.shape[1], frame.shape[0], frame.strides[0], QtGui.QImage.Format_RGB888)
        pix = QtGui.QPixmap.fromImage(qtimg)
        self.face_label.setPixmap(pix)
        self._last_pixmap = pix

    @QtCore.pyqtSlot(list, str, float)
    def play_video_frames(self, frames, audio_path, audio_duration):
        self.stop_idle_video()  # 只在此处暂停idle
        self.stop_sync()
        self.video_frames = frames
        self.video_frame_count = len(frames)
        self.target_fps = 25
        self.audio_total_ms = int(audio_duration * 1000)
        self.audio_player.play(audio_path)
        self.sync_timer.start(20)

    def _sync_frame_with_audio(self):
        ms = self.audio_player.player.position()
        if ms <= 0:
            return
        idx = int(ms * self.target_fps / 1000)
        idx = min(idx, self.video_frame_count - 1)
        if 0 <= idx < self.video_frame_count:
            frame = self.video_frames[idx]
            self.show_video_frame(frame)
        if ms >= self.audio_total_ms - 20 or idx >= self.video_frame_count - 1:
            self.sync_timer.stop()
            self.idle_signal.emit()

    # ----------- 语音识别部分（参考修正版） -----------
    @QtCore.pyqtSlot(str)
    def on_asr_text(self, text):
        if self.busy:
            self.stage_signal.emit("正在播报回答，请稍后再提问。")
            return
        text = text.strip()
        if text:
            self.input_box.setText(text)
            self.on_submit()

    @QtCore.pyqtSlot(bool, bool)
    def update_asr_status(self, running, wake):
        self.asr_running = running
        self._asr_wake = wake
        s = f"语音识别：{'开启' if running else '关闭'} | 唤醒：{'已唤醒' if wake else '未唤醒'}"
        color = "#4faaff" if running else "#555"
        wcolor = "#ff5050" if wake else "#4faaff"
        self.asr_status_label.setText(s)
        self.asr_status_label.setStyleSheet(f"color:{wcolor if wake else color}; font-size:15px; font-weight:bold; background:transparent;")
        if not running:
            self.asr_btn.setChecked(False)
            self.asr_btn.setText("🎤")
        else:
            self.asr_btn.setChecked(True)
            self.asr_btn.setText("⏹")

    def on_toggle_asr(self):
        if self.asr_running:
            self._stop_asr()
        else:
            self._start_asr()

    def _start_asr(self):
        if self.asr_running:
            return
        self.asr_running = True
        self.asr_btn.setText("⏹")
        self.asr_status_signal.emit(True, False)
        self.asr_thread = threading.Thread(target=self.start_asr, daemon=True)
        self.asr_thread.start()

    def _stop_asr(self):
        if not self.asr_running:
            return
        self.asr_running = False
        self.asr_btn.setText("🎤")
        self.asr_status_signal.emit(False, False)

    def start_asr(self):
        def asr_callback(text, wake_state):
            self.asr_status_signal.emit(True, wake_state)
            if text and wake_state and not self.busy:
                self.asr_text_signal.emit(text)
        try:
            run_asr_thread(asr_callback, lambda: self.asr_running)
        except Exception as e:
            self.asr_status_signal.emit(False, False)
    # ----------- 语音识别部分 END -----------

    def on_submit(self):
        if self.busy:
            self.stage_signal.emit("正在播报上一个回答，请稍后...")
            return
        question = self.input_box.text().strip()
        if not question:
            return
        self.input_box.setText("")
        self.append_history("用户", question)
        self.show_stage("开始处理...")
        self.busy = True
        threading.Thread(target=self.process_conversation, args=(question,)).start()

    def process_conversation(self, question):
        t0 = time.perf_counter()
        self.stage_signal.emit("等待大模型回复...")
        t1 = time.perf_counter()
        answer = self.bot.chat(question)
        t2 = time.perf_counter()
        self.append_history_signal.emit("助手", answer)
        self.stage_signal.emit(f"大模型回复完成，耗时：{t2-t1:.2f}s")
        self.stage_signal.emit("正在合成语音...")

        if not answer or len(answer.strip()) < 2:
            self.stage_signal.emit("回答内容过短，跳过语音与口型合成。")
            time.sleep(0.5)
            self.idle_signal.emit()
            self.busy = False
            return

        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        t3 = time.perf_counter()
        try:
            audio_path = loop.run_until_complete(generate_speech(answer))
        except Exception as e:
            self.stage_signal.emit(f"语音合成失败：{e}")
            self.idle_signal.emit()
            self.busy = False
            return
        t4 = time.perf_counter()
        if not audio_path or not os.path.exists(audio_path) or os.path.getsize(audio_path) < 800:
            self.stage_signal.emit("语音文件生成失败或内容太短，跳过口型合成。")
            self.idle_signal.emit()
            self.busy = False
            return
        import soundfile as sf
        audio_info = sf.info(audio_path)
        audio_duration = float(audio_info.duration)
        self.stage_signal.emit(f"语音合成完成，耗时：{t4-t3:.2f}s")
        self.stage_signal.emit("正在生成嘴型动画...")

        t5 = time.perf_counter()
        gen = prepare_audio_batches(audio_path, self.face_img, self.face_coords)
        infer_time, all_frames = self.lip_player.infer_frames(gen)
        self.stage_signal.emit(f"视频帧推理完成，耗时{infer_time:.2f}s")
        self.stage_signal.emit("正在播放语音与动画...")

        self.play_video_frames_signal.emit(all_frames, audio_path, audio_duration)

        t6 = time.perf_counter()
        self.busy = False

if __name__ == "__main__":
    os.environ["QT_FONT_DPI"] = "96"
    if sys.platform == "win32":
        import ctypes
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("digitalhuman.app")
    sys.stdout.reconfigure(encoding='utf-8')
    app = QtWidgets.QApplication(sys.argv)
    win = DigitalHumanUI()
    win.show()
    sys.exit(app.exec_())