In [None]:
import mimikit as mmk
import h5mapper as h5m

import torch
import torch.nn as nn
from torch.utils.data import Sampler, SequentialSampler

# Architecture + Feature => Network

In [None]:
tp = h5m.TypedFile("test.h5")
net = mmk.SampleRNN()
feat = mmk.MuLawSignal(sr=16000, normalize=True, q_levels=256)

[st for st in net]

# Train Dataloader

In [None]:
getters = net.getters(batch_length=512, shift_error=0)
batch = (
    tuple(h5m.Input(proxy=tp.snd, getter=g_input, transform=feat.transform)
          for g_input in getters['inputs']),
    h5m.Target(proxy=tp.snd, getter=getters['targets'], transform=feat.transform),
)
dl = tp.serve(batch,
              num_workers=8,
              pin_memory=True,
              persistent_workers=True, # need this!
              batch_sampler=mmk.TBPTTSampler(tp.snd.shape[0],
                                           batch_size=8,
                                           chunk_length=16000*8,
                                           seq_len=512))

inp, outp = next(iter(dl))
# inp.shape, outp.shape, inp, outp, tp.snd[1980:2000]
outp

# Train Loop

In [None]:
tr_loop = mmk.TrainLoop(
    loader=dl,
    net=net,
    loss_fn=lambda out, trgt: {"loss": feat.loss_fn(out, trgt)},
    optim=torch.optim.Adam(net.parameters(), lr=1e-3),
    tbptt_len=2
)

Logs = h5m.typedfile("Logs",
                     {'ckpt': h5m.TensorDict(net.state_dict())}
                    )
logs = Logs("logs.h5", mode='w')

callbacks = [
    mmk.MMKCheckpoint(h5_tensor_dict=logs.ckpt, epochs=1),
]
logger = mmk.LossLogger(logs)

tr_loop.run(max_epochs=2,
           logger=logger,
           callbacks=callbacks,
           limit_train_batches=10)

logs.info()
logs.loss

In [None]:
logs.index, logs.loss[:], net.load_state_dict(logs.ckpt['epoch=1-step=10']), logs.ckpt.load_hp()

# Generate Loop

In [None]:

n_batches = 20
batch_size = 8
prompt_length = 32
n_steps = 100

# Gen DataLoader
gen_getters = net.getters(batch_length=prompt_length, shift_error=0)
gen_batch = (h5m.Input(proxy=tp.snd,
                       getter=h5m.AsSlice(dim=0, shift=0, length=prompt_length),
                       transform=feat.transform),)
gen_dl = tp.serve(gen_batch,
                  shuffle=False,
                  batch_size=batch_size,
                  sampler=torch.randint(0, tp.snd.shape[0], (batch_size*n_batches,),)
                 )

# Gen Loop
outputs = {}
loop = mmk.GenerateLoop(
    network=net,
    dataloader=gen_dl,
    interfaces=[
        mmk.DynamicDataInterface(
            None,
            getter=h5m.AsSlice(dim=1, shift=-net.shift, length=net.shift),
            setter=mmk.Setter(dim=1)
        ),
#         temperature
        mmk.DynamicDataInterface(
            None,
            prepare=lambda src: torch.rand(batch_size, n_steps) + 1,
            getter=h5m.AsSlice(dim=1, shift=0, length=1),
            setter=None,
        )
    ],
    n_batches=n_batches,
    n_steps=n_steps,
    device='cpu',
    process_outputs=lambda out, i: outputs.__setitem__(i, out)
)

loop.run()

len(outputs)