In [None]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os
import json
from random import randint
import matplotlib.pyplot as plt

from models.wavenets import WaveNetFFT, WaveNetQx
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM
from models.ensemble import Ensemble
from mains import train, generate

from checkpoints import group_ckpts_by_trainset, load_feature, load_files, load_network_cls
from datasets import TRAINSET, VERDI_X

import numpy as np

from mimikit.extract.from_neighbors import *

In [None]:
root = "trainings/wn-verdi-x"

CKPTS = group_ckpts_by_trainset(root)

In [None]:
CKPTS

In [None]:
2730*16*16

In [None]:
torch.arange(64).dim()

In [None]:
for ckpt in CKPTS.values():
    net_cls, ckpt, feat, epochs, hp = ckpt[0]
    
    print(feat)
    train = load_files(hp["files"], feat.sr)
    y = feat.transform(train.snd[:])
    
    for output in h5m.FileWalker(h5m.Sound.__re__, root+"/"+hp["id"]):
        x = h5m.Sound(sr=feat.sr).load(output)
        print(output)
        mmk.audio(x, sr=feat.sr)
        x = feat.transform(x)
        
        X = torch.as_tensor(x).unsqueeze(0).to("cuda")
        Y = torch.as_tensor(y).unsqueeze(0).to("cuda")

        with torch.no_grad():
            _, nn = nearest_neighbor(X, Y)
            rr = repeat_rate(nn, 88, 1)
            items, idx = torch.unique(nn, return_inverse=True)
            cum_probs = torch.zeros(nn.size(0), items.size(0), nn.size(1))
            cum_probs[:, idx, torch.arange(nn.size(1))] = 1
            cum_probs = torch.cumsum(cum_probs, dim=2)
            print(cum_probs)
            
            cum_probs = cum_probs / cum_probs.sum(dim=1, keepdims=True)
            e_wrt_t = (-cum_probs*torch.where(cum_probs > 0, torch.log(cum_probs), cum_probs)).sum(dim=1)
            print((torch.sign(e_wrt_t[:, 1:] - e_wrt_t[:, :-1]) * e_wrt_t[:, :-1]).sum(dim=1))
            
        plt.figure(figsize=(18, 4))
        plt.plot(nn.cpu().numpy()[0])
        plt.figure(figsize=(18, 4))
        plt.hist(nn.cpu().numpy()[0], bins=512)
        plt.figure(figsize=(18, 4))
        plt.plot(rr.cpu().numpy()[0] * e_wrt_t.cpu().numpy()[0].max().item())
        plt.plot(e_wrt_t.cpu().numpy()[0])
        plt.show()



In [None]:
CKPTS = group_ckpts_by_trainset("trainings")
[*CKPTS.keys()]

In [None]:
k = [*CKPTS.keys()][7]
CKPTS[k]

In [None]:
import torch
import h5mapper as h5m
import mimikit as mmk
from pbind import *

from models.ensemble import Ensemble
from datasets import COUGH
from checkpoints import load_files

stream = Pseq([
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=1., hi=8., repeats=1)
        ),
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=0.5, hi=1.5, repeats=1)
        ),
], inf).asStream()
    
ensemble = Ensemble(60., 22050, stream)

def process_outputs(outputs, bidx):
    mmk.audio(outputs[0][0].cpu().numpy(), sr=ensemble.base_sr)

prompt_files = load_files(COUGH["Cough"], ensemble.base_sr)
prompt = prompt_files.snd[0:44100]
prompt = torch.as_tensor(prompt).unsqueeze(0)

loop = mmk.GenerateLoop(
    network=ensemble,
    dataloader=[(prompt,)],
    inputs=(h5m.Input(None, 
                      getter=h5m.AsSlice(dim=1, shift=-ensemble.base_sr, length=ensemble.base_sr),
                      setter=h5m.Setter(dim=1)),),
    n_steps=int(ensemble.base_sr * ensemble.max_seconds),
    add_blank=True,
    process_outputs=process_outputs
)
loop.run()


In [None]:
feat = mmk.Spectrogram(sr=44100, normalize=True, emphasis=0.0, n_fft=1024, hop_length=256, coordinate='mag', center=True)
feat.inverse_transform_(torch.randn(1, 20, 513)).shape

In [None]:
ensemble.max_seconds * ensemble.base_sr

# Stream Declaration

In [None]:
import h5mapper as h5m

from checkpoints import group_ckpts_by_trainset

group_ckpts_by_trainset()

# Split Checkpoint Banks

In [None]:
import os
import h5mapper as h5m
from google.cloud import storage

from checkpoints import CkptBank, load_trainings_hp, load_network_cls, Checkpoint
from concurrent.futures import ThreadPoolExecutor, as_completed

def split_by_src(ckpt_path):
    print(ckpt_path)
    bank = CkptBank(ckpt_path)
    hp = bank.ckpt.load_hp()
    to_upload = []
    
    dirname = os.path.dirname(ckpt_path)
    train_hp = load_trainings_hp(dirname)
    net_cls = load_network_cls(train_hp["network_class"])
    hp["cls"] = net_cls
    for ep_id in bank.index.keys():
        new_path = os.path.join(
            dirname, ep_id.split("-")[0] + ".h5"
        )
        if os.path.isfile(new_path):
            to_upload += [new_path]
            continue
        new = CkptBank(new_path, mode="w")
        new.ckpt.save_hp(hp)
        new.flush()
        new.ckpt.add("state_dict", new.ckpt.format(bank.get(ep_id)['ckpt']))
        new.flush()
        new.close()
        to_upload += [new_path]
    return to_upload

def upload_to_gcp(ckpt_path):
#     raise ValueError
    ck = Checkpoint(*Checkpoint.get_id_and_epoch(ckpt_path))
    if not ck.blob.exists():
        print("uploading", ck)
        ck.blob.upload_from_filename(ckpt_path, timeout=None)
    print(ck.blob, ck.blob.exists())
    return



# to_split = h5m.FileWalker(r"checkpoints\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-cough")
to_upload = h5m.FileWalker(r"epoch=.*\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-lungs")

executor = ThreadPoolExecutor(max_workers=4)
# to_upload = [path for x in to_split for path in split_by_src(x)]

as_completed([*executor.map(upload_to_gcp, to_upload)])
executor.shutdown()
1, 2, 3

# Download Checkpoint

In [None]:
from datasets import *

Trainset.root_dir = "./train-data"


In [None]:
cough = Trainset("Cough")
lungs = Trainset("Lung Collection")
cough = cough.download()

lungs = lungs.download()
cough.index, lungs.index

In [None]:
import librosa
librosa.load("./train-data/Lung Collection/Breath.mp3")