# Imports

In [None]:
from numpy import pi
from numpy.random import default_rng
from tqdm import trange

from asintf.audio import binauralize, player
from asintf.datasets import load_file
from asintf.IS_BWLP import IS_BWLP
from asintf.reconstruction import mimo_mwf
from asintf.stft import analysis, estimate_covariance_matrices, magnitude_compression, synthesis

# Parameters

In [None]:
file_path = 'MUSIC_II_example.pkl'  # MUSIC_I_example.pkl, MUSIC_II_example.pkl, SPEECH_example.pkl, DSD100_example.pkl
window = 'hann'
samples_per_frame = 2048
overlapping_sampes = 1024
components_per_source = 25
number_of_directions = 162
number_of_iterations = 100
standard_deviation = pi ** (-1 / 2)
random_generator = default_rng(0)

# Audio loading

In [None]:
sampling_frequency, ambisonic_source_images, _, directions_of_arrival_cartesian = load_file(file_path)
number_of_sources = ambisonic_source_images.shape[0]
ambisonic_mixture = ambisonic_source_images.sum(axis=0)

# STFT & preprocessing

In [None]:
stft = analysis(ambisonic_mixture, sampling_frequency, window, samples_per_frame, overlapping_sampes)[2]


# Estimation

In [None]:
is_bwlp = IS_BWLP(stft,
                  number_of_sources,
                  components_per_source,
                  number_of_directions,
                  directions_of_arrival_cartesian,
                  random_generator)
for i in trange(number_of_iterations):
    is_bwlp.iteration()

# Reconstruction

In [None]:
reconstructed_stft = mimo_mwf(stft, is_bwlp.covariance_matrices)
reconstructed_ambisonic_source_images = synthesis(reconstructed_stft, sampling_frequency, window, samples_per_frame, overlapping_sampes)[1]

# Playback

In [None]:
binaural_audio = binauralize(ambisonic_mixture, sampling_frequency)
player(binaural_audio, sampling_frequency, 'Input mixture')

for source_index in range(number_of_sources):
    binaural_audio = binauralize(reconstructed_ambisonic_source_images[source_index], sampling_frequency)
    player(binaural_audio, sampling_frequency, 'Estimated source ' + str(source_index))