In [None]:
%matplotlib widget
import numpy as np
import pyroomacoustics as pra
import soundfile as sf
import os
import matplotlib.pyplot as plt
from wandas.core import ChannelFrame
from models.overlap_add import OverlapAdd

In [None]:
# 音声ファイル読み込み (例: LibriSpeech データセット)
data_dir = "./data"
fs = 32000
source_audio_files = [
    "target.wav", 
    "noise.wav", 
]
source_audio_files = [os.path.join(data_dir, f) for f in source_audio_files]
assert all(os.path.exists(f) for f in source_audio_files), "音声ファイルが見つかりません"
source_signals = [sf.read(file)[0][:fs*5] for file in source_audio_files]


# 部屋の形状を設定（角の座標）
corners = np.array([
    [0, 0], [8, 0], [8, 8], [0, 8]
]).T  # 転置して [2, N] の形に

# 壁の吸収率を設定
absorption = 0.4  # 一般壁の吸収率
scattering = 0.1  # 散乱率

# 部屋の作成
room = pra.Room.from_corners(
    corners=corners,
    fs=fs,
    materials=pra.Material(absorption, scattering),
    max_order=12  # 反射の最大回数
)

# マイクを配置
mic_positions = np.array([[1.0, 1.0], [7.0, 7.0]]).T  # 2つのマイクの座標
room.add_microphone_array(mic_positions)

# 音源を配置
# 音源配置と信号追加
source_positions = [[2.0, 2.0], [6.0, 6.0]]  # 2つの音源の座標
for position, signal in zip(source_positions, source_signals):
    room.add_source(position, signal=signal)


# 描画
fig, ax = room.plot()
ax.set_title("Room with Obstacles and Multiple Microphones/Sources")
plt.show()

# シミュレーション実行
room.simulate()


In [None]:
# マイク信号取得
mic_signals = room.mic_array.signals

# 各音源の残響込み信号を生成（RIRを畳み込み）
reverberant_signals = []
for i, src in enumerate(room.sources):
    rir = room.rir[0][i]  # マイク0と音源iのRIR
    reverberant_signal = np.convolve(source_signals[i], rir, mode='same')
    reverberant_signals.append(reverberant_signal)
reverberant_signals = np.array(reverberant_signals)
val_signal = ChannelFrame.from_ndarray(np.vstack((mic_signals[...,:reverberant_signals.shape[-1]], reverberant_signals)), sampling_rate=fs, labels=["mixed", "noise"]+["tgt", "noise_clean"])


In [None]:
val_signal.describe()

In [None]:
import torch
import yaml
import numpy as np
from wandas.core import ChannelFrame
# 'model'ディレクトリをPythonのパスに追加
from models.dcunet import TwoChDCUNet

from torchmetrics.audio.snr import (
    signal_noise_ratio as snr,
)

torch.set_float32_matmul_precision("high")


In [None]:
with open("./DCUNet/conf_finetuning.yml") as f:
    conf = yaml.safe_load(f)
conf["exp_dir"] = "../exp/DCUNet_TwoNoise_ClossTalk"

model = TwoChDCUNet(
    **conf["filterbank"],
    **conf["masknet"],
    sample_rate=conf["data"]["sample_rate"],
)

state_dict = torch.load("./exp/checkpoints/epoch=66-step=670000.ckpt", weights_only=True, map_location="cpu")

# # システムの状態辞書からモデルの部分だけを抽出
model_state_dict = {}
for key, value in state_dict["state_dict"].items():
    # "model."で始まるキーだけを取り出し、プレフィックスを削除
    if key.startswith("model."):
        model_state_dict[key[6:]] = value  # "model." の6文字を削除

# モデルに状態辞書を読み込む
model.load_state_dict(model_state_dict)
model.eval()

In [None]:
def process_two_channel_audio(mixed, noise_clean, model, ola):
    # 入力形状: [ch, length] = [2, length]
    x = np.stack([mixed, noise_clean], axis=0).astype(np.float64)
    
    def batch_process_with_dnn(segments):
        """
        バッチとしてマルチチャネルセグメントを処理
        
        Args:
            segments: 形状(n_segments, n_channels, window_size)の配列
            
        Returns:
            形状(n_segments, n_channels, window_size)の処理済み配列
        """
        # ダミー処理 - モデルがない場合単に信号を返す
        if model is None:
            return segments 
            
        # input_signal.shape [1, 2, 320000]
        # NumPy -> PyTorch、データ型を合わせる
        with torch.no_grad():
            device = next(model.parameters()).device
            dtype = next(model.parameters()).dtype
            
            # 入力データをモデルと同じデータ型に変換
            tsegs = torch.from_numpy(segments).to(device=device, dtype=dtype)
            
            # モデルが入力として(batch_size, channels, time)の形状を期待
            # 現在の形状は(n_segments, n_channels, window_size)なのでOK
            est_targets = model(tsegs).squeeze()
            
            # 出力変換: ターゲット信号とノイズ信号
            est_noise = (tsegs[:,0] - est_targets).squeeze(1)
            
            # PyTorch -> NumPy
            result = torch.stack([est_targets, est_noise], dim=1).cpu().numpy()
        return result
    
    # OverlapAddによる処理
    processed = ola.process_signal(x, batch_process_with_dnn)
    
    # 出力形状: [ch, length] = [2, length]
    return processed


In [None]:
mixed, noise = val_signal["mixed"].data, val_signal["noise"].data
tgt = val_signal["tgt"].data
mixed = np.tile(mixed, 5)
noise_clean = np.tile(noise, 5)
tgt = np.tile(tgt, 5)

window_size = conf["data"]["sample_rate"]*3
hop_size = window_size//2
ola = OverlapAdd(window_size=window_size, hop_size=hop_size, window='hann')

processed=process_two_channel_audio(mixed, noise_clean, model, ola)
sep_signal_segment = ChannelFrame.from_ndarray(np.stack([mixed, noise_clean, tgt, processed[0], processed[1]], axis=0), sampling_rate=conf["data"]["sample_rate"], labels=["mixed", "noise_clean", "tgt", "est_tgt", "est_noise"])


In [None]:
for ch in sep_signal_segment:
    print(f"{ch.label} snr:{snr(torch.from_numpy(ch.data), torch.from_numpy(tgt))}")
sep_signal_segment.to_audio()

In [None]:
for ch in sep_signal_segment:
    ch.stft().plot(vmin=0, vmax=80)