In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
import paderbox as pb

In [None]:
from collections import defaultdict
import itertools
from mms_msg.visualization.plot import plot_mixture
from mms_msg import keys
from mms_msg.simulation.utils import load_audio
        
def plot_mixtures(generator_dataset, number=6, columns=3, figure_width=10):
    with pb.visualization.axes_context(columns=columns, figure_size=(figure_width, 3)) as ac:
        for ex in itertools.islice(generator_dataset, number):
            plot_mixture(ex, ax=ac.new)
            # activity = defaultdict(pb.array.interval.zeros)
            # num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)
            # for o, l, s in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):
            #     activity[s][o:o+l] = True
            #
            # pb.visualization.plot.activity(activity, ax=ac.new)

In [None]:
import mms_msg
from mms_msg import sampling

## Preparation: The input dataset
The mixture/meeting generators are generic, i.e., they work with any database that contains examples of single-speaker speech.
The input database has to have its examples in the correct format, i.e., they have to contain the correct keys.

The examples have to have the following format:
 - `example_id` (`str`): The ID of the input example. Has to be unique in the input dataset
 - `num_samples` or `num_samples.observation` (`int`): The number of samples in the example
 - `speaker_id` (`str`): The ID of the speaker that uttered the speech in this example
 - `audio_path` or `audio_path.observation` (`str`): The path to the audio, will later be in `audio_path.original_source`
 
For meeting data additionally:
 - `scenario` (`str`): An identifier that uniquely identifies a "scenario" that should not change for a single speaker in a meeting. E.g., in LibriSpeech the scenario should be `f"{chapter_id}_{speaker_id}"`. Defaults to `speaker_id`.

All other keys are simply copied over from the input examples, so all information present in the input examples will be present in the generated mixtures.

In [None]:
# Prepare input datasets. Use LibriSpeech here because it is freely available, but WSJ (or any other database) works as well
from mms_msg.databases.single_speaker.librispeech.database import LibriSpeech8kHz
input_db = LibriSpeech8kHz()
input_ds = input_db.get_dataset('test_clean')
input_ds[0]

## Fully overlapped mixtures

### Like WSJ0-2mix

In [None]:
# Deterministic, anechoic, no offset, like WSJ0-2mix

# Compute a composition of base examples. This makes sure that the speaker distribution
# in the mixtures is equal to the speaker distribution in the original database.
ds = sampling.source_composition.get_composition_dataset(input_dataset=input_ds, num_speakers=2)

# If required: Offset the utterances
ds = ds.map(sampling.pattern.classical.ConstantOffsetSampler(0))

# If required: Add log_weights to simulate volume differences
ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))

# If required: Truncate to the shorter utterance
ds = ds.map(mms_msg.simulation.truncation.truncate_min)

len(ds), ds[0]

In [None]:
plot_mixtures(ds)

In [None]:
# Load audio
from mms_msg import keys
ds = ds\
    .map(lambda example: load_audio(example, keys.ORIGINAL_SOURCE))\
    .map(mms_msg.simulation.anechoic.anechoic_scenario_map_fn)
ex = ds[0]
pb.io.play(ex['audio_data']['observation'], sample_rate=8000)

### Like SMS-WSJ

In [None]:
from mms_msg.databases.reverberation.sms_wsj import SMSWSJRIRDatabase
ds = sampling.source_composition.get_composition_dataset(input_dataset=input_ds, num_speakers=2)
ds = ds.map(sampling.pattern.classical.SMSWSJOffsetSampler())
ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))
ds = ds.map(sampling.environment.noise.UniformSNRSampler(20, 30))
ds = ds.map(sampling.environment.rir.RIRSampler(SMSWSJRIRDatabase().get_dataset('test_eval92')))
ds[0]

In [None]:
plot_mixtures(ds)

In [None]:
# Load an example
from mms_msg import keys
ds = ds\
    .map(lambda example: load_audio(example, keys.ORIGINAL_SOURCE, keys.RIR))\
    .map(mms_msg.simulation.reverberant.reverberant_scenario_map_fn)\
    .map(mms_msg.simulation.noise.white_microphone_noise)
ex = ds[0]
pb.io.play(ex['audio_data']['observation'], sample_rate=8000)

### Dynamic Mixing

In [None]:
# Dynamic mixing: Set the rng argument to `True` to get a non-deterministic dataset that changes its contents 
# every time it is iterated. Useful if you want to train on an infinite stream of randomly generated examples
# TODO: dynamic_ -> rng
ds = sampling.source_composition.get_composition_dataset(input_dataset=input_ds, num_speakers=2, rng=True)
# only the function above this line changed from the determinstic case
# -------------------------------------------------------------------------------------------------------------------
# the part below this line is deterministic and equal to the cell above
ds = ds.map(sampling.pattern.classical.SMSWSJOffsetSampler())
ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))

In [None]:
# Check that iterating two times gives different examples
for _ in range(2):
    plot_mixtures(ds, number=3)

## Generate Meetings

### Anechoic

In [None]:
# Deterministic, anechoic, use the same base function as for SMS-WSJ, i.e., we have the same initial examples as SMS-WSJ
ds = sampling.source_composition.get_composition_dataset(input_dataset=input_ds, num_speakers=[3, 4, 5])
ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))
ds = ds.map(sampling.pattern.meeting.MeetingSampler(duration=60*8000)(input_ds))
ds[0]

In [None]:
plot_mixtures(ds, columns=2, figure_width=20, number=6)

### With reverberation

In [None]:
# With rir, use the same base function as for SMS-WSJ, i.e., we have the same initial examples as SMS-WSJ
import functools
ds = sampling.source_composition.get_composition_dataset(input_dataset=input_ds, num_speakers=[3, 4])
ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))
ds = ds.map(sampling.environment.rir.RIRSampler(SMSWSJRIRDatabase().get_dataset('test_eval92')))
ds = ds.map(sampling.pattern.meeting.MeetingSampler(duration=60*8000)(input_ds))
ds[0]

## Class-based interface

In [None]:
db = mms_msg.databases.classical.full_overlap.Libri2MixClean()

In [None]:
# Dataset names are the same as LibriSpeech
db.dataset_names

In [None]:
plot_mixtures(db.get_dataset('test_clean'), number=3)

In [None]:
# Dynamic mixing can be enabled by appending "_rng" (for a random seed) or "_rng<seed>" (for a fixed seed) to the dataset name.
# The top two potted rows are different because the seed is random by default
# The bottom two plotted rows are equal because the seed is fixed to 42
plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)
plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)
plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)
plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)

In [None]:
# Audio can be loaded with the `load_example` method of the database object
ex = db.load_example(db.get_dataset('test_clean')[0])
pb.io.play(ex['audio_data']['observation'], name='observation')
pb.io.play(ex['audio_data']['speech_image'][0], name='speech_image 1')
pb.io.play(ex['audio_data']['speech_image'][1], name='speech_image 2')

## Generate JSON
A JSON file that contains the hyperparameters and can be read with a `lazy_dataset.JSONDatabase` object can easily be created by iterating over the dataset.
If you want to dump all generated signals as audio files, refer to TODO

In [None]:
from tqdm.notebook import tqdm
database_dict = {'datasets': {dataset_name: dict(tqdm(db.get_dataset(dataset_name).items(), desc=dataset_name)) for dataset_name in db.dataset_names}}
pb.io.dump(database_dict, 'libri_mix.json')