In [None]:
import io
import functools

import numpy as np

import soundfile
import matplotlib
import matplotlib.pylab as plt
from IPython.display import display, Audio

from einops import rearrange

from nara_wpe.utils import stft, istft

from pb_bss.distribution import CACGMMTrainer, CBMMTrainer, CWMMTrainer
from pb_bss.permutation_alignment import DHTVPermutationAlignment, OraclePermutationAlignment
from pb_bss.evaluation import InputMetrics, OutputMetrics

In [None]:
def string_function(a):
    if a.size < 50:
        return str(a)
    else:
        return f'array(shape={a.shape}, dtype={a.dtype})'
np.set_string_function(string_function)

# Read some data
You can use `soundfile.read` directly, when reading local files.

Read the observation $\mathbf y_n$ and the source images $\mathbf x_{0, n}$ and $\mathbf x_{1, n}$ and the noise $\mathbf n_n$
$$
\mathbf{y}_n = \mathbf x_{0, n} + \mathbf x_{1, n} + \mathbf n_{n} \in \mathbb R ^ D
$$

In [None]:
from urllib.request import urlopen

sample_rate = 8000

@functools.lru_cache()
def soundfile_read(url):
    data, data_sample_rate = soundfile.read(io.BytesIO(urlopen(url).read()))
    
    assert sample_rate == data_sample_rate, (sample_rate, data_sample_rate)
    
    print(f'Read: {url}.\nSample rate: {data_sample_rate}')
    return np.ascontiguousarray(data.T)



In [None]:

observation = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/observation.wav"
)
speech_image_0 = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_image_0.wav"
)
speech_image_1 = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_image_1.wav"
)
speech_image = np.array([speech_image_0, speech_image_1])
noise_image = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/noise_image.wav"
)
speech_source_0 = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_source_0.wav"
)
speech_source_1 = soundfile_read(
    "https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_source_1.wav"
)
speech_source = np.array([speech_source_0, speech_source_1])

observation, speech_image, noise_image, speech_source

In [None]:
def plot_mask(mask, *, ax=None):
    if ax is None:
        ax = plt.gca()
    image = ax.imshow(
        mask,
        interpolation='nearest',
        vmin=0,
        vmax=1,
        origin='lower'
    )
    cbar = plt.colorbar(image, ax=ax)
    return ax

def plot_stft(stft_signal, *, ax=None):
    if ax is None:
        ax = plt.gca()
        
    stft_signal = np.abs(stft_signal)
        
    
    stft_signal = 10 * np.log10(
        np.maximum(stft_signal, np.max(stft_signal) / 1e6))
    # 1e6: 60 dB is sufficient
        
    image = ax.imshow(
        stft_signal,
        interpolation='nearest',
        origin='lower',
    )
    cbar = plt.colorbar(image, ax=ax)
    cbar.set_label('Energy / dB')
    return ax

# Calculate the STFT signals
When the first letter is uppercase, this indicates a STFT signal

In [None]:
Observation = stft(observation, 512, 128)
Speech_image = stft(speech_image, 512, 128)
Noise_image = stft(noise_image, 512, 128)
Observation, Speech_image, Noise_image

In [None]:
plot_stft(Observation[0].T)

# Train a mixture model on each frequency

- Instantiate a Trainer 
  - In our experiments most of the time the cACGMM yields the best results
  - The Trainer ist statefull, becasue some distributione (e.g. complex Watson) instantiate a numeric solver
- Call the fit function. There are three way to start the EM algorithm:
  - initialization: np.array with shape (..., K, N)
     - An affiliation mask, that indicate the initial probabilities.
  - num_classes: Scalar
     - Calculates an i.i.d. initialization mask and falls back to the case above
  - initialization: Probabilty instance (i.e. The returned model of the fit function) (e.g. CACGMM instance) 
     - Provide a trained model.
  - Call predict on the trained model to obtain the affiliation (i.e. posterior)

In [None]:
trainer = CACGMMTrainer()
# trainer = CWMMTrainer()
# trainer = CBMMTrainer()

In [None]:
Observation_mm = rearrange(Observation, 'd t f -> f t d')

model = trainer.fit(
    Observation_mm,
    num_classes=3,
    iterations=40,
    inline_permutation_aligner=None
)
model

In [None]:
affiliation = model.predict(Observation_mm)
affiliation

# Permutation alignment
 - The model is trained for each frequeny independent. So the permutation between all frequencies is random. (See next plot)
 - Apply a permutation alignment between the frequencies that was originally implemented from Dang Hai Tran Vu.
 - Calculate the global permutation to identify the speakers and the noise (This uses oracle information)
   - Use as enhanced signal the observaton multiplied with the mask (ToDo: Move beamformer code to pb_bss and use beamforming instead of masking)

In [None]:
plot_mask(affiliation[:, 0, :])

In [None]:
pa = DHTVPermutationAlignment.from_stft_size(512)

In [None]:

mapping = pa.calculate_mapping(
    rearrange(affiliation, 'f k t -> k f t')
)
affiliation_pa = pa.apply_mapping(
    rearrange(affiliation, 'f k t -> k f t'),
    mapping,
)

# Alternative to obtain affiliation_pa:
#     affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))

affiliation_pa

In [None]:
Observation, Speech_image, Noise_image

In [None]:
global_pa_est = rearrange(affiliation_pa, 'k f (t d) -> k d t f', d=1) * rearrange(Observation, 'd t (f k) -> k d t f', k=1)
global_pa_est = rearrange(global_pa_est, 'k d t f -> k (d t f)')
global_pa_reference = rearrange(np.array([*Speech_image, Noise_image]), 'k d t f -> k (d t f)')
global_pa = OraclePermutationAlignment()
global_permutation = global_pa.calculate_mapping(global_pa_est, global_pa_reference)
global_permutation

affiliation_pa = affiliation_pa[global_permutation]
global_permutation

# Visualize the output
 - Display the masks, the enhanced signals and the clean signals

In [None]:
reference_channel = 0

f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))

plot_mask(affiliation_pa[0, :, :], ax=ax0)
plot_mask(affiliation_pa[1, :, :], ax=ax1)
plot_mask(affiliation_pa[2, :, :], ax=ax2)

f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))

Speech_image_0_est = Observation[reference_channel, :, :].T * affiliation_pa[0, :, :]
Speech_image_1_est = Observation[reference_channel, :, :].T * affiliation_pa[1, :, :]
Noise_image_est = Observation[reference_channel, :, :].T * affiliation_pa[2, :, :]

plot_stft(Speech_image_0_est, ax=ax0)
plot_stft(Speech_image_1_est, ax=ax1)
plot_stft(Noise_image_est, ax=ax2)

f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))

plot_stft(Speech_image[0, reference_channel, :, :].T, ax=ax0)
plot_stft(Speech_image[1, reference_channel, :, :].T, ax=ax1)
plot_stft(Noise_image[reference_channel, :, :].T, ax=ax2)

In [None]:
Speech_image_0_est

In [None]:
speech_image_0_est = istft(Speech_image_0_est.T, 512, 128)[..., :observation.shape[-1]]
speech_image_1_est = istft(Speech_image_1_est.T, 512, 128)[..., :observation.shape[-1]]
noise_image_est = istft(Noise_image_est.T, 512, 128)[..., :observation.shape[-1]]
speech_image_0_est, speech_image_1_est, noise_image_est

In [None]:
display(
    'speech_image_0_est',
    Audio(speech_image_0_est, rate=sample_rate),
    'speech_image_1_est',
    Audio(speech_image_1_est, rate=sample_rate),
)

# Metrics

In pb_bss are some metrics available. Some are wrappers around external libraries (e.g. peaq, stoi) and some are implemented in this package (e.g. invasive SXR). `mir_eval` is an external metrix, that got some slightly modifications in this package.

- Input metrics:
  - Note: The input metrics are calculated for each speaker and each channel. 
    With all values given, the user can do analyse the scores in details (e.g. SNR per channel)
- Output metrics:
  - Note: invasive_sxr in only defined for linear enhancements.

## Input metric

In [None]:
input_metric = InputMetrics(
    observation=observation,
    speech_source=speech_source,
    speech_image=speech_image,
    noise_image=noise_image,
    sample_rate=sample_rate,
)
print(input_metric.as_dict().keys())
print(input_metric.as_dict())

Display the metrics for each speaker

In [None]:
for k, v in input_metric.as_dict().items():
    print(k, np.mean(v, axis=-1))

Display the metrics for each channel

In [None]:
for k, v in input_metric.as_dict().items():
    print(k, np.mean(v, axis=0))

Display the average metrics

In [None]:
for k, v in input_metric.as_dict().items():
    print(k, np.mean(v))

## Output metric

Masking is a linear enhancement, so we can calculate the parts that the speech sources and the noise contribute to the estimate

In [None]:
Speech_contribution = Speech_image[:, reference_channel, None, :, :] * rearrange(affiliation_pa, 'k f t -> k t f')
Noise_contribution = Noise_image[reference_channel, :, :] * rearrange(affiliation_pa, 'k f t -> k t f')

speech_contribution = istft(Speech_contribution, 512, 128)[..., :observation.shape[-1]]
noise_contribution = istft(Noise_contribution, 512, 128)[..., :observation.shape[-1]]

In [None]:
output_metric = OutputMetrics(
    speech_prediction=np.array([speech_image_0_est, speech_image_1_est, noise_image_est]),
    speech_source=speech_source,
    speech_contribution=speech_contribution,
    noise_contribution=noise_contribution,
    sample_rate=sample_rate,
)
print(output_metric.as_dict().keys())
output_metric.as_dict()

In [None]:
print(f'{"Score": <19}{"in": <22} + {"gain": <20} -> out')
print('-' * 61)
for k, v in output_metric.as_dict().items():
    if k == 'mir_eval_sxr_selection':
        print(k, v)
    else:
        i = np.mean(input_metric.as_dict()[k])
        o = np.mean(v)
        g = o - i
        print(f'{k+":": <19}{i: <22} + {g: <20} -> {o}')