In [5]:
# Generating acoustic reconstruction data for postprocessing stages
CONFIG_PATH = "Models/Multi0_40/config.yml"
CKPT_PATH = "Models/Multi0_40/epoch_2nd_40_1c872.pth"
DIR_48KS = "Z:/"
DATA_OUTPUT_DIR = "acoustic_rec_data"
BATCH_SIZE = 2

from meldataset import build_dataloader_with_ref_48k
from models import *
from utils import *
from losses import *
import torch
import yaml
import os
from itertools import chain
from scipy.io.wavfile import write
from accelerate import Accelerator, DistributedDataParallelKwargs

with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)
dp = config['data_params']
sr = config['preprocess_params'].get('sr', 24000)

batch_size = BATCH_SIZE
log_dir = config['log_dir']
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_list, val_list = get_data_path_list(
    dp['train_data'], dp['val_data'])
val_path = dp['val_data']
root_path = dp['root_path']
min_length = dp['min_length']
OOD_data = dp['OOD_data']
max_len = 800

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(project_dir=log_dir,
    kwargs_handlers=[ddp_kwargs], mixed_precision='fp16')
device = accelerator.device

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

train_dataloader = build_dataloader_with_ref_48k(
    train_list, root_path, dir_48ks = DIR_48KS, OOD_data=OOD_data,
    min_length=min_length, batch_size=batch_size, num_workers=2,
    dataset_config={}, device=device)
val_dataloader = build_dataloader_with_ref_48k(
    val_list, root_path, dir_48ks = DIR_48KS, OOD_data=OOD_data,
    min_length=min_length, batch_size=batch_size, validation=True,
    num_workers=0, device=device, dataset_config={})

import logging
model_params = recursive_munch(config['model_params'])
multispeaker = model_params.multispeaker
model = build_model(model_params, text_aligner, pitch_extractor, plbert)

for k in model:
    model[k] = accelerator.prepare(model[k])
    model[k].eval()

try:
    n_down = model.text_aligner.module.n_down
except:
    n_down = model.text_aligner.n_down

print(f"Loading {CKPT_PATH}")
params_whole = torch.load(CKPT_PATH, map_location='cpu')
params = params_whole['net']

for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
#             except:
#                 _load(params[key], model[key])
_ = [model[key].eval() for key in model]

#print(max_len)
#print(model.diffusion.diffusion.sigma_data)

if not os.path.exists(DATA_OUTPUT_DIR):
    os.makedirs(DATA_OUTPUT_DIR, exist_ok=True)

for i, batch in enumerate(chain(val_dataloader, train_dataloader)):
    with torch.no_grad():
        waves = batch[0]
        waves_48k = batch[-1]
        batch = [b.to(device) for b in batch[1:-1]]
        texts, input_lengths, _, _, mels, mel_input_length, _ = batch

        with torch.no_grad():
            mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
            ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)

            s2s_attn = s2s_attn.transpose(-1, -2)
            s2s_attn = s2s_attn[..., 1:]
            s2s_attn = s2s_attn.transpose(-1, -2)

            text_mask = length_to_mask(input_lengths).to(texts.device)
            attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
            attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
            attn_mask = (attn_mask < 1)
            s2s_attn.masked_fill_(attn_mask, 0.0)

        # encode
        t_en = model.text_encoder(texts, input_lengths, text_mask)
        
        asr = (t_en @ s2s_attn)

        # get clips
        #mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
        #mel_len = min([int(mel_input_length.min().item() / 2 - 1), max_len // 2])
        mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
        mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
        mel_len_st = int(mel_input_length_all.min().item() / 2 - 1)
        
        en = []
        gt = []
        wav = []
        wav48k = []
        st = []

        for bib in range(len(mel_input_length)):
            mel_length = int(mel_input_length[bib].item() / 2)

            random_start = np.random.randint(0, mel_length - mel_len)
            en.append(asr[bib, :, random_start:random_start+mel_len])
            gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
            y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
            y48k = waves_48k[bib][
                (random_start * 2) * 600:((random_start+mel_len) * 2) * 600]

            wav.append(torch.from_numpy(y).to('cuda'))
            wav48k.append(torch.from_numpy(y48k).to('cuda'))

            random_start = np.random.randint(0, mel_length - mel_len_st)
            st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])

        wav = torch.stack(wav).float().detach()
        wav48k = torch.stack(wav48k).float().detach()

        en = torch.stack(en)
        gt = torch.stack(gt).detach()
        st = torch.stack(st).detach()

        F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
        # style encoder keeps outputting NaN?
        # st is not nan but style_encoder is
        with torch.autograd.set_detect_anomaly(True):
            s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
        real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
        #print(f"nan checks: F0_real {F0_real.isnan().any()} real_norm {real_norm.isnan().any()} s {s.isnan().any()}")
        print(f"en {en.max()} F0_real {F0_real.max()} real_norm {real_norm.max()} s {s.max()}")
        y_rec = model.decoder(en, F0_real, real_norm, s)
        print(f"nan check decoder: {y_rec.isnan().any()}")
        #print(y_rec.isnan().any())
        # Why are these NaN?

        y_rec = y_rec.squeeze(1)
        y_rec = y_rec.cpu().numpy()
        wav = wav.cpu().numpy()
        wav48k = wav48k.cpu().numpy()

        #print(wav.shape)
        #print(y_rec.shape)
        for i2 in range(wav.shape[0]):
            w48k = wav48k[i2,:]
            yr = y_rec[i2,:]
            print(yr.shape)
            print(yr.max())
            write(os.path.join(DATA_OUTPUT_DIR, f"gtru_ard_{i}_{i2}.wav"), 48000, w48k)
            write(os.path.join(DATA_OUTPUT_DIR, f"pred_ard_{i}_{i2}.wav"), sr, yr)

Loading Models/Multi0_40/epoch_2nd_40_1c872.pth
bert loaded
bert_encoder loaded
predictor loaded
decoder loaded
text_encoder loaded
predictor_encoder loaded
style_encoder loaded
diffusion loaded
text_aligner loaded
pitch_extractor loaded
mpd loaded
msd loaded
wd loaded
en 1.223442554473877 F0_real 621.5 real_norm 8.060370445251465 s 0.8671875
nan check decoder: False
(43800,)
0.20376517
(43800,)
0.14051805
en 1.398130178451538 F0_real 436.0 real_norm 7.806671142578125 s 0.88134765625
nan check decoder: False
(97800,)
0.2139518
(97800,)
0.35668683
en 1.3756866455078125 F0_real 613.5 real_norm 8.382726669311523 s 0.80615234375
nan check decoder: False
(66600,)
0.14601736
(66600,)
0.22770905
en 1.5676051378250122 F0_real 548.0 real_norm 8.450801849365234 s 0.71728515625
nan check decoder: False
(79200,)
0.25255433
(79200,)
0.24256547
en 1.2638163566589355 F0_real 623.5 real_norm 10.001133918762207 s 0.69140625
nan check decoder: False
(59400,)
0.5914138
(59400,)
0.42668417
en 1.0846501588

In [1]:
DATA_OUTPUT_DIR = "acoustic_rec_data"
DEMUCS_OUTPUT_DIR = "demucs_acoustic_rec_data"
import os
from pydub import AudioSegment
from functools import reduce
from tqdm import tqdm
assert(os.path.exists(DATA_OUTPUT_DIR))

path_pairs = []
for p in os.listdir(DATA_OUTPUT_DIR):
    if p.startswith('gtru'):
        gtru_path = p
        pred_path = gtru_path.replace('gtru', 'pred')
        assert(os.path.exists(os.path.join(DATA_OUTPUT_DIR,pred_path)))
        path_pairs.append((gtru_path, pred_path))

splits = {'train': 18, 'valid': 1, 'test':1}
sum_splits = reduce(lambda a,b: a+b, splits.values())
splits_list = [(k,v/sum_splits*len(path_pairs)) for k,v in splits.items()]
i = 0
for spl in splits_list:
    (k,v) = spl
    splits[k] = (int(i),int(i+v))
    i += v

os.makedirs(DEMUCS_OUTPUT_DIR, exist_ok=True)

def process_file(path, target_path): # Everything is converted to 48khz
    snd = AudioSegment.from_file(os.path.join(DATA_OUTPUT_DIR,path), format="wav")
    snd = snd.set_sample_width(4).set_frame_rate(48000).set_channels(1)
    snd.export(target_path, format="wav")
    return target_path

# Demucs training performance is REALLY BAD with lots of small files, so we do batches of 20
BATCH_SIZE = 20
for k,v in splits.items():
    print(f"Handling split {k}")
    split_begin, split_end = v
    os.makedirs(os.path.join(DEMUCS_OUTPUT_DIR, k), exist_ok=True)
    for i in tqdm(range(split_begin, split_end, BATCH_SIZE)):
        song_dir = os.path.join(DEMUCS_OUTPUT_DIR, k, str(i))
        if not os.path.exists(song_dir):
            os.makedirs(song_dir, exist_ok=True)
        gtru = AudioSegment.empty()
        pred = AudioSegment.empty()
        for i2 in range(i, min(i+BATCH_SIZE, split_end)):
            gtru_temp = AudioSegment.from_file(
                os.path.join(DATA_OUTPUT_DIR, path_pairs[i2][0]), format="wav"
                ).set_sample_width(4).set_frame_rate(48000).set_channels(1) # 48khz mono float
            pred_temp = AudioSegment.from_file(
                os.path.join(DATA_OUTPUT_DIR, path_pairs[i2][1]), format="wav"
                ).set_sample_width(4).set_frame_rate(48000).set_channels(1)
            min_len = min(len(gtru_temp),len(pred_temp))
            gtru += gtru_temp[:min_len]
            pred += pred_temp[:min_len]

        gtru.export(os.path.join(song_dir, "vocals.wav"), format="wav")
        pred.export(os.path.join(song_dir, "mixture.wav"), format="wav")
            

Handling split train


100%|██████████| 1605/1605 [1:15:08<00:00,  2.81s/it]


Handling split valid


100%|██████████| 90/90 [04:25<00:00,  2.95s/it]


Handling split test


100%|██████████| 90/90 [04:06<00:00,  2.74s/it]
