In [None]:
import mimikit as mmk
import h5mapper as h5m
from itertools import chain
import torch
import torch.nn as nn


# Architecture + Feature => Network

In [None]:
tp = h5m.TypedFile("test.h5")
arch = mmk.WNBlock(blocks=(4,), pad_side=0)
feat = mmk.MuLawSignal(sr=16000, normalize=True, q_levels=256)

inpt_mods = [feat.input_module(d) for feat, d in zip([feat], chain(arch.hp.dims_dilated, arch.hp.dims_1x1))]
out_d = arch.hp.skips_dim if arch.hp.skips_dim is not None else arch.hp.dims_dilated[0]
outpt_mods = [feat.output_module(out_d) for feat in [feat]]
net = arch.with_io(inpt_mods, outpt_mods)


net(torch.randint(0, 256, (4, 32,)), temperature=.95).size(), net.s.args, net.s.full_kwargs, \
net.s.default, net.s.in_

In [None]:
getters = net.getters(batch_length=32, stride=1, hop_length=1, shift_error=0)
batch = (
    h5m.Input(proxy=tp.snd, getter=getters['inputs'], transform=feat.transform),
    h5m.Target(proxy=tp.snd, getter=getters['targets'], transform=feat.transform),
)
dl = tp.serve(batch,
              shuffle=True,
              batch_size=8,
              num_workers=8,
              pin_memory=True,
              persistent_workers=True, # need this!
             )

# inp, outp = next(iter(dl))
# inp.shape, outp.shape, inp, outp, tp.snd[1980:2000]

In [None]:
tr_loop = mmk.TrainLoop(
    loader=dl,
    net=net,
    loss_fn=lambda out, trgt: {"loss": feat.loss_fn(out, trgt)},
    optim=torch.optim.Adam(net.parameters(), lr=1e-3)
)

class Logs(h5m.TypedFile):
     ckpt = h5m.TensorDict(net.state_dict())

logs = Logs("logs.h5", mode='w')

callbacks = [
    mmk.MMKCheckpoint(h5_tensor_dict=logs.ckpt, epochs=1),
]
logger = mmk.LossLogger(logs)

tr_loop.run(max_epochs=5, 
           logger=logger,
           callbacks=callbacks,
           limit_train_batches=58)

logs.info()
logs.loss

In [None]:
logs.index

In [None]:
logs.loss[:], net.load_state_dict(logs.ckpt['epoch=1-step=58']), logs.ckpt.load_hp()

In [None]:
import numpy as np

# Logs = h5m.typedfile("Logs", {})
# logs = Logs("logs.h5", mode='r+')


# logs.info()

fdict = {"test": lambda x: x-1, 'test2': lambda x: x+2}
logs.loss.compute(fdict, )

# for i in logs.__src__.id[logs.loss.refs[:].astype(np.bool)]:
#     print(i)
#     res = {k: f(logs.loss.get(i)) for k, f in fdict.items()}
#     logs.add(i, res)

logs.__src__.id[:]

In [None]:
logs.get("4")

In [None]:
logs.loss.refs[:].shape, logs.__src__.id[:].shape, logs.index

In [None]:
logs.info()

# Generate Loop

In [None]:
net.use_fast_generate = False

n_batches = 2
batch_size = 8
prompt_length = 32
n_steps = 100

# Gen DataLoader
gen_getters = net.getters(batch_length=prompt_length, stride=1, hop_length=1, shift_error=0)
gen_batch = (h5m.Input(proxy=tp.snd, getter=gen_getters['inputs'], transform=feat.transform),)
gen_dl = tp.serve(gen_batch,
                  shuffle=False,
                  batch_size=batch_size,
                  sampler=torch.randint(0, tp.snd.shape[0], (batch_size*n_batches,),)
                 )

# Gen Loop
outputs = {}
loop = mmk.GenerateLoop(
    network=net,
    dataloader=gen_dl,
    interfaces=[
        mmk.DynamicDataInterface(
            None,
            getter=h5m.AsSlice(dim=1, shift=-net.rf, length=net.rf),
            setter=mmk.Setter(dim=1)
        ),
        # temperature
        mmk.DynamicDataInterface(
            None,
            prepare=lambda src: torch.rand(batch_size, n_steps) + 1,
            getter=h5m.AsSlice(dim=1, shift=0, length=1),
            setter=None,
        )
    ],
    n_batches=n_batches,
    n_steps=n_steps,
    device='cpu',
    process_outputs=lambda out, i: outputs.__setitem__(i, out)
)

loop.run()

len(outputs), outputs[0]