In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib agg

In [None]:
import copy

class ExecutableCodeBlocks:
    def __init__(self):
        self.old_blocks = []
        self.blocks = []
        self.states = []

    def step(self):
        self.old_blocks = self.blocks
        self.blocks = []

    def add(self, code, executable=True):
        self.blocks.append((code, executable))

    def execute(self):
        _i = 0
        if not self.old_blocks:
            self.old_blocks = [[] for _ in self.blocks]
        if not self.states:
            self.states = [None for _ in self.blocks]
        _dirty = False
        for _new, _old in zip(self.blocks, self.old_blocks):
            if _new[-1]:
                if (_dirty or _new != _old):
                    if not _dirty:
                        if _i > 0:
                            locals().update(self.states[_i - 1])
                    _dirty = True
                    exec(_new[0])
                    self.states[_i] = copy.copy(locals())
                _i += 1
        state = self.states[_i-1]
        del state['self']
        del state['_i']
        del state['_old']
        del state['_new']
        del state['_dirty']
        return state

    def get_code(self):
        return '\n'.join(b for b, _ in self.blocks)

def generate_code(input_database='WSJ', input_dataset_name='test_eval92', num_speakers=4, duration=60, p_silence=0.2, silence=(0, 2), overlap=(0, 8), example=0, reverberation=False, code_block=None):
    if code_block is None:
        code_block = ExecutableCodeBlocks()
    else:
        code_block.step()
        
    if silence[0] == silence[1]:
        silence = (silence[0], silence[0] + 0.1)
    if overlap[0] == overlap[1]:
        overlap = (overlap[0], overlap[0] + 0.1)

    code_block.add('from padercontrib.database import mixture_generator as g\n\n')
    
    ## Select input database
    input_database_code = '# Set up the input database\n'
    if input_database == 'WSJ':
        input_database_code += 'from padercontrib.database.wsj import WSJ_8kHz\n'
        input_database_code += 'db = WSJ_8kHz()\n'
        input_database_code += f'input_dataset = db.get_dataset("{input_dataset_name}")\n'
    elif input_database == 'LibriSpeech':
        input_database_code += 'from padercontrib.database.librispeech import LibriSpeech\n'
        input_database_code += 'db = LibriSpeech()\n'
        input_database_code += f'input_dataset = db.get_dataset("{input_dataset_name}")\n'
    code_block.add(input_database_code)
    
    
    ## Composition
    code_block.add(f'ds = g.get_composition_dataset(input_dataset, num_speakers={num_speakers})\n')
    
    ## Log Weights
    meeting_code = f'ds = ds.map(g.UniformLogWeightSampler())\n'

    ## Reverberation
    meeting_code += f'{"" if reverberation else "# "}ds = ds.map(g.RIRSampler.from_scenarios_json(\'/net/db/sms_wsj/rirs/scenarios.json\', \'{input_dataset_name}\'))\n'
    
    ## Meeting Sampler
    meeting_code += (
        f'ds = ds.map(g.MeetingSampler(\n'
        f'    duration={duration*8000}, # in samples\n'
        f'    overlap_sampler=g.meeting.overlap_sampler.UniformOverlapSampler(\n'
        f'        max_concurrent_spk=2,\n'
        f'        p_silence={p_silence},\n'
        f'        minimum_silence={int(silence[0]*8000)}, # in samples\n'
        f'        maximum_silence={int(silence[1]*8000)}, # in samples\n'
        f'        minimum_overlap={int(overlap[0]*8000)}, # in samples\n'
        f'        maximum_overlap={int(overlap[1]*8000)}, # in samples\n'
        f'))(input_dataset))\n'
    )
    code_block.add(meeting_code)
    
    # Plot and display
    display_code = '\n# Plot meeting\n'
    display_code += f'plot_meeting(ds[{example})\n\n'

    # Load & play
    display_code += '# Load and play audio\n'
    display_code += 'import paderbox as pb\n'
    display_code += (
        'def load_audio(example):\n'
        '    example[\'audio_data\'] = pb.io.audioread.recursive_load_audio(example[\'audio_path\'])\n'
        '    return example\n'
    )

    if reverberation:
        display_code += 'from mms_msg.scenario import multi_channel_scenario_map_fn\n'
        display_code += 'ds = ds.map(g.multichannel_scenario_map_fn)\n'
    else:
        display_code += 'from mms_msg.scenario import anechoic_scenario_map_fn\n'
        display_code += 'ds = ds.map(g.anechoic_scenario_map_fn)\n'
    display_code += f'example = ds[{example}]\n\n'
    display_code += '# Play, e.g., in a notebook\n'
    display_code += 'pb.io.play(example[\'audio_data\'][\'observation\'])\n'

    code_block.add(display_code, executable=False)

    return code_block

In [None]:
from collections import defaultdict
import paderbox as pb

def plot_meeting(ex, ax):

    speech_activity = defaultdict(pb.array.interval.zeros)
    try:
        num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)
    except KeyError:
        num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.speech_source', allow_early_stopping=True)
    for o, l, s,  in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):
        speech_activity[s][o:o+l]=True

    pb.visualization.plot.activity(speech_activity, ax=ax)
    ax.set_xticklabels(ax.get_xticks() / 8000)

In [None]:
from pygments import highlight

from pygments.lexers import Python3Lexer
from pygments.formatters import HtmlFormatter

def get_formatted_html_code(code):

    formatter = HtmlFormatter()
    formatted_code = highlight(code, Python3Lexer(), formatter)

    return f'''
    <div>
    <style type="text/css" scoped>
    {formatter.get_style_defs()}
    code {{
        background: black;
    }}
    </style>
    
    {formatted_code}
    </div>
    '''

In [None]:
from padercontrib.database.mixture_generator.scenario import anechoic_scenario_map_fn, multi_channel_scenario_map_fn

In [None]:
import numpy as np

def _make_wav(data, rate):
    """ Transform a numpy array to a PCM bytestring
    Taken from IPython display lib """
    from io import BytesIO
    import wave

    data = np.array(data, dtype=float)
    if len(data.shape) == 1:
        nchan = 1
    elif len(data.shape) == 2:
        # In wave files,channels are interleaved. E.g.,
        # "L1R1L2R2..." for stereo. See
        # http://msdn.microsoft.com/en-us/library/windows/hardware/dn653308(v=vs.85).aspx
        # for channel ordering
        nchan = data.shape[0]
        data = data.T.ravel()
    else:
        raise ValueError('Array audio input must be a 1D or 2D array')

    max_abs_value = np.max(np.abs(data))
    normalization_factor = max_abs_value
    scaled = data / normalization_factor * 32767
    scaled = scaled.astype('<h').tobytes()
    fp = BytesIO()
    waveobj = wave.open(fp,mode='wb')
    waveobj.setnchannels(nchan)
    waveobj.setframerate(rate)
    waveobj.setsampwidth(2)
    waveobj.setcomptype('NONE','NONE')
    waveobj.writeframes(scaled)
    val = fp.getvalue()
    waveobj.close()

    return val

In [None]:
from matplotlib import pyplot as plt
from ipywidgets import *
from io import BytesIO

class InteractiveMeetingPlotter():
    def __init__(self):
        # Control widgets
        self.w_input_database = Dropdown(options=['WSJ', 'LibriSpeech'], value='WSJ')
        self.w_input_dataset = Dropdown(options=['test_eval92', 'train_si284'])
        self.w_example = IntSlider(min=0, max=50, value=0)

        # def update_input_dataset(args):
        #     if args['new'] == 'WSJ':
        #         w_input_dataset.options = ['test_eval92', 'train_si284']
        #     elif args['new'] == 'LibriSpeech':
        #         w_input_dataset.options = ['train_clean', 'test_clean']
        # self.w_input_database.observe(update_input_dataset, 'value')

        self.w_num_speakers = IntSlider(min=2, max=5, value=4)
        self.w_duration = IntSlider(max=120, value=60, step=10)
        self.w_p_silence = FloatSlider(min=0, max=1, value=0.1)

        self.w_silence = FloatRangeSlider(min=0, max=6, value=(0, 2), step=0.5)
        self.w_overlap = FloatRangeSlider(min=0, max=15, value=(0, 8), step=0.5)
        self.w_reverb = Checkbox(value=False)

        # Output widgets
        self.w_plot_out = Image()
        self.w_plot_out.layout.height = '200px'
        self.html = HTML()
        self.w_audio_button = Button(description='Load Audio')
        self.w_audio_button.on_click(self.load_audio)
        self.w_audio_players = VBox()
        
        # Global variables
        self.code = ExecutableCodeBlocks()
        self.generated_example = None
        self.audio_data = None
        self.fig = None
        self.ax = None
        
    def generate_code_interactive(self, input_database, input_dataset_name, num_speakers, duration, p_silence=0.2, silence=1, overlap=1, example=0, reverberation=False):
        generate_code(input_database, input_dataset_name, num_speakers, duration, p_silence, silence, overlap, example, reverberation, code_block=self.code)
        self.html.value = get_formatted_html_code(self.code.get_code())

        self.generated_example = self.code.execute()['ds'][example]
        # with self.w_plot_out:
        with pb.visualization.figure_context(figure_size=(10, 3)):
            if self.fig is None:
                self.fig, self.ax = plt.subplots()

            plot_meeting(self.generated_example, self.ax)
            buffer = BytesIO()
            self.fig.savefig(buffer)
            buffer.seek(0)
            self.w_plot_out.value=buffer.read()

            self.ax.clear()

        # Clear audio
        self.w_audio_players.children = []

    # Audio widget
    def load_audio_data(self):
        if 'audio_data' not in self.generated_example:
            self.generated_example['audio_data'] = pb.io.audioread.recursive_load_audio(self.generated_example['audio_path'])

        if 'rir' in self.generated_example['audio_data']:
            self.generated_example = multi_channel_scenario_map_fn(self.generated_example)
        else:
            self.generated_example = anechoic_scenario_map_fn(self.generated_example)

                
    def load_audio(self, *args):
        self.load_audio_data()
        players = []
        players.append(HBox([Label('observation'), Audio(value=_make_wav(self.generated_example['audio_data']['observation'], 8000), autoplay=False)]))
        self.w_audio_players.children = players

    def show(self):
        w = interactive(
            self.generate_code_interactive, 
            input_database=self.w_input_database,
            input_dataset_name=self.w_input_dataset,
            num_speakers=self.w_num_speakers,
            duration=self.w_duration,
            p_silence=self.w_p_silence,
            overlap=self.w_overlap,
            silence=self.w_silence,
            example=self.w_example,
            reverberation=self.w_reverb,
        )

        w_audio = HBox([self.w_audio_button, self.w_audio_players])
        display(HBox([w, VBox([self.w_plot_out, w_audio])]), self.html)
InteractiveMeetingPlotter().show()