In [None]:
import os

from bqplot import pyplot as plt;
from bqplot import DateScale, LinearScale, Lines, Axis, Figure, Label;
import ipywidgets as widgets;
from ipywebrtc import AudioRecorder, CameraStream, AudioStream;
import numpy as np;

from onsei.utils import SpeechRecord, generate_transcript, segment_speech;


# Globals


def get_jsut_samples():
    samples = {}
    basepath = "data/jsut_basic5000_sample"
    with open(os.path.join(basepath, "transcript_utf8.txt")) as f:
        for line in f:
            basename, sentence = line.rstrip().split(':')
            filename = os.path.join(basepath, f"{basename}.wav")
            # transcript = generate_transcript(sentence)
            samples[sentence] = {
                "filename": filename,
                "sentence": sentence,
            }
    return samples

samples = get_jsut_samples()

default_sample_key = list(samples.keys())[0]
default_sample = samples[default_sample_key]

teacher_rec = None
student_rec = None

default_autoplay = True


# Create widgets


w_select = widgets.Dropdown(
    options=samples.keys(),
    value=default_sample_key,
    description='Samples:',
    disabled=False,
)

w_autoplay_tick = widgets.Checkbox(
    value=default_autoplay,
    description='Autoplay',
    disabled=False,
    indent=False
)
w_options_accordion = widgets.Accordion(children=[w_autoplay_tick], selected_index=None)
w_options_accordion.set_title(0, "Options")

w_audio = widgets.Audio(value=b'', format='wav', autoplay=default_autoplay, loop=False)

w_sentence = widgets.HTML(value='')

camera = CameraStream(constraints={'audio': True, 'video': False})
w_recorder = AudioRecorder(stream=camera)

w_compare_btn = widgets.Button(description="Compare")


scale_ts = LinearScale()
scale_cmp_ts = LinearScale()
scale_pitch = LinearScale()
scale_intensity = LinearScale()
scale_norm_pitch = LinearScale()

line_pitch = Lines(x=[], y=[], scales={'x': scale_ts, 'y': scale_pitch}, labels=["Pitch"], colors=["dodgerblue"], display_legend=True)
line_intensity = Lines(x=[], y=[], scales={'x': scale_ts, 'y': scale_intensity}, labels=["Intensity"], colors=["lightgreen"], fill="bottom", display_legend=True)
line_vad_intensity = Lines(x=[], y=[], scales={'x': scale_ts, 'y': scale_intensity}, labels=["Detected Speech"], colors=["red"], fill="bottom", display_legend=True)
ax_ts = Axis(scale=scale_ts, label="Time (s)", grid_lines="solid")
ax_pitch = Axis(scale=scale_intensity, label="Pitch (Hz)", orientation="vertical", grid_lines="solid", side="left")
ax_intensity = Axis(scale=scale_intensity, label="Intensity (dB)", orientation="vertical", grid_lines="solid", side="right")
label_transcript_teacher = Label(x=[], y=[], text=[], colors=[])
fig_teacher = Figure(
    marks=[line_intensity, line_vad_intensity, line_pitch, label_transcript_teacher],
    axes=[ax_ts, ax_pitch, ax_intensity],
    legend_location="top-right",
    title="Teacher's recording"
)

line_cmp_pitch_teacher = Lines(x=[], y=[], scales={'x': scale_ts, 'y': scale_norm_pitch}, labels=["Teacher Norm Pitch"], colors=["blue"], display_legend=True)
line_cmp_pitch_student = Lines(x=[], y=[], scales={'x': scale_ts, 'y': scale_norm_pitch}, labels=["Student Norm Pitch"], colors=["red"], display_legend=True)
ax_cmp_ts = Axis(scale=scale_cmp_ts, label="Time (s)", grid_lines="solid")
ax_cmp_pitch = Axis(scale=scale_norm_pitch, label="Normalized Pitch", orientation="vertical", grid_lines="solid", side="left")
fig_cmp = Figure(marks=[line_cmp_pitch_teacher, line_cmp_pitch_student], axes=[ax_cmp_ts, ax_cmp_pitch], legend_location="top-right", title="Pitch comparison")


# Callbacks


def update_autoplay(change):
    w_audio.autoplay = change['new']

w_autoplay_tick.observe(update_autoplay, 'value')


def clear_compare_plot():
    with line_cmp_pitch_student.hold_sync(), line_cmp_pitch_teacher.hold_sync():
        line_cmp_pitch_student.x = []
        line_cmp_pitch_student.y = []
        line_cmp_pitch_teacher.x = []
        line_cmp_pitch_teacher.y = []


def update_compare_plot():
    with line_cmp_pitch_student.hold_sync(), line_cmp_pitch_teacher.hold_sync():
        line_cmp_pitch_student.x = teacher_rec.align_ts
        line_cmp_pitch_student.y = student_rec.norm_aligned_pitch
        line_cmp_pitch_teacher.x = teacher_rec.align_ts
        line_cmp_pitch_teacher.y = teacher_rec.norm_aligned_pitch


def get_sample_audio_data(sample):
    return open(sample['filename'], 'rb').read()


def update_teacher_plot():
    global teacher_rec

    with line_pitch.hold_sync(), line_intensity.hold_sync(), line_vad_intensity.hold_sync(), label_transcript_teacher.hold_sync():
        y = teacher_rec.pitch_freq_filtered.copy()
        y[y == 0] = np.nan
        line_pitch.x = teacher_rec.pitch.xs()
        line_pitch.y = y

        line_intensity.x = teacher_rec.intensity.xs()
        line_intensity.y = teacher_rec.intensity.values.T

        line_vad_intensity.x = teacher_rec.intensity.xs()[teacher_rec.begin_idx:teacher_rec.end_idx]
        line_vad_intensity.y = teacher_rec.intensity.values.T[teacher_rec.begin_idx:teacher_rec.end_idx]

        phonemes = segment_speech(teacher_rec.wav_filename, teacher_rec.transcript, teacher_rec.begin_ts, teacher_rec.end_ts)
        print(phonemes)

        
        def ts_to_ratio(ts):
            return (ts - line_pitch.x[0]) / (line_pitch.x[-1] - line_pitch.x[0])
        xs = []
        texts = []
        colors = []
        for pho_beg, pho_end, _ in phonemes:
            xs.extend([ts_to_ratio(pho_beg), ts_to_ratio(pho_end)])
            texts.extend(["|", "|"])
            colors.extend(["gray", "gray"])
        for pho_beg, pho_end, pho in phonemes:
            ts = pho_beg + (pho_end - pho_beg) / 2
            x = ts_to_ratio(ts)
            xs.append(x)
            texts.append(pho)
            colors.append("orange")
        label_transcript_teacher.x = xs
        label_transcript_teacher.y = [0.5 for _ in range(len(xs))]
        label_transcript_teacher.text = texts
        label_transcript_teacher.colors = colors


def update_sample(sample):
    global teacher_rec

    with w_sentence.hold_sync():
        transcript = generate_transcript(sample['sentence'])
        w_sentence.value = f"<b>{sample['sentence']}</b> ({transcript})"

    teacher_rec = SpeechRecord(sample['filename'], sentence=sample['sentence'], name="Teacher");

    w_audio.value = get_sample_audio_data(sample)
    
    update_teacher_plot()

    
update_sample(default_sample);


def load_selected_sample(change):
    sample = samples[change["new"]]
    update_sample(sample)
    clear_compare_plot()


w_select.observe(load_selected_sample, 'value');
        

def run_compare(_):
    global student_rec

    sample = samples[w_select.value]

    w_recorder.save('test.webm')
    !ffmpeg -hide_banner -loglevel error -y -i test.webm -ar 16000 -ac 1 test.wav
    student_wav_filename = 'test.wav'
    # Alternatively, here is a sample:
    #student_wav_filename = "data/ps/ps1_boku_no_chijin-student3.wav"

    student_rec = SpeechRecord(student_wav_filename, sample['sentence'], name="Student")
    student_rec.align_with(teacher_rec)
    mean_distance = student_rec.compare_pitch();
    print(f"mean_distance = {mean_distance}")

    update_compare_plot()
    
w_compare_btn.on_click(run_compare)


# Layout

items = [w_select, w_options_accordion]
w_hbox = widgets.HBox(items)
display(w_hbox)

display(w_sentence)

display(widgets.HBox([
    widgets.VBox([widgets.Label(value="Teacher's recording:"), w_audio]),
    widgets.VBox([widgets.Label(value="Your recording:"), w_recorder]),
    w_compare_btn
]))

display(fig_cmp)
display(fig_teacher)