In [None]:
import mimikit as mmk
import h5mapper as h5m
import dill as pickle
import torch
import torch.nn as nn
from torch.utils.data import Sampler, SequentialSampler

# Architecture + Feature => Network

In [None]:
tp = h5m.TypedFile("test.h5", mode='r', keep_open=True)
net = mmk.Seq2SeqLSTM()
feat = mmk.Spectrogram(sr=16000, normalize=True, n_fft=1024, hop_length=256, coordinate='mag')

net

# Train Dataloader

In [None]:
# getters = net.getters(batch_length=net.shift, shift_error=0)
net.shift = 8
hops_per_batch = 2
batch = (
    h5m.Input(proxy=tp.snd,
              getter=h5m.AsSlice(
#                   length=(feat.n_fft-feat.hop_length)+(net.shift-1)*feat.hop_length,
                  length=(hops_per_batch*net.shift-1)*feat.hop_length,
                  stride=feat.hop_length),
              transform=feat.transform),
    h5m.Target(proxy=tp.snd,
               getter=h5m.AsSlice(
                   shift=net.shift*feat.hop_length,
                   length=(hops_per_batch*net.shift-1)*feat.hop_length,
                   stride=feat.hop_length),
               transform=feat.transform),
)
dl = tp.serve(batch,
             batch_sampler=mmk.TBPTTSampler(tp.snd.shape[0]//feat.hop_length,
                                           batch_size=8,
                                           chunk_length=net.shift*20,
                                           seq_len=net.shift))

# inp, outp = next(iter(dl))
# # inp.shape, outp.shape, inp, outp, tp.snd[1980:2000]
# inp.shape, outp.shape
tp.f_

# Train Loop

In [None]:
tr_loop = mmk.TrainLoop(
    loader=dl,
    net=net,
    loss_fn=lambda out, trgt: {"loss": feat.loss_fn(out, trgt)},
    optim=torch.optim.Adam(net.parameters(), lr=1e-3),
    tbptt_len=20
)

Logs = h5m.typedfile("Logs",
                     {'ckpt': h5m.TensorDict(net.state_dict())}
                    )
logs = Logs("logs.h5", mode='w')

callbacks = [
    mmk.MMKCheckpoint(h5_tensor_dict=logs.ckpt, epochs=1),
]
logger = mmk.LossLogger(logs)

tr_loop.run(max_epochs=10,
           logger=logger,
           callbacks=callbacks,
           limit_train_batches=100
           )

logs.info()
logs.loss

In [None]:
logs.index, logs.loss[:], net.load_state_dict(logs.ckpt['epoch=1-step=10']), logs.ckpt.load_hp()

# Generate Loop

In [None]:

n_batches = 2
batch_size = 8
prompt_length = 32
n_steps = 10 * net.shift

# Gen DataLoader
gen_batch = (h5m.Input(proxy=tp.snd,
                       getter=h5m.AsSlice(dim=0, shift=0, length=prompt_length*feat.hop_length),
                       transform=feat.transform),)
gen_dl = tp.serve(gen_batch,
                  shuffle=False,
                  batch_size=batch_size,
                  sampler=torch.randint(0, tp.snd.shape[0]//feat.hop_length, (batch_size*n_batches,),)
                 )

# Gen Loop
outputs = {}
loop = mmk.GenerateLoop(
    network=net,
    dataloader=gen_dl,
    interfaces=[
        mmk.DynamicDataInterface(
            None,
            getter=h5m.AsSlice(dim=1, shift=-net.shift, length=net.shift),
            setter=mmk.Setter(dim=1)
        ),
#         temperature
#         mmk.DynamicDataInterface(
#             None,
#             prepare=lambda src: torch.rand(batch_size, n_steps) + 1,
#             getter=h5m.AsSlice(dim=1, shift=0, length=1),
#             setter=None,
#         )
    ],
    hop=net.shift,
    n_batches=n_batches,
    n_steps=n_steps,
    device='cpu',
    process_outputs=lambda out, i: outputs.__setitem__(i, out)
)

loop.run()

len(outputs)