In [None]:
!git clone --recurse-submodules https://github.com/ktonal/ax6.git
%pip install torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
%pip install -e ax6/pypbind ax6/h5mapper ax6/mimikit/

In [None]:
import h5mapper as h5m
import mimikit as mmk
from pbind import *
import os, sys

try:
    from google.colab import auth
    auth.authenticate_user()
    sys.path.append(os.path.join(os.getcwd(), "ax6/"))
except ModuleNotFoundError:
    pass

from ensemble import Ensemble
from datasets import Trainset
from checkpoints import Checkpoint
from models.nnn import NearestNextNeighbor

In [None]:
BASE_SR = 22050

# this download the 'Cough' trainset for prompts
prompt_files = Trainset(keyword="Cough", sr=BASE_SR).bank


In [None]:
prompts = prompt_files.serve(
    (h5m.Input(data='snd', getter=h5m.AsSlice(shift=0, length=BASE_SR)), ),
    shuffle=False,
    # batch_size=1 --> new stream for each prompt <> batch_size=8 --> one stream for 8 prompts :
    batch_size=1,
    sampler=mmk.IndicesSampler(
        
        # INDICES FOR THE PROMPTS :
        indices=(0, BASE_SR*8, BASE_SR*16)
))

# ID of the models can be copied from axx
wavenet_fft_cough = "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd"
srnn_cough = "cbba48a801f8b21600818da1362c61aa1287d81793e8cc154771d666956bdcef"

# THE MODELS PATTERN defines which checkpoint (id, epoch) generates for how long (seconds)

stream = Pseq([
    Pbind(
        "type", Checkpoint,
        "id", wavenet_fft_cough,
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=3., hi=5., repeats=1)
        ),
    Pbind(
        # This event inserts the most similar continuation from the Trainset "Cough"
        "type", NearestNextNeighbor,
        "keyword", "Cough",
        "feature", mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag"),
        "seconds", Pwhite(lo=2., hi=5., repeats=1)
        ),
    Pbind(
        "type", Checkpoint,
        "id", srnn_cough,
        "epoch", Prand([200, 300], inf),
        # SampleRNN Checkpoints work best with a temperature parameter :
        "temperature", Pwhite(lo=.25, hi=1.5),
        "seconds", Pwhite(lo=.5, hi=2.5, repeats=1),
    )
], inf).asStream()


TOTAL_SECONDS = 30.
    
ensemble = Ensemble(
    TOTAL_SECONDS, BASE_SR, stream,
    # with this you can print the event -- or not
    print_events=False
)

def process_outputs(outputs, bidx):
    for output in outputs[0]:
        mmk.audio(output.cpu().numpy(), sr=BASE_SR)

loop = mmk.GenerateLoop(
    network=ensemble,
    dataloader=prompts,
    inputs=(h5m.Input(None, 
                      getter=h5m.AsSlice(dim=1, shift=-BASE_SR, length=BASE_SR),
                      setter=h5m.Setter(dim=1)),),
    n_steps=int(BASE_SR * ensemble.max_seconds),
    add_blank=True,
    process_outputs=process_outputs
)
loop.run()


In [None]:
next(stream)["type"] is Checkpoint