In [None]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os
import json
from random import randint
import matplotlib.pyplot as plt

from models.wavenets import WaveNetFFT, WaveNetQx
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM
from models.ensemble import Ensemble
from mains import train, generate

from checkpoints import group_ckpts_by_trainset, load_feature, load_files, load_network_cls
from datasets import TRAINSET, VERDI_X

import numpy as np

from mimikit.extract.from_neighbors import *

In [None]:
root = "trainings/wn-verdi-x"

CKPTS = group_ckpts_by_trainset(root)

In [None]:
CKPTS

In [12]:
2730*16*16

698880

In [None]:
torch.arange(64).dim()

In [None]:
for ckpt in CKPTS.values():
    net_cls, ckpt, feat, epochs, hp = ckpt[0]
    
    print(feat)
    train = load_files(hp["files"], feat.sr)
    y = feat.transform(train.snd[:])
    
    for output in h5m.FileWalker(h5m.Sound.__re__, root+"/"+hp["id"]):
        x = h5m.Sound(sr=feat.sr).load(output)
        print(output)
        mmk.audio(x, sr=feat.sr)
        x = feat.transform(x)
        
        X = torch.as_tensor(x).unsqueeze(0).to("cuda")
        Y = torch.as_tensor(y).unsqueeze(0).to("cuda")

        with torch.no_grad():
            _, nn = nearest_neighbor(X, Y)
            rr = repeat_rate(nn, 88, 1)
            items, idx = torch.unique(nn, return_inverse=True)
            cum_probs = torch.zeros(nn.size(0), items.size(0), nn.size(1))
            cum_probs[:, idx, torch.arange(nn.size(1))] = 1
            cum_probs = torch.cumsum(cum_probs, dim=2)
            print(cum_probs)
            
            cum_probs = cum_probs / cum_probs.sum(dim=1, keepdims=True)
            e_wrt_t = (-cum_probs*torch.where(cum_probs > 0, torch.log(cum_probs), cum_probs)).sum(dim=1)
            print((torch.sign(e_wrt_t[:, 1:] - e_wrt_t[:, :-1]) * e_wrt_t[:, :-1]).sum(dim=1))
            
        plt.figure(figsize=(18, 4))
        plt.plot(nn.cpu().numpy()[0])
        plt.figure(figsize=(18, 4))
        plt.hist(nn.cpu().numpy()[0], bins=512)
        plt.figure(figsize=(18, 4))
        plt.plot(rr.cpu().numpy()[0] * e_wrt_t.cpu().numpy()[0].max().item())
        plt.plot(e_wrt_t.cpu().numpy()[0])
        plt.show()



In [None]:
CKPTS = group_ckpts_by_trainset("trainings")
[*CKPTS.keys()]

In [None]:
k = [*CKPTS.keys()][7]
CKPTS[k]

In [None]:
import torch
import h5mapper as h5m
import mimikit as mmk
from pbind import *

from models.ensemble import Ensemble
from datasets import COUGH
from checkpoints import load_files

stream = Pseq([
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=1., hi=8., repeats=1)
        ),
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=0.5, hi=1.5, repeats=1)
        ),
], inf).asStream()
    
ensemble = Ensemble(60., 22050, stream)

def process_outputs(outputs, bidx):
    mmk.audio(outputs[0][0].cpu().numpy(), sr=ensemble.base_sr)

prompt_files = load_files(COUGH["Cough"], ensemble.base_sr)
prompt = prompt_files.snd[0:44100]
prompt = torch.as_tensor(prompt).unsqueeze(0)

loop = mmk.GenerateLoop(
    network=ensemble,
    dataloader=[(prompt,)],
    inputs=(h5m.Input(None, 
                      getter=h5m.AsSlice(dim=1, shift=-ensemble.base_sr, length=ensemble.base_sr),
                      setter=h5m.Setter(dim=1)),),
    n_steps=int(ensemble.base_sr * ensemble.max_seconds),
    add_blank=True,
    process_outputs=process_outputs
)
loop.run()


In [None]:
feat = mmk.Spectrogram(sr=44100, normalize=True, emphasis=0.0, n_fft=1024, hop_length=256, coordinate='mag', center=True)
feat.inverse_transform_(torch.randn(1, 20, 513)).shape

In [None]:
ensemble.max_seconds * ensemble.base_sr

# Stream Declaration

In [4]:
import h5mapper as h5m

from checkpoints import group_ckpts_by_trainset

group_ckpts_by_trainset()

{'1e0fc278-5aa0-4be6-9784-677452f26eac': [(ax6.models.wavenets.WaveNetFFT,
   <CkptBank trainings/wn-medium-pol/3ae777af59c413ebae8637bea38d6665ded2c9eac001a8c7dfe98e56d2651d78/checkpoints.h5>,
   Spectrogram(sr=22050, normalize=True, emphasis=0.0, n_fft=2048, hop_length=512, coordinate='pol', center=True),
   ['epoch=5-step=67135',
    'epoch=10-step=134270',
    'epoch=15-step=201405',
    'epoch=20-step=268540',
    'epoch=25-step=335675',
    'epoch=30-step=402810',
    'epoch=35-step=469945',
    'epoch=40-step=537080',
    'epoch=45-step=604215',
    'epoch=50-step=671350'],
   {'files': ['train-data/sounds/raw/benny-hinn-anointing-of-the-spirit.mp3'],
    'network_class': 'WaveNetFFT',
    'network': {'kernel_sizes': [2],
     'blocks': [4],
     'act_f': 'Tanh()',
     'act_g': 'Sigmoid()',
     'apply_residuals': False,
     'bias': True,
     'dims_1x1': [],
     'dims_dilated': [2048],
     'groups': 4,
     'pad_side': 0,
     'residuals_dim': None,
     'skips_dim': None,


# Split Checkpoint Banks

In [3]:
import os
import h5mapper as h5m
from google.cloud import storage

from checkpoints import CkptBank, load_trainings_hp, load_network_cls, Checkpoint
from concurrent.futures import ThreadPoolExecutor, as_completed

def split_by_src(ckpt_path):
    print(ckpt_path)
    bank = CkptBank(ckpt_path)
    hp = bank.ckpt.load_hp()
    to_upload = []
    
    dirname = os.path.dirname(ckpt_path)
    train_hp = load_trainings_hp(dirname)
    net_cls = load_network_cls(train_hp["network_class"])
    hp["cls"] = net_cls
    for ep_id in bank.index.keys():
        new_path = os.path.join(
            dirname, ep_id.split("-")[0] + ".h5"
        )
        if os.path.isfile(new_path):
            to_upload += [new_path]
            continue
        new = CkptBank(new_path, mode="w")
        new.ckpt.save_hp(hp)
        new.flush()
        new.ckpt.add("state_dict", new.ckpt.format(bank.get(ep_id)['ckpt']))
        new.flush()
        new.close()
        to_upload += [new_path]
    return to_upload

def upload_to_gcp(ckpt_path):
#     raise ValueError
    ck = Checkpoint(*Checkpoint.get_id_and_epoch(ckpt_path))
    if not ck.blob.exists():
        print("uploading", ck)
        ck.blob.upload_from_filename(ckpt_path, timeout=None)
    print(ck.blob, ck.blob.exists())
    return



# to_split = h5m.FileWalker(r"checkpoints\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-cough")
to_upload = h5m.FileWalker(r"epoch=.*\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-lungs")

executor = ThreadPoolExecutor(max_workers=4)
# to_upload = [path for x in to_split for path in split_by_src(x)]

as_completed([*executor.map(upload_to_gcp, to_upload)])
executor.shutdown()
1, 2, 3

uploading Checkpoint(id='e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3', epoch=15, root_dir='./')
uploading Checkpoint(id='e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3', epoch=5, root_dir='./')
uploading Checkpoint(id='e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3', epoch=10, root_dir='./')
uploading Checkpoint(id='e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3', epoch=20, root_dir='./')
<Blob: ax6-outputs, checkpoints/e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3/epoch=20.h5, None> True
uploading Checkpoint(id='121b54aa1a5fa7b91a35347761278620369626ccd4e2394573e6515d2c5da752', epoch=15, root_dir='./')
<Blob: ax6-outputs, checkpoints/e80943abd23eebf72427064dc25e6ee141ee1ed70faba52599a21e18c22e9ea3/epoch=5.h5, None> True
uploading Checkpoint(id='121b54aa1a5fa7b91a35347761278620369626ccd4e2394573e6515d2c5da752', epoch=5, root_dir='./')
<Blob: ax6-outputs, checkpoints/e80943abd23eebf72427064dc25

(1, 2, 3)

# Download Checkpoint

In [None]:
from google.cloud import storage
import dataclasses as dtc
import os

from checkpoints import CkptBank, load_feature


client = storage.Client("ax6-Project")

@dtc.dataclass
class Checkpoint:
    id: str
    epoch: int
    bucket = "ax6-outputs"
    root_dir = "./"
    
    @staticmethod
    def get_id_and_epoch(path):
        id_, epoch = path.split("/")[-2:]
        return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1])
    
    @staticmethod
    def from_blob(blob):
        path = blob.name
        id_, epoch = Checkpoint.get_id_and_epoch(path)
        ckpt = Checkpoint(id_, epoch)
        ckpt.bucket = blob.bucket.name
        return ckpt
    
    @property
    def gcp_path(self):
        return f"gs://{self.bucket}/checkpoints/{self.id}/epoch={self.epoch}.h5"
    
    @property
    def os_path(self):
        return os.path.join(self.root_dir, f"{self.id}_epoch={self.epoch}.h5")
    
    @property
    def blob(self):
        return client.bucket(self.bucket).blob(f"checkpoints/{self.id}/epoch={self.epoch}.h5")
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        client.download_blob_to_file(self.gcp_path, open(self.os_path, "wb"))
        return self
    
    @property
    def network(self):
        if not os.path.isfile(self.os_path):
            self.download()
        bank = CkptBank(self.os_path, 'r')
        hp = bank.ckpt.load_hp()
        return bank.ckpt.load_checkpoint(hp["cls"], "state_dict")
    
    @property
    def feature(self):
        if not os.path.isfile(self.os_path):
            self.download()
        bank = CkptBank(self.os_path, 'r')
        hp = bank.ckpt.load_hp()
        return hp['feature']
    

[Checkpoint.from_blob(blob) 
 for blob in client.list_blobs(Checkpoint.bucket, prefix='checkpoints')
 if "epoch=" in blob.name and ".h5" == blob.name[-3:]
][-1].feature

In [None]:
list(client.list_blobs("ax6-outputs", prefix='checkpoints'))

In [1]:
from models.nnn import *
from checkpoints import *
from datasets import *
import mimikit as mmk

In [2]:
bank = Trainset("Cough").bank


In [3]:
fft = mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag").transform
optimal_path(fft(bank.snd[:8000]), fft(bank.snd[5000:15000]))

array([[ 0,  7],
       [ 1,  8],
       [ 2,  8],
       [ 3,  9],
       [ 4, 10],
       [ 5, 11],
       [ 6, 12],
       [ 7, 13],
       [ 8, 14],
       [ 9, 15],
       [10, 16],
       [11, 16],
       [12, 17],
       [13, 17],
       [14, 17],
       [15, 17]])

In [6]:
nnn = NearestNextNeighbor(
    mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag"),
    bank.snd
)
prompt = fft(bank.snd[3000:8000])

mmk.GenerateLoop(
    nnn,
    [(torch.from_numpy(prompt).unsqueeze(0), )],
    inputs=(h5m.Input(None,
        getter=h5m.AsSlice(dim=1, shift=-22050, length=22050),
        setter=h5m.Setter(dim=1)),),
    n_steps=32,
    device="cpu"
).run()

HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=32.0, style=Progr…

