## Connect to your GDrive 
In order to train the network on your data, create a directory named `data/`
in the current working directory (cwd) of this notebook (when on colab and connected to gdrive
this would be the `MyDrive/` directory in your gdrive account) and put audio files in it. 

In [None]:
from google.colab import drive
drive.mount('/gdrive')
# this set the cwd of the notebook
%cd /gdrive/MyDrive 

### Install `mimikit`

In [None]:
%pip uninstall torchtext -y
%pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
%pip install mimikit[colab]==0.4.3

In [None]:
# colab crashes if following import is done within mimikit
import pytorch_lightning as pl

### imports

In [None]:
import h5mapper as h5m
import mimikit as mmk
from pbind import Pseq, Pbind, Pwhite, inf

### Get some checkpoints

In [None]:
ROOT_DIR = './'
checkpoints = {}
for i, path in enumerate(h5m.FileWalker(mmk.CHECKPOINT_REGEX, ROOT_DIR)):
    checkpoints[i] = mmk.Checkpoint.from_path(path)
checkpoints

### Get the prompts from which to generate

In [None]:
db = checkpoints[0].dataset

OUTPUT_SR = 22050
PROMPTS_POS_SEC = (
    0, OUTPUT_SR // 2, OUTPUT_SR
)
PROMPT_LENGTH_SEC = OUTPUT_SR

# get a batch of prompts
prompts = next(iter(db.serve(
    (h5m.Input(data='signal', getter=h5m.AsSlice(shift=0, length=PROMPT_LENGTH_SEC)),),
    shuffle=False,
    # batch_size=1 --> new stream for each prompt <> batch_size=8 --> one stream for 8 prompts :
    batch_size=len(PROMPTS_POS_SEC),
    sampler=mmk.IndicesSampler(
        # INDICES FOR THE PROMPTS :
        indices=PROMPTS_POS_SEC
    ))))[0]
prompts.shape

### Define a pattern of models

In [None]:
# THE MODELS PATTERN defines which checkpoint (id, epoch) generates for how long (seconds)

stream = Pseq([
    Pbind(
        "generator", checkpoints[0],
        "seconds", Pwhite(lo=3., hi=5., repeats=1)
    ),
    # Pbind(
    #     # TODO: This event inserts the most similar continuation from the Trainset "Cough"
    #     "seconds", Pwhite(lo=2., hi=5., repeats=1)
    # ),
    Pbind(
        "generator", checkpoints[1],
        # SampleRNN Checkpoints work best with a temperature parameter :
        "temperature", Pwhite(lo=.25, hi=1.5),
        "seconds", Pwhite(lo=.1, hi=1., repeats=1),
    )
], inf).asStream()
stream

### Generate

In [None]:
TOTAL_SECONDS = 10.

ensemble = mmk.EnsembleGenerator(
    prompts, TOTAL_SECONDS, OUTPUT_SR, stream,
    # with this you can print the event -- or not
    print_events=False
)
outputs = ensemble.run()
logger = mmk.AudioLogger(sr=OUTPUT_SR)
logger.display_batch(outputs)

----------------------------

<img src="https://ktonal.com/k-circle-bw.png" alt="logo" width="75"/>