In [None]:
!git clone --recurse-submodules https://github.com/ktonal/ax6.git
%pip install -e ax6/pypbind ax6/h5mapper ax6/mimikit/

In [None]:
import h5mapper as h5m
import mimikit as mmk
from pbind import *
import os, sys

try:
    from google.colab import auth
    auth.authenticate_user()
    sys.path.append(os.path.join(os.getcwd(), "ax6/"))
except ModuleNotFoundError:
    pass

from ensemble import Ensemble
from datasets import Trainset

In [None]:
BASE_SR = 22050

# this download the 'Cough' trainset for prompts
prompt_files = Trainset(keyword="Cough", sr=BASE_SR).download()


In [None]:
prompts = prompt_files.serve(
    (h5m.Input(data='snd', getter=h5m.AsSlice(shift=0, length=BASE_SR)), ),
    shuffle=False,
    # batch_size=1 --> new stream for each prompt <> batch_size=8 --> one stream for 8 prompts :
    batch_size=8,
    sampler=mmk.IndicesSampler(
        
        # INDICES FOR THE PROMPTS :
        indices=(0, BASE_SR*8, BASE_SR*16, BASE_SR*32)
))

# ID of the models can be copied from axx
wavenet_fft_cough = "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd"

# THE MODELS PATTERN defines which checkpoint (id, epoch) generates for how long (seconds)

stream = Pseq([
    Pbind(
        "id", wavenet_fft_cough,
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=1., hi=8., repeats=1)
        ),
    Pbind(
        "id", wavenet_fft_cough,
        "epoch", Prand([10, 20], inf),
        "seconds", Pwhite(lo=0.5, hi=1.5, repeats=1)
        ),
], inf).asStream()


TOTAL_SECONDS = 60.
    
ensemble = Ensemble(
    TOTAL_SECONDS, BASE_SR, stream,
    # with this you can print the event -- or not
    print_events=False
)

def process_outputs(outputs, bidx):
    for output in outputs[0]:
        mmk.audio(output.cpu().numpy(), sr=BASE_SR)

loop = mmk.GenerateLoop(
    network=ensemble,
    dataloader=prompts,
    inputs=(h5m.Input(None, 
                      getter=h5m.AsSlice(dim=1, shift=BASE_SR, length=BASE_SR),
                      setter=h5m.Setter(dim=1)),),
    n_steps=int(BASE_SR * ensemble.max_seconds),
    add_blank=True,
    process_outputs=process_outputs
)
loop.run()
