# Source-Filter Utterance Data Generation
*This notebook uses a FAUST implementation of a source-filter model of the vocal tract to generate a dataset of synthetic utterances.*

In [None]:
import numpy as np
import os
import torch
from torch import nn
import torchaudio
from faust_ctypes.wrapper import Faust
import itertools
import math

from IPython.display import Audio
import matplotlib.pyplot as plt

from RSA_helpers import *

## Define params & test synthesis pipeline

In [None]:
dsp = Faust("../faust_dsp/SF_voc_synth_f.so")
samples = dsp.proc.compute(100000)
Audio(samples, rate=44100)

In [None]:
device = ("mps" if torch.backends.mps.is_available() else "cpu")
torch.set_default_dtype(torch.float)
device

## Synth function

In [None]:
def synthesize_voice(params_sequence):
    n_control_steps = params_sequence.shape[0]
    params_seq_list = params_sequence.tolist()
    output = np.zeros(n_control_steps*AUDIO_SR//CONTROL_SR, dtype=np.float32)
    state = params_sequence[0,:]

    # warm-up
    dsp.proc.compute(500)
    
    for i in range(n_control_steps):
        dsp.ui.b_vocal.p_freq.zone = params_sequence[i, 0].item() + 0.01*np.random.rand()
        dsp.ui.b_vocal.p_gain.zone = params_sequence[i, 1].item()
        dsp.ui.b_vocal.p_vowel.zone = params_sequence[i, 2].item() + 0.02*np.random.rand()
        dsp.ui.b_vocal.p_fricative.zone = (params_sequence[i, 3].item() > 0.99)
        dsp.ui.b_vocal.p_plosive.zone = (params_sequence[i, 4].item() > 0.2)
        output[i*AUDIO_SR//CONTROL_SR : (i+1)*AUDIO_SR//CONTROL_SR] = dsp.proc.compute(AUDIO_SR//CONTROL_SR)

    return torch.tensor(output)

## Dataset generation

In [None]:
def generate_random_walk_data():
    synth_params = torch.zeros((N_SAMPLES, SAMPLE_LEN, N_VOCAL_TRACT_CONTROLS))
    
    for sample_idx in range(N_SAMPLES):
        state = torch.rand(N_VOCAL_TRACT_CONTROLS)   # freq, vowel, fric
        for step_idx in range(SAMPLE_LEN):
            synth_params[sample_idx, step_idx, :] = torch.sin(state)*0.5 + 0.5   # write state
            state += torch.rand(N_VOCAL_TRACT_CONTROLS)*0.5   # update state
    
    for sample_idx in range(N_SAMPLES):
        synthesized_audio = synthesize_voice(synth_params[sample_idx,:,:])
        torchaudio.save(VOCAL_DATA_DIR+'synth_'+str(sample_idx) + '.wav', synthesized_audio.unsqueeze(0), sample_rate=AUDIO_SR)
    
    torch.save(synth_params, VOCAL_DATA_DIR+'params.pt')


def generate_grid_osc_data():
    for controls_tup in itertools.product(range(0, N_GRID_SEQ_TYPE), repeat=N_VOCAL_TRACT_CONTROLS):
        if 0 in controls_tup or 1 in controls_tup or 2 in controls_tup or 3 in controls_tup:
            synth_params = torch.zeros((SAMPLE_LEN, N_VOCAL_TRACT_CONTROLS))
            for control_num in range(N_VOCAL_TRACT_CONTROLS):
                seq_order = controls_tup[control_num]
                synth_params[:,control_num] = lookup_grid_seq(seq_order, SAMPLE_LEN)
        
            synthesized_audio = synthesize_voice(synth_params)
            torchaudio.save(VOCAL_DATA_DIR + f"{'-'.join(map(str,controls_tup))}.wav", synthesized_audio.unsqueeze(0), sample_rate=AUDIO_SR)


def lookup_grid_seq(seq_order, length):
    # constant
    if seq_order <= 3:
        return 0.33*seq_order*torch.ones(SAMPLE_LEN)

    # sine
    elif seq_order <= 6:
        return 0.5 + 0.4*torch.sin(torch.linspace(0,SAMPLE_LEN//CONTROL_SR,steps=SAMPLE_LEN)*2*math.pi*(seq_order-3))

    # sine biased hi
    elif seq_order <= 7:
        return 0.75 + 0.25*torch.sin(torch.linspace(0,SAMPLE_LEN//CONTROL_SR,steps=SAMPLE_LEN)*2*math.pi*(seq_order-5))

    # sine biased low
    elif seq_order <= 8:
        return 0.25 + 0.25*torch.sin(torch.linspace(0,SAMPLE_LEN//CONTROL_SR,steps=SAMPLE_LEN)*2*math.pi*(seq_order-6))

    # falling saw
    elif seq_order <= 9:
        return 1 - (torch.linspace(0, SAMPLE_LEN*(seq_order-7), steps=SAMPLE_LEN) % 1)

    # random
    else:
        return torch.rand(SAMPLE_LEN//3 + 3).repeat_interleave(3)[:SAMPLE_LEN]

## Run!

In [None]:
generate_grid_osc_data()