diff --git a/lib/rvc/train.py b/lib/rvc/train.py index babe473..d3beaad 100644 --- a/lib/rvc/train.py +++ b/lib/rvc/train.py @@ -22,46 +22,56 @@ from . import commons, utils from .checkpoints import save from .config import DatasetMetadata, TrainConfig -from .data_utils import (DistributedBucketSampler, TextAudioCollate, - TextAudioCollateMultiNSFsid, TextAudioLoader, - TextAudioLoaderMultiNSFsid) +from .data_utils import ( + DistributedBucketSampler, + TextAudioCollate, + TextAudioCollateMultiNSFsid, + TextAudioLoader, + TextAudioLoaderMultiNSFsid, +) from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch -from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid, - SynthesizerTrnMs256NSFSidNono) +from .models import ( + MultiPeriodDiscriminator, + SynthesizerTrnMs256NSFSid, + SynthesizerTrnMs256NSFSidNono, +) def glob_dataset(glob_str: str, speaker_id: int): + globs = glob_str.split(",") datasets_speakers = [] - if os.path.isdir(glob_str): - files = os.listdir(glob_str) - # pattern: {glob_str}/{decimal}[_]* and isdir - dirs = [ - (os.path.join(glob_str, f), int(f.split("_")[0])) - for f in files - if os.path.isdir(os.path.join(glob_str, f)) and f.split("_")[0].isdecimal() - ] - - if len(dirs) > 0: - # multi speakers at once train - match_files_re = re.compile(r".+\.(wav|flac)") # matches .wav and .flac - datasets_speakers = [ - (file, dir[1]) - for dir in dirs - for file in glob.iglob(os.path.join(dir[0], "*"), recursive=True) - if match_files_re.search(file) + for glob_str in globs: + if os.path.isdir(glob_str): + files = os.listdir(glob_str) + # pattern: {glob_str}/{decimal}[_]* and isdir + dirs = [ + (os.path.join(glob_str, f), int(f.split("_")[0])) + for f in files + if os.path.isdir(os.path.join(glob_str, f)) + and f.split("_")[0].isdecimal() ] - # for dir in dirs: - # for file in glob.iglob(dir[0], recursive=True): - # if match_files_re.search(file): - # datasets_speakers.append((file, dirs[1])) - # return sorted(datasets_speakers, key=operator.itemgetter(0)) - glob_str = os.path.join(glob_str, "*.wav") - - datasets_speakers.extend( - [(file, speaker_id) for file in glob.iglob(glob_str, recursive=True)] - ) + if len(dirs) > 0: + # multi speakers at once train + match_files_re = re.compile(r".+\.(wav|flac)") # matches .wav and .flac + datasets_speakers = [ + (file, dir[1]) + for dir in dirs + for file in glob.iglob(os.path.join(dir[0], "*"), recursive=True) + if match_files_re.search(file) + ] + # for dir in dirs: + # for file in glob.iglob(dir[0], recursive=True): + # if match_files_re.search(file): + # datasets_speakers.append((file, dirs[1])) + # return sorted(datasets_speakers, key=operator.itemgetter(0)) + + glob_str = os.path.join(glob_str, "*.wav") + + datasets_speakers.extend( + [(file, speaker_id) for file in glob.iglob(glob_str, recursive=True)] + ) return sorted(datasets_speakers, key=operator.itemgetter(0)) diff --git a/modules/tabs/training.py b/modules/tabs/training.py index ae74d14..0d3f398 100644 --- a/modules/tabs/training.py +++ b/modules/tabs/training.py @@ -7,7 +7,7 @@ from lib.rvc.preprocessing import extract_f0, extract_feature, split from lib.rvc.train import create_dataset_meta, glob_dataset, train_index, train_model from modules import models, utils -from modules.shared import MODELS_DIR, device +from modules.shared import MODELS_DIR from modules.ui import Tab SR_DICT = { @@ -31,6 +31,7 @@ def train_index_only( f0, dataset_glob, speaker_id, + gpu_id, num_cpu_process, norm_audio_when_preprocess, pitch_extraction_algo, @@ -42,6 +43,7 @@ def train_index_only( f0 = f0 == "Yes" norm_audio_when_preprocess = norm_audio_when_preprocess == "Yes" training_dir = os.path.join(MODELS_DIR, "training", "models", model_name) + gpu_ids = [int(x.strip()) for x in gpu_id.split(",")] yield f"Training directory: {training_dir}" if os.path.exists(training_dir) and ignore_cache: @@ -78,8 +80,7 @@ def train_index_only( embedder_filepath, embedder_load_from, embedding_channels == 768, - None, - device, + gpu_ids, ) out_dir = os.path.join(MODELS_DIR, "checkpoints") @@ -304,6 +305,7 @@ def train_all( f0, dataset_glob, speaker_id, + gpu_id, num_cpu_process, norm_audio_when_preprocess, pitch_extraction_algo,