In [None]:
import os

import ipywidgets as widgets;
from ipywebrtc import AudioRecorder, CameraStream, AudioStream;

from onsei.utils import SpeechRecord, generate_transcript, segment_speech;
from onsei.figures import ViewRecordFigure, CompareFigure;

# Globals


def get_jsut_samples():
    samples = {}
    basepath = "data/jsut_basic5000_sample"
    with open(os.path.join(basepath, "transcript_utf8.txt")) as f:
        for line in f:
            basename, sentence = line.rstrip().split(':')
            filename = os.path.join(basepath, f"{basename}.wav")
            samples[sentence] = {
                "filename": filename,
                "sentence": sentence,
            }
    return samples

samples = get_jsut_samples()

default_sample_key = list(samples.keys())[0]
default_sample = samples[default_sample_key]

teacher_rec = None
student_rec = None

default_autoplay = True


# Create widgets


w_select = widgets.Dropdown(
    options=samples.keys(),
    value=default_sample_key,
    description='Samples:',
    disabled=False,
    layout=widgets.Layout(width='100%')
)

w_autoplay_tick = widgets.Checkbox(
    value=default_autoplay,
    description='Autoplay',
    disabled=False,
    indent=False
)
w_options_accordion = widgets.Accordion(children=[w_autoplay_tick], selected_index=None)
w_options_accordion.set_title(0, "Options")

w_audio = widgets.Audio(value=b'', format='wav', autoplay=default_autoplay, loop=False)

w_sentence = widgets.HTML(value='')

camera = CameraStream(constraints={'audio': True, 'video': False})
w_recorder = AudioRecorder(stream=camera)

w_compare_btn = widgets.Button(description="Compare")

w_cmp_result = widgets.Label(value='')

fig_teacher = ViewRecordFigure(title="Teacher's recording")
fig_student = ViewRecordFigure(title="Your recording")

fig_cmp = CompareFigure()


# Callbacks


def update_autoplay(change):
    w_audio.autoplay = change['new']

w_autoplay_tick.observe(update_autoplay, 'value')


def get_sample_audio_data(sample):
    return open(sample['filename'], 'rb').read()


def update_sample(sample):
    global teacher_rec

    with w_sentence.hold_sync():
        transcript = generate_transcript(sample['sentence'])
        w_sentence.value = f'<p style="font-size: xx-large">{sample["sentence"]}</p> ({transcript})'

    teacher_rec = SpeechRecord(sample['filename'], sentence=sample['sentence'], name="Teacher");

    w_audio.value = get_sample_audio_data(sample);
    
    fig_teacher.update_data(teacher_rec);
    fig_student.clear();
    fig_cmp.clear();
    
    w_cmp_result.value = ""

    
update_sample(default_sample);


def load_selected_sample(change):
    sample = samples[change["new"]]
    update_sample(sample)


w_select.observe(load_selected_sample, 'value');


def get_student_wav_filename():
    try:
        w_recorder.save('test.webm')
    except ValueError as exc:
        if str(exc).startswith('No data'):
            w_cmp_result.value = f"Record something first !"
        raise exc
            
    !ffmpeg -hide_banner -loglevel error -y -i test.webm -ar 16000 -ac 1 test.wav
    return 'test.wav'


def run_compare(_):
    global teacher_rec
    global student_rec

    sample = samples[w_select.value]

    student_wav_filename = get_student_wav_filename()
    # Alternatively, here is a sample:
    #student_wav_filename = "data/mizo_wo_student.wav"

    student_rec = SpeechRecord(student_wav_filename, sample['sentence'], name="Student");
    fig_student.update_data(student_rec);
    
    try:
        student_rec.align_with(teacher_rec);
        mean_distance = student_rec.compare_pitch();
        w_cmp_result.value = f"Success !\nMean distance = {mean_distance:.2f}"
    except Exception as exc:
        w_cmp_result.value = "FAILED !"
        raise exc

    fig_cmp.update_data(teacher_rec, student_rec)


w_compare_btn.on_click(run_compare)


# Layout

box = widgets.Box([
    widgets.Box([w_select, w_options_accordion]),
    w_sentence,
    widgets.Box([
        widgets.VBox([widgets.Label(value="Teacher's recording:"), w_audio], layout=widgets.Layout(width='33%')),
        widgets.VBox([widgets.Label(value="Your recording:"), w_recorder], layout=widgets.Layout(width='33%')),
        widgets.VBox([w_compare_btn, w_cmp_result], layout=widgets.Layout(width='33%')),
    ]),
    fig_cmp,
    fig_teacher,
    fig_student,
], layout=widgets.Layout(display="flex", flex_flow="column", align_items="stretch", align_content="center")
)

display(box)