In [None]:
import os
import glob
import tqdm
import torch
import numpy as np
import omegaconf
from mllib.src.train import main

import IPython.display as ipd
import librosa
import librosa.display
import matplotlib.pyplot as plt

import torch.cuda
from mllib.src.evaluate import evaluate
from mllib.src.utils import prepare_device, load_yaml
from mllib.src.distrib import get_dev_wav_clarity

In [None]:
# config = "./mllib/result/mel-rnn/20230202-145405/config.yaml"
# config = './mllib/result/mel-rnn/20230203-121042/config.yaml'

# config = "./mllib/result/dnn/20230202-142249/config.yaml"
# config = "./mllib/result/dnn/20230202-163959/config.yaml"
# config = "./mllib/result/dnn/20230202-170504/config.yaml"
# config = "./mllib/result/dnn/20230202-171624/config.yaml"
# config = "./mllib/result/dnn/20230202-185453/config.yaml"
# config = "./mllib/result/dnn/20230203-115011/config.yaml"

# config= "./mllib/result/unet/20230203-183804/config.yaml"

# config= "./mllib/result/conv-tasnet/20230203-183838/config.yaml"

# config = "./result/demucs/20230201-104202/config.yaml"
# config = "./result/wav-unet/20230201-104328/config.yaml"
# config = "./result/dcunet/20230201-104116/config.yaml"

# config= "./result/conv-tasnet/20230207-080249/config.yaml"
# config= "./result/conv-tasnet/20230207-184607/config.yaml"    # samples only including target's period, PIT
config= "./result/conv-tasnet/20230207-185011/config.yaml"    # including all samples, PIT
# config= "./result/conv-tasnet/20230208-175200/config.yaml"    # including all samples and no PIT

In [None]:
solver = main(path_config=config, return_solver=True)

In [None]:
model = solver.model

In [None]:
args = load_yaml(config)
n_gpu = torch.cuda.device_count()
device = prepare_device(n_gpu, cudnn_deterministic=args.solver.cudnn_deterministic)

In [None]:
dev_dataset = get_dev_wav_clarity(args.dset)

In [None]:
from mllib.src.distrib import get_train_wav_dataset

SNR = '0' # '0', '5', '10', '15' # SNR = P_{Signal} / P_{Noise}
if args.dset.name == "Clarity":

    log_clarity = "./data/metadata/scenes.dev.snr.json"
    metadata = omegaconf.OmegaConf.load(log_clarity)
    print(list(metadata.values())[0], list(metadata.keys())[0])
    snr_min = 0
    snr_max = 5
    for data in tqdm.tqdm(dev_dataset, ncols=120):
        mixture, sources, origial_length, name = data
        scene_name = name.split("_")[0]
        if metadata[scene_name] >= snr_min and metadata[scene_name] < snr_max:
            data_test = data
            snr = metadata[scene_name]
            break
    
    print("Clarity dataset SNR: ", snr)

data_test

In [None]:
mixture, sources, _, name = dev_dataset[0]

In [None]:
from mllib.src.model.types import (MULTI_SPEECH_SEPERATION_MODELS,
                MULTI_CHANNEL_SEPERATION_MODELS,
                MONARCH_SPEECH_SEPARTAION_MODELS, 
                STFT_MODELS,
                WAV_MODELS,)

nchannel, nsample = mixture.shape
num_spk = sources.shape[1]

# mono channel to stereo for source separation models
assert args.model.audio_channels == nchannel, f"Channel between {args.dset.name} and {args.model.name} did not match..."
assert args.model.num_spk == num_spk, f"number of speakers between {args.dset.name} and {args.model.name} did not match..."

if args.model.name in MULTI_SPEECH_SEPERATION_MODELS:
    assert num_spk == len(args.model.sources), f"number of speakers between {args.dset.name} and {args.model.name} did not match..."

# if not source separation models, merge batch and channels
if args.model.name in MONARCH_SPEECH_SEPARTAION_MODELS:
    mixture = torch.reshape(mixture, shape=(nchannel, 1, nsample))


In [None]:
mixture.shape, args.model.name

In [None]:
enhanced = evaluate(mixture=mixture[None], model=model, device=device, config=args)
enhanced = torch.squeeze(enhanced, dim=0)

In [None]:
enhanced.shape, sources.shape

In [None]:
enhanced = enhanced.detach().cpu()
sources = sources.detach().cpu()

if args.model.name in MULTI_SPEECH_SEPERATION_MODELS:
    enhanced = enhanced[:, 0, ...]
    sources = sources[:, 0, ...]



In [None]:
enhanced.shape, sources.shape, mixture.shape

In [None]:
import julius
from omegaconf import OmegaConf
from recipes.icassp_2023.MLbaseline.enhance  import enhance
from recipes.icassp_2023.MLbaseline.evaluate import get_amplified_signal

In [None]:
name_scene = name.split("_")[0]
config_clarity_challenge = OmegaConf.load("./recipes/icassp_2023/MLbaseline/config.yaml")

In [None]:
enhanced_signal_resample = julius.resample.resample_frac(enhanced, args.dset.sample_rate, config_clarity_challenge.nalr.fs)

amplified, ref, haspi_score, hasqi_score, audiogram = get_amplified_signal(enhance_signal = enhanced_signal_resample,
                                                                fs_signal=config_clarity_challenge.nalr.fs,
                                                                scene=name_scene,
                                                                cfg=config_clarity_challenge)


In [None]:
haspi_score, hasqi_score, audiogram

In [None]:
sources_signal_resample = julius.resample.resample_frac(sources, args.dset.sample_rate, config_clarity_challenge.nalr.fs)

amplified_clean, ref_clean, haspi_score_clean, hasqi_score_clean, audiogram = get_amplified_signal(enhance_signal = sources_signal_resample,
                                                                fs_signal=config_clarity_challenge.nalr.fs,
                                                                scene=name_scene,
                                                                cfg=config_clarity_challenge,
                                                                audiogram=audiogram)

In [None]:
haspi_score_clean, hasqi_score_clean, audiogram

In [None]:
amplified.shape, amplified_clean.shape

In [None]:
enhanced_np = enhanced
mixture_np = mixture
sources_np = sources
amplified_np = amplified
amplified_clean_np = amplified_clean

In [None]:
enhanced_np = enhanced_np.numpy().flatten()
mixture_np = mixture_np.numpy().flatten()
sources_np = sources_np.numpy().flatten()
amplified_np = amplified_np.flatten()
amplified_clean_np = amplified_clean_np.flatten()

In [None]:
enhanced_np.shape, mixture_np.shape, sources_np.shape, amplified_np.shape, amplified_clean_np.shape

In [None]:
fig, (ax0, ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=6)

ax0.plot(mixture_np)
ax1.plot(sources_np)
ax2.plot(enhanced_np)
ax3.plot(amplified_np)
ax4.plot(amplified_clean_np)
ax5.plot(ref.flatten())

In [None]:
fig, (ax0, ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=6, sharey=True)

def show_stft(y, _fig, _ax):
    D = librosa.stft(y, n_fft=4096)  # STFT of y
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    img = librosa.display.specshow(S_db, ax=_ax)
    _fig.colorbar(img, ax=_ax)

show_stft(mixture_np, fig, ax0)
show_stft(sources_np, fig, ax1)
show_stft(enhanced_np, fig, ax2)
show_stft(amplified_np, fig, ax3)
show_stft(amplified_clean_np, fig, ax4)
show_stft(ref.flatten(), fig, ax5)

In [None]:
ipd.Audio(mixture_np, rate=args.dset.sample_rate)

In [None]:
ipd.Audio(sources_np, rate=args.dset.sample_rate)

In [None]:
ipd.Audio(enhanced_np, rate=args.dset.sample_rate)

In [None]:
ipd.Audio(amplified_np, rate=config_clarity_challenge.nalr.fs)

In [None]:
ipd.Audio(amplified_clean_np, rate=config_clarity_challenge.nalr.fs)

In [None]:
ipd.Audio(ref.flatten(), rate=config_clarity_challenge.nalr.fs)