In [13]:
import torch
import torchaudio
from resemblyzer import VoiceEncoder, preprocess_wav
from thop import profile


In [14]:
# Load audio file
waveform, sample_rate = torchaudio.load("data/20250306170609.wav")

# Resample if necessary
if sample_rate != 16000:
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

# d-vector output encoder
speaker_encoder = VoiceEncoder(device="cpu")
# Convert to numpy array and preprocess for d-vector
waveform_preprocessed = preprocess_wav(waveform.numpy().squeeze())
speaker_embedding = torch.from_numpy(speaker_encoder.embed_speaker([waveform_preprocessed])).unsqueeze(0)
speaker_embedding.shape

Loaded the voice encoder model on cpu in 0.02 seconds.


torch.Size([1, 256])

In [15]:
from model import SpeakerBeamSS

In [16]:
batch_size = 1
input_len = 16000  # 1秒分 @16kHz
mixture = torch.randn(batch_size, 1, input_len)

model = SpeakerBeamSS()
with torch.no_grad():
    out = model(mixture, speaker_embedding)
    print("Input:", mixture.shape, "Output:", out.shape)
    flops, params = profile(model, inputs=(mixture, speaker_embedding))
    print(f"FLOPs: {flops / 1e9:.2f}G, Params: {params / 1e6:.2f}M")

Input: torch.Size([1, 1, 16000]) Output: torch.Size([1, 1, 16000])
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose1d'>.
FLOPs: 21.60G, Params: 7.64M
