diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py index 00ba1f9e0..99cd9928d 100644 --- a/diffsynth/core/data/operators.py +++ b/diffsynth/core/data/operators.py @@ -2,8 +2,6 @@ import torch, torchvision, imageio, os import imageio.v3 as iio from PIL import Image -import torchaudio -from diffsynth.utils.data.audio import read_audio class DataProcessingPipeline: @@ -249,9 +247,11 @@ def __call__(self, data): class LoadAudio(DataProcessingOperator): def __init__(self, sr=16000): self.sr = sr - def __call__(self, data: str): import librosa - input_audio, sample_rate = librosa.load(data, sr=self.sr) + self.audio_loader = librosa.load + + def __call__(self, data: str): + input_audio, sample_rate = self.audio_loader(data, sr=self.sr) return input_audio @@ -259,13 +259,15 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin): def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True): FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate) + import torchaudio + self.audio_loader = torchaudio.load def __call__(self, data: str): try: reader = self.get_reader(data) num_frames = self.get_num_frames(reader) duration = num_frames / self.frame_rate - waveform, sample_rate = torchaudio.load(data) + waveform, sample_rate = self.audio_loader(data) target_samples = int(duration * sample_rate) current_samples = waveform.shape[-1] if current_samples > target_samples: @@ -285,10 +287,12 @@ def __init__(self, target_sample_rate=None, target_duration=None): self.target_sample_rate = target_sample_rate self.target_duration = target_duration self.resample = True if target_sample_rate is not None else False + from diffsynth.utils.data.audio import read_audio + self.audio_loader = read_audio def __call__(self, data: str): try: - waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate) + waveform, sample_rate = self.audio_loader(data, resample=self.resample, resample_rate=self.target_sample_rate) if self.target_duration is not None: target_samples = int(self.target_duration * sample_rate) current_samples = waveform.shape[-1] diff --git a/pyproject.toml b/pyproject.toml index ccfad82a2..86f8e5d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,8 +48,17 @@ npu = [ "torchvision==0.22.1+cpu" ] audio = [ + "av", "torchaudio", - "torchcodec" + "torchcodec", + "librosa" +] +all = [ + "av", + "torchaudio", + "torchcodec", + "librosa", + "streamlit" ] [tool.setuptools]