In [None]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os
import json
from random import randint
import matplotlib.pyplot as plt

from models.wavenets import WaveNetFFT, WaveNetQx
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM
from models.ensemble import Ensemble
from mains import train, generate

from checkpoints import group_ckpts_by_trainset, load_feature, load_files, load_network_cls
from datasets import TRAINSET, VERDI_X

import numpy as np

from mimikit.extract.from_neighbors import *

In [None]:
root = "trainings/wn-verdi-x"

CKPTS = group_ckpts_by_trainset(root)

In [None]:
CKPTS

In [None]:
torch.unique(torch.randint(0, 8, (48,)))

In [None]:
torch.arange(64).dim()

In [None]:
for ckpt in CKPTS.values():
    net_cls, ckpt, feat, epochs, hp = ckpt[0]
    
    print(feat)
    train = load_files(hp["files"], feat.sr)
    y = feat.transform(train.snd[:])
    
    for output in h5m.FileWalker(h5m.Sound.__re__, root+"/"+hp["id"]):
        x = h5m.Sound(sr=feat.sr).load(output)
        print(output)
        mmk.audio(x, sr=feat.sr)
        x = feat.transform(x)
        
        X = torch.as_tensor(x).unsqueeze(0).to("cuda")
        Y = torch.as_tensor(y).unsqueeze(0).to("cuda")

        with torch.no_grad():
            _, nn = nearest_neighbor(X, Y)
            rr = repeat_rate(nn, 88, 1)
            items, idx = torch.unique(nn, return_inverse=True)
            cum_probs = torch.zeros(nn.size(0), items.size(0), nn.size(1))
            cum_probs[:, idx, torch.arange(nn.size(1))] = 1
            cum_probs = torch.cumsum(cum_probs, dim=2)
            print(cum_probs)
            
            cum_probs = cum_probs / cum_probs.sum(dim=1, keepdims=True)
            e_wrt_t = (-cum_probs*torch.where(cum_probs > 0, torch.log(cum_probs), cum_probs)).sum(dim=1)
            print((torch.sign(e_wrt_t[:, 1:] - e_wrt_t[:, :-1]) * e_wrt_t[:, :-1]).sum(dim=1))
            
        plt.figure(figsize=(18, 4))
        plt.plot(nn.cpu().numpy()[0])
        plt.figure(figsize=(18, 4))
        plt.hist(nn.cpu().numpy()[0], bins=512)
        plt.figure(figsize=(18, 4))
        plt.plot(rr.cpu().numpy()[0] * e_wrt_t.cpu().numpy()[0].max().item())
        plt.plot(e_wrt_t.cpu().numpy()[0])
        plt.show()



In [None]:
CKPTS = group_ckpts_by_trainset("trainings")
[*CKPTS.keys()]

In [None]:
k = [*CKPTS.keys()][7]
CKPTS[k]

In [None]:
ens = Ensemble(16000, (None, None))

def process_outputs(outputs, bidx):
    mmk.audio(outputs[0][0].cpu().numpy(), sr=16000)
    
N_SECONDS = 60



for k in CKPTS.keys():

    prompt_files = load_files(CKPTS[k][0][-1]["files"], 16000)
    prompt = prompt_files.snd[32000:48000]
    prompt, torch.as_tensor(prompt).unsqueeze(0)
    
    def next_model():
        cls, tp, feat, epochs, hp = CKPTS[k][randint(0, len(CKPTS[k])-1)]
        ep = epochs[-1]
        net = tp.ckpt.load_checkpoint(cls, ep)
        return net, feat
    
    class Chainer:
        training = False
        device = "cuda"
        def eval(self):
            pass
        def to(self, *args, **kwargs):
            pass

        def generate_step(self, t, inputs, ctx):
            print("************************************************************************")
            net, feature = next_model()
            net = net.to("cuda")
            if hasattr(net, "use_fast_generate"):
                net.use_fast_generate = True
            print(net.__class__, net.rf, feature.sr, inputs[0].shape)
    
            if (t/16000) < (N_SECONDS-1):
                out = ens.single_step(inputs[0], net, feature,
                                      feature.sr if isinstance(feature, mmk.MuLawSignal)
                                      else (feature.sr//feature.hop_length)//getattr(net, "hop", 1))
    #             print(net.__class__, feature.sr, t/16000, t, inputs[0].shape, out.shape)
                return out
            return torch.zeros(1, (N_SECONDS+1)*16000-t).to("cuda")

    loop = mmk.GenerateLoop(
        network=Chainer(),
        dataloader=[(torch.as_tensor(prompt).unsqueeze(0),)],
        inputs=(h5m.Input(None, 
                          getter=h5m.AsSlice(dim=1, shift=-16000, length=16000),
                          setter=h5m.Setter(dim=1)),),
        n_steps=16000*N_SECONDS,
        add_blank=True,
        process_outputs=process_outputs
    )
    loop.run()

In [None]:
feat = mmk.Spectrogram(sr=44100, normalize=True, emphasis=0.0, n_fft=1024, hop_length=256, coordinate='mag', center=True)
feat.inverse_transform_(torch.randn(1, 20, 513)).shape

In [None]:
import numpy as np

class X:
    snd = np.random.randn(32000)

batch = [
    batch_item_resolver(net2.tiers[0],
                        mmk.Spectrogram(n_fft=16, hop_length=16, center=False, coordinate="mag"),
                        shift=0, length=8),
    batch_item_resolver(net2.tiers[1],
                        mmk.Spectrogram(n_fft=4, hop_length=4, center=False, coordinate="mag"),
                        shift=16//4 - 1, length=8*(16//4)),
    batch_item_resolver(net2.tiers[-1],
                        mmk.Spectrogram(n_fft=4, hop_length=1, center=False, coordinate="mag"),
                        shift=4 * (16//4 - 1), length=8*16),
    
]

[item(X, 12).shape for item in batch]