In [None]:
import os
import librosa
import torch
import yaml
import numpy as np
from wandas.core import ChannelFrame
# 'model'ディレクトリをPythonのパスに追加
from system.sep_system import TwoChSepSystem
from models.dcunet import TwoChDCUNet

from torchmetrics.audio.snr import (
    signal_noise_ratio as snr,
)

import pytorch_lightning as pl

In [None]:
with open("./DCUNet/conf_finetuning_crosstalk_add_EQ_Large-DCUNet.yml") as f:
    conf = yaml.safe_load(f)

model = TwoChDCUNet(
    **conf["filterbank"],
    **conf["masknet"],
    sample_rate=conf["data"]["sample_rate"],
)

def loss_fn(pred, tgt):
    return - snr(pred, tgt).mean()

system = TwoChSepSystem(
        model=model,
        loss_func=loss_fn,
        optimizer=None,
        train_loader=None,
        val_loader=None,
        scheduler=None,
        config=conf,
    )

state_dict = torch.load("./exp/checkpoints/epoch=33-step=340000.ckpt", weights_only=True, map_location="cpu")
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
model.eval()


In [None]:
data_dir = "./data"
source_audio_file = 

fs = 32000
source_signal, _ = librosa.load(os.path.join(data_dir, source_audio_file), sr=fs, duration=10) 

# Ensure the audio data is in floating-point format
source_signal = source_signal.astype(np.float32)
source_signal /= np.abs(source_signal).max()*1.2
# HPSS
#ハーモニック成分を取得
source_signal_r = np.copy(source_signal)
for i in range(1):
    source_signal_r = source_signal_r - librosa.effects.harmonic(source_signal_r,margin=1, kernel_size=62)

# Perform STFT
D = librosa.stft(source_signal_r)

# Separate amplitude and phase
amplitude, phase = np.abs(D), np.angle(D)

# Modify the amplitude
med = np.mean(amplitude, axis=-1, keepdims=True)
med_tiled = np.tile(med, (1, amplitude.shape[1]))

mask = amplitude <med_tiled 
amplitude[mask] = 1E-12 
# Combine modified amplitude with original phase
D_modified = amplitude * np.exp(1j * phase)

# Perform inverse STFT  
source_signal_r = librosa.istft(D_modified)

source_signal_h = source_signal - source_signal_r


In [None]:
mixed, noise = source_signal, np.roll(source_signal_h, int(fs*0.0))*3
est_targets = model(torch.stack([torch.from_numpy(mixed.astype(np.float32)), torch.from_numpy(noise.astype(np.float32))], dim=0).unsqueeze(0)).squeeze()
est_targets =est_targets.detach().numpy()
est_noise = (mixed - est_targets).squeeze()
sep_signal = ChannelFrame.from_ndarray(np.stack([source_signal, source_signal_r, est_targets, source_signal_h, est_noise], axis=0), sampling_rate=fs, labels=["obs", "hpss residual", "dnn percussive", "hpss harmonic", "dnn harmonic"])

In [None]:
sep_signal.describe()