Skip to content

Commit

Permalink
fix index training
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed Apr 23, 2023
1 parent a9572eb commit e79ffea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 35 deletions.
74 changes: 42 additions & 32 deletions lib/rvc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,56 @@
from . import commons, utils
from .checkpoints import save
from .config import DatasetMetadata, TrainConfig
from .data_utils import (DistributedBucketSampler, TextAudioCollate,
TextAudioCollateMultiNSFsid, TextAudioLoader,
TextAudioLoaderMultiNSFsid)
from .data_utils import (
DistributedBucketSampler,
TextAudioCollate,
TextAudioCollateMultiNSFsid,
TextAudioLoader,
TextAudioLoaderMultiNSFsid,
)
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid,
SynthesizerTrnMs256NSFSidNono)
from .models import (
MultiPeriodDiscriminator,
SynthesizerTrnMs256NSFSid,
SynthesizerTrnMs256NSFSidNono,
)


def glob_dataset(glob_str: str, speaker_id: int):
globs = glob_str.split(",")
datasets_speakers = []
if os.path.isdir(glob_str):
files = os.listdir(glob_str)
# pattern: {glob_str}/{decimal}[_]* and isdir
dirs = [
(os.path.join(glob_str, f), int(f.split("_")[0]))
for f in files
if os.path.isdir(os.path.join(glob_str, f)) and f.split("_")[0].isdecimal()
]

if len(dirs) > 0:
# multi speakers at once train
match_files_re = re.compile(r".+\.(wav|flac)") # matches .wav and .flac
datasets_speakers = [
(file, dir[1])
for dir in dirs
for file in glob.iglob(os.path.join(dir[0], "*"), recursive=True)
if match_files_re.search(file)
for glob_str in globs:
if os.path.isdir(glob_str):
files = os.listdir(glob_str)
# pattern: {glob_str}/{decimal}[_]* and isdir
dirs = [
(os.path.join(glob_str, f), int(f.split("_")[0]))
for f in files
if os.path.isdir(os.path.join(glob_str, f))
and f.split("_")[0].isdecimal()
]
# for dir in dirs:
# for file in glob.iglob(dir[0], recursive=True):
# if match_files_re.search(file):
# datasets_speakers.append((file, dirs[1]))
# return sorted(datasets_speakers, key=operator.itemgetter(0))

glob_str = os.path.join(glob_str, "*.wav")

datasets_speakers.extend(
[(file, speaker_id) for file in glob.iglob(glob_str, recursive=True)]
)
if len(dirs) > 0:
# multi speakers at once train
match_files_re = re.compile(r".+\.(wav|flac)") # matches .wav and .flac
datasets_speakers = [
(file, dir[1])
for dir in dirs
for file in glob.iglob(os.path.join(dir[0], "*"), recursive=True)
if match_files_re.search(file)
]
# for dir in dirs:
# for file in glob.iglob(dir[0], recursive=True):
# if match_files_re.search(file):
# datasets_speakers.append((file, dirs[1]))
# return sorted(datasets_speakers, key=operator.itemgetter(0))

glob_str = os.path.join(glob_str, "*.wav")

datasets_speakers.extend(
[(file, speaker_id) for file in glob.iglob(glob_str, recursive=True)]
)

return sorted(datasets_speakers, key=operator.itemgetter(0))

Expand Down
8 changes: 5 additions & 3 deletions modules/tabs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lib.rvc.preprocessing import extract_f0, extract_feature, split
from lib.rvc.train import create_dataset_meta, glob_dataset, train_index, train_model
from modules import models, utils
from modules.shared import MODELS_DIR, device
from modules.shared import MODELS_DIR
from modules.ui import Tab

SR_DICT = {
Expand All @@ -31,6 +31,7 @@ def train_index_only(
f0,
dataset_glob,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
Expand All @@ -42,6 +43,7 @@ def train_index_only(
f0 = f0 == "Yes"
norm_audio_when_preprocess = norm_audio_when_preprocess == "Yes"
training_dir = os.path.join(MODELS_DIR, "training", "models", model_name)
gpu_ids = [int(x.strip()) for x in gpu_id.split(",")]
yield f"Training directory: {training_dir}"

if os.path.exists(training_dir) and ignore_cache:
Expand Down Expand Up @@ -78,8 +80,7 @@ def train_index_only(
embedder_filepath,
embedder_load_from,
embedding_channels == 768,
None,
device,
gpu_ids,
)

out_dir = os.path.join(MODELS_DIR, "checkpoints")
Expand Down Expand Up @@ -304,6 +305,7 @@ def train_all(
f0,
dataset_glob,
speaker_id,
gpu_id,
num_cpu_process,
norm_audio_when_preprocess,
pitch_extraction_algo,
Expand Down

0 comments on commit e79ffea

Please sign in to comment.