In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
import copy

class ExecutableCodeBlocks:
    def __init__(self):
        self.old_blocks = []
        self.blocks = []
        self.states = []

    def step(self):
        self.old_blocks = self.blocks
        self.blocks = []

    def add(self, code, executable=True):
        self.blocks.append((code, executable))

    def execute(self):
        _i = 0
        if not self.old_blocks:
            self.old_blocks = [[] for _ in self.blocks]
        if not self.states:
            self.states = [None for _ in self.blocks]
        _dirty = False
        for _new, _old in zip(self.blocks, self.old_blocks):
            if _new[-1]:
                if (_dirty or _new != _old):
                    if not _dirty:
                        if _i > 0:
                            locals().update(self.states[_i - 1])
                    _dirty = True
                    exec(_new[0])
                    self.states[_i] = copy.copy(locals())
                _i += 1
        state = self.states[_i-1]
        try:
            del state['self']
            del state['_i']
            del state['_old']
            del state['_new']
            del state['_dirty']
        except KeyError:
            pass
        return state

    def get_code(self):
        return '\n'.join(b for b, _ in self.blocks)

    
def generate_code(
        input_database='WSJ', input_dataset_name='test_eval92',
        num_speakers=4, duration=60, p_silence=0.2, silence=(0, 2),
        overlap=(0, 8), example=0, reverberation=False,
        style='Meeting', snr_range=(20, 30),
        code_block=None,
):
    if code_block is None:
        code_block = ExecutableCodeBlocks()
    else:
        code_block.step()
        
    if silence[0] == silence[1]:
        silence = (silence[0], silence[0] + 0.1)
    if overlap[0] == overlap[1]:
        overlap = (overlap[0], overlap[0] + 0.1)

    code_block.add('import mms_msg\n')
    code_block.add('import functools\n')
    
    ## Select input database
    input_database_code = '# Set up the input database\n'
    if input_database == 'WSJ':
        sample_rate = 8000
        input_database_code += 'from padercontrib.database.wsj import WSJ_8kHz\n'
        input_database_code += 'db = WSJ_8kHz()\n'
    elif input_database == 'LibriSpeech':
        sample_rate = 16000
        input_database_code += 'from padercontrib.database.librispeech import LibriSpeech\n'
        input_database_code += 'db = LibriSpeech()\n'
    input_database_code += f'sample_rate = {sample_rate}\n'
    input_database_code += f'input_dataset = db.get_dataset("{input_dataset_name}")\n'
    code_block.add(input_database_code)
    
    ## Composition
    code_block.add(f'ds = mms_msg.get_composition_dataset(input_dataset, num_speakers={num_speakers})\n')
    
    ## Log Weights
    meeting_code = f'ds = ds.map(mms_msg.utils.scaling.UniformLogWeightSampler())\n'

    ## Reverberation
    meeting_code += f'''{"" if reverberation else "# "}ds = ds.map(mms_msg.RIRSampler.from_scenarios_json('/net/db/sms_wsj/rirs/scenarios.json', {input_dataset_name!r}))\n'''
    
    if style == 'Meeting':
        meeting_code += f'''
ds = ds.map(mms_msg.MeetingSampler(
    duration={duration*sample_rate}, # in samples
    overlap_sampler=mms_msg.meeting.overlap_sampler.UniformOverlapSampler(
        max_concurrent_spk=2,
        p_silence={p_silence},
        minimum_silence={int(silence[0]*sample_rate)}, # in samples
        maximum_silence={int(silence[1]*sample_rate)}, # in samples
        minimum_overlap={int(overlap[0]*sample_rate)}, # in samples
        maximum_overlap={int(overlap[1]*sample_rate)}, # in samples
    ))(input_dataset))
'''.lstrip()
    elif style == 'Partial overlap':
        meeting_code += f'''
ds = ds.map(mms_msg.PartialOverlapOffsetSampler(
    minimum_overlap={int(overlap[0]*sample_rate)}, # in samples
    maximum_overlap={int(overlap[1]*sample_rate)}, # in samples
))
'''.lstrip()
    elif style == 'Full overlap':
        meeting_code += f'''
ds = ds.map(mms_msg.SMSWSJOffsetSampler())
'''.lstrip()
    elif style == 'Same start':
        meeting_code += f'''
ds = ds.map(mms_msg.ConstantOffsetSampler(0))  # WSJ0-2mix-like
'''.lstrip()
    else:
        raise ValueError(style)
    
    code_block.add(meeting_code)
    
    # Plot and display
    display_code = '\n# Plot meeting\n'
    display_code += f'plot_meeting(ds[{example}], sample_rate=sample_rate)\n\n'

    # Load & play
    display_code += '# Load and play audio\n'
    display_code += 'import paderbox as pb\n'

    if reverberation:
        scenario_map_fn = 'mms_msg.scenario.multi_channel_scenario_map_fn'
    else:
        scenario_map_fn = 'mms_msg.scenario.anechoic_scenario_map_fn'

    display_code += f'''
def load_audio(example):
    example['audio_data'] = pb.io.audioread.recursive_load_audio(example['audio_path'])
    return example
ds = ds.map(load_audio)
ds = ds.map(functools.partial({scenario_map_fn}, snr_range={snr_range}))
example = ds[{example}]

# Play, e.g., in a notebook
pb.io.play(example['audio_data']['observation'], sample_rate=sample_rate)
'''.lstrip()

    code_block.add(display_code, executable=False)

    return code_block

In [None]:
from collections import defaultdict
import paderbox as pb

def plot_meeting(ex, ax=None, sample_rate=8000):
    if ax is None:
        ax = plt.gca()
    speech_activity = defaultdict(pb.array.interval.zeros)
    try:
        num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)
    except KeyError:
        num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.speech_source', allow_early_stopping=True)
    for o, l, s,  in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):
        speech_activity[s][o:o+l]=True

    pb.visualization.plot.activity(speech_activity, ax=ax)
    ax.set_xticklabels(ax.get_xticks() / sample_rate)

In [None]:
from pygments import highlight

from pygments.lexers import Python3Lexer
from pygments.formatters import HtmlFormatter

def get_formatted_html_code(code):

    formatter = HtmlFormatter()
    formatted_code = highlight(code, Python3Lexer(), formatter)

    return f'''
    <div>
    <style type="text/css" scoped>
    {formatter.get_style_defs()}
    code {{
        background: black;
    }}
    </style>
    
    {formatted_code}
    </div>
    '''

In [None]:
import numpy as np

def _make_wav(data, rate):
    """ Transform a numpy array to a PCM bytestring
    Taken from IPython display lib """
    from io import BytesIO
    import wave

    data = np.array(data, dtype=float)
    if len(data.shape) == 1:
        nchan = 1
    elif len(data.shape) == 2:
        # In wave files,channels are interleaved. E.g.,
        # "L1R1L2R2..." for stereo. See
        # http://msdn.microsoft.com/en-us/library/windows/hardware/dn653308(v=vs.85).aspx
        # for channel ordering
        nchan = data.shape[0]
        data = data.T.ravel()
    else:
        raise ValueError('Array audio input must be a 1D or 2D array')

    max_abs_value = np.max(np.abs(data))
    normalization_factor = max_abs_value
    scaled = data / normalization_factor * 32767
    scaled = scaled.astype('<h').tobytes()
    fp = BytesIO()
    waveobj = wave.open(fp,mode='wb')
    waveobj.setnchannels(nchan)
    waveobj.setframerate(rate)
    waveobj.setsampwidth(2)
    waveobj.setcomptype('NONE','NONE')
    waveobj.writeframes(scaled)
    val = fp.getvalue()
    waveobj.close()

    return val

In [None]:
from matplotlib import pyplot as plt
from ipywidgets import *
from io import BytesIO

class InteractiveMeetingPlotter():
    def __init__(self):
        # Control widgets
        
        w = {}
        
        w['input_database'] = Dropdown(options=['WSJ', 'LibriSpeech'], value='WSJ')
        w['input_dataset_name'] = Dropdown(options=['test_eval92', 'train_si284'])
        w['style'] = Dropdown(options=['Meeting', 'Partial overlap', 'Full overlap', 'Same start'])
        w['example'] = IntSlider(min=0, max=50, value=0)

        self.database_to_dataset = {
            'WSJ': ('test_eval92', 'train_si284'),
            'LibriSpeech': ('test_clean', 'dev_clean', 'train_clean'),
        }
        
        def database_callback(change):
            if change['old'] != change['new']:
                try:
                    new_options = self.database_to_dataset[change['new']]
                    w['input_dataset_name'].options = new_options
                    w['input_dataset_name'].value = new_options[0]
                except Exception:
                    raise Exception(change)
        w['input_database'].observe(database_callback, 'value')
        
        
        # def update_input_dataset(args):
        #     if args['new'] == 'WSJ':
        #         w_input_dataset.options = ['test_eval92', 'train_si284']
        #     elif args['new'] == 'LibriSpeech':
        #         w_input_dataset.options = ['train_clean', 'test_clean']
        # self.w_input_database.observe(update_input_dataset, 'value')

        w['num_speakers'] = IntSlider(min=2, max=5, value=4)
        w['duration'] = IntSlider(max=120, value=60, step=10)
        w['p_silence'] = FloatSlider(min=0, max=1, value=0.1)

        w['silence'] = FloatRangeSlider(min=0, max=6, value=(0, 2), step=0.5)
        w['overlap'] = FloatRangeSlider(min=0, max=15, value=(0, 8), step=0.5)
        w['reverb'] = Checkbox(value=False)
        
        w['snr_range'] = FloatRangeSlider(min=0, max=40, value=(20, 30), step=1)

        # Output widgets
        self.w_plot_out = Image()
        self.w_plot_out.layout.height = '200px'
        self.html = HTML()
                
        self.w_audio_button = Button(description='Load Audio')
        self.w_audio_button.on_click(self.load_audio)
        self.w_audio_players = VBox()
        
        self.interactive_widgets = w
        
        # Global variables
        self.code = ExecutableCodeBlocks()
        self.generated_example = None
        self.audio_data = None
        self.fig = None
        self.ax = None
        
    def generate_code_interactive(
            self, 
            input_database,
            input_dataset_name, 
            style,
            num_speakers,
            duration,
            p_silence=0.2, silence=1,
            overlap=1,
            example=0,
            snr_range=(20, 30),
            reverberation=False,
    ):
        allowed_datasets = self.database_to_dataset[self.interactive_widgets['input_database'].value]
        if self.interactive_widgets['input_dataset_name'].value not in allowed_datasets:
            return
        
        generate_code(
            input_database, input_dataset_name, num_speakers, duration,
            p_silence, silence, overlap, example, reverberation,
            style, snr_range,
            code_block=self.code)
        self.html.value = get_formatted_html_code(self.code.get_code())

        self.code_locals = self.code.execute()
        self.generated_example = self.code_locals['ds'][example]
        # with self.w_plot_out:
        with pb.visualization.figure_context(figure_size=(10, 3)):
            if self.fig is None:
                self.fig, self.ax = plt.subplots()

            plot_meeting(self.generated_example, self.ax)
            buffer = BytesIO()
            self.fig.savefig(buffer)
            buffer.seek(0)
            self.w_plot_out.value=buffer.read()

            self.ax.clear()
            plt.close()

        # Clear audio
        self.w_audio_players.children = []

    # Audio widget
    def load_audio_data(self):
        if 'audio_data' not in self.generated_example:
            self.generated_example['audio_data'] = pb.io.audioread.recursive_load_audio(self.generated_example['audio_path'])

        if 'rir' in self.generated_example['audio_data']:
            self.generated_example = mms_msg.scenario.multi_channel_scenario_map_fn(self.generated_example)
        else:
            self.generated_example = mms_msg.scenario.anechoic_scenario_map_fn(self.generated_example)

                
    def load_audio(self, *args):
        self.load_audio_data()
        players = []
        players.append(HBox([Label('observation'), Audio(
            value=_make_wav(self.generated_example['audio_data']['observation'], self.code_locals['sample_rate']), autoplay=False)]))
        self.w_audio_players.children = players

    def show(self):
        self.w = interactive(
            self.generate_code_interactive, 
            **self.interactive_widgets,
        )

        w_audio = HBox([self.w_audio_button, self.w_audio_players])
        display(HBox([self.w, VBox([self.w_plot_out, w_audio])]), self.html)
imp = InteractiveMeetingPlotter()
imp.show()