In [1]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os
import json
from random import randint
import matplotlib.pyplot as plt

from models.wavenets import WaveNetFFT, WaveNetQx
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM
from ensemble import Ensemble
from mains import train, generate

from checkpoints import Checkpoint
from datasets import Trainset

import numpy as np

from mimikit.extract.from_neighbors import *

In [2]:
root = "trainings/"

ckpts = [Checkpoint.from_path(f) for f in h5m.FileWalker(r".*verdi.*epoch=.*\.h5", root)
        if "qx" not in f and '88k' not in f]
ckpts

[Checkpoint(id='dc63b032229eafd03f52f56d14aefbb4274bea5ca5fcdd55c6bdfa5cc43d1e8e', epoch=50, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='dc63b032229eafd03f52f56d14aefbb4274bea5ca5fcdd55c6bdfa5cc43d1e8e', epoch=30, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='dc63b032229eafd03f52f56d14aefbb4274bea5ca5fcdd55c6bdfa5cc43d1e8e', epoch=10, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='dc63b032229eafd03f52f56d14aefbb4274bea5ca5fcdd55c6bdfa5cc43d1e8e', epoch=40, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='dc63b032229eafd03f52f56d14aefbb4274bea5ca5fcdd55c6bdfa5cc43d1e8e', epoch=20, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='d9746c79b8e54a1ef67374b3ab711f2f440416b5a9adbda2b6c5e1f16434c505', epoch=50, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='d9746c79b8e54a1ef67374b3ab711f2f440416b5a9adbda2b6c5e1f16434c505', epoch=30, root_dir='trainings/s2s-verdi-x-44k'),
 Checkpoint(id='d9746c79b8e54a1ef67374b3ab711f2f440416b5a9adbda2b6c5e1f16434

In [None]:
import gc

for ck in ckpts[:]:
    net = ck.network
    feature = ck.feature
    bank = Trainset(keyword=ck.train_hp["trainset"], sr=feature.sr)
    saved = {}
    
    
    def process_outputs(outputs, bidx):
        outputs = outputs[0]
        y = feature.transform(bank.bank.snd[:])
        y = torch.from_numpy(y).to(outputs)
        nn = torch.stack([nearest_neighbor(out, y)[1] for out in outputs])
        hx = torch.stack([cum_entropy(n, neg_diff=False) for n in nn]).detach().cpu().numpy()
        idx = np.argsort(hx)
        for i in idx:
            saved[hx[i]] = outputs[i].detach().cpu().numpy().T
        del y
        del nn
        gc.collect()
        torch.cuda.empty_cache()
        

    prompt_files = bank.bank
    batch_item = feature.batch_item(shift=0, length=net.rf, training=False)
    indices = mmk.IndicesSampler(N=500,
                                  indices=torch.arange(0,
                                                       prompt_files.snd.shape[0]-batch_item.getter.length,
                                                       (prompt_files.snd.shape[0]-batch_item.getter.length)//500))
    dl = prompt_files.serve(
        (batch_item, ),
        sampler=indices,
        shuffle=False,
        batch_size=64,
    )
    
    
    loop = mmk.GenerateLoop(
        network=net,
        dataloader=dl,
        inputs=(h5m.Input(None, 
                          getter=h5m.AsSlice(dim=1, shift=-net.rf, length=net.rf),
                          setter=h5m.Setter(dim=1)),),
        n_steps=feature.sr*25//feature.hop_length,
        add_blank=True,
        time_hop=net.hp.get("hop", 1),
        process_outputs=process_outputs
    )
#     print("\n")
    print("\n")
    print("\n")
    print("-----------------------------------------")
    print("\n")
    print("\n")  
    loop.run()
    print("\n")
    print("\n")
    print(ck)

    for k in list(sorted(saved))[-8:]:
        print("SCORE = ", k)
        mmk.audio(saved[k], hop_length=feature.hop_length, sr=feature.sr)
        
    del net
    gc.collect()
    torch.cuda.empty_cache()

    





-----------------------------------------






HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



HBox(children=(FloatProgress(value=0.0, description='Generate', layout=Layout(flex='2'), max=4306.0, style=Pro…



In [None]:
CKPTS = group_ckpts_by_trainset("trainings")
[*CKPTS.keys()]

In [None]:
k = [*CKPTS.keys()][7]
CKPTS[k]

In [None]:
import torch
import h5mapper as h5m
import mimikit as mmk
from pbind import *

from models.ensemble import Ensemble
from datasets import COUGH
from checkpoints import load_files

stream = Pseq([
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=1., hi=8., repeats=1)
        ),
    Pbind(
        "id", "80cb7d5b4ff7af169e74b3617c43580a41d5de5bd6c25e3251db2d11213755cd",
        "epoch", Prand([40, 50], inf),
        "seconds", Pwhite(lo=0.5, hi=1.5, repeats=1)
        ),
], inf).asStream()
    
ensemble = Ensemble(60., 22050, stream)

def process_outputs(outputs, bidx):
    mmk.audio(outputs[0][0].cpu().numpy(), sr=ensemble.base_sr)

prompt_files = load_files(COUGH["Cough"], ensemble.base_sr)
prompt = prompt_files.snd[0:44100]
prompt = torch.as_tensor(prompt).unsqueeze(0)

loop = mmk.GenerateLoop(
    network=ensemble,
    dataloader=[(prompt,)],
    inputs=(h5m.Input(None, 
                      getter=h5m.AsSlice(dim=1, shift=-ensemble.base_sr, length=ensemble.base_sr),
                      setter=h5m.Setter(dim=1)),),
    n_steps=int(ensemble.base_sr * ensemble.max_seconds),
    add_blank=True,
    process_outputs=process_outputs
)
loop.run()


In [None]:
feat = mmk.Spectrogram(sr=44100, normalize=True, emphasis=0.0, n_fft=1024, hop_length=256, coordinate='mag', center=True)
feat.inverse_transform_(torch.randn(1, 20, 513)).shape

In [None]:
ensemble.max_seconds * ensemble.base_sr

# Stream Declaration

In [None]:
import h5mapper as h5m

from checkpoints import group_ckpts_by_trainset

group_ckpts_by_trainset()

# Split Checkpoint Banks

In [None]:
import os
import h5mapper as h5m
from google.cloud import storage

from checkpoints import CkptBank, load_trainings_hp, load_network_cls, Checkpoint
from concurrent.futures import ThreadPoolExecutor, as_completed

def split_by_src(ckpt_path):
    print(ckpt_path)
    bank = CkptBank(ckpt_path)
    hp = bank.ckpt.load_hp()
    to_upload = []
    
    dirname = os.path.dirname(ckpt_path)
    train_hp = load_trainings_hp(dirname)
    net_cls = load_network_cls(train_hp["network_class"])
    hp["cls"] = net_cls
    for ep_id in bank.index.keys():
        new_path = os.path.join(
            dirname, ep_id.split("-")[0] + ".h5"
        )
        if os.path.isfile(new_path):
            to_upload += [new_path]
            continue
        new = CkptBank(new_path, mode="w")
        new.ckpt.save_hp(hp)
        new.flush()
        new.ckpt.add("state_dict", new.ckpt.format(bank.get(ep_id)['ckpt']))
        new.flush()
        new.close()
        to_upload += [new_path]
    return to_upload

def upload_to_gcp(ckpt_path):
#     raise ValueError
    ck = Checkpoint(*Checkpoint.get_id_and_epoch(ckpt_path))
    if not ck.blob.exists():
        print("uploading", ck)
        ck.blob.upload_from_filename(ckpt_path, timeout=None)
    print(ck.blob, ck.blob.exists())
    return



# to_split = h5m.FileWalker(r"checkpoints\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-cough")
to_upload = h5m.FileWalker(r"epoch=.*\.h5", "/home/antoine/ktonal/ax6/trainings/s2s-grid-lungs")

executor = ThreadPoolExecutor(max_workers=4)
# to_upload = [path for x in to_split for path in split_by_src(x)]

as_completed([*executor.map(upload_to_gcp, to_upload)])
executor.shutdown()
1, 2, 3

# Download Checkpoint

In [None]:
from google.cloud import storage
import dataclasses as dtc
import os

from checkpoints import CkptBank, load_feature


client = storage.Client("ax6-Project")

@dtc.dataclass
class Checkpoint:
    id: str
    epoch: int
    bucket = "ax6-outputs"
    root_dir = "./"
    
    @staticmethod
    def get_id_and_epoch(path):
        id_, epoch = path.split("/")[-2:]
        return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1])
    
    @staticmethod
    def from_blob(blob):
        path = blob.name
        id_, epoch = Checkpoint.get_id_and_epoch(path)
        ckpt = Checkpoint(id_, epoch)
        ckpt.bucket = blob.bucket.name
        return ckpt
    
    @property
    def gcp_path(self):
        return f"gs://{self.bucket}/checkpoints/{self.id}/epoch={self.epoch}.h5"
    
    @property
    def os_path(self):
        return os.path.join(self.root_dir, f"{self.id}_epoch={self.epoch}.h5")
    
    @property
    def blob(self):
        return client.bucket(self.bucket).blob(f"checkpoints/{self.id}/epoch={self.epoch}.h5")
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        client.download_blob_to_file(self.gcp_path, open(self.os_path, "wb"))
        return self
    
    @property
    def network(self):
        if not os.path.isfile(self.os_path):
            self.download()
        bank = CkptBank(self.os_path, 'r')
        hp = bank.ckpt.load_hp()
        return bank.ckpt.load_checkpoint(hp["cls"], "state_dict")
    
    @property
    def feature(self):
        if not os.path.isfile(self.os_path):
            self.download()
        bank = CkptBank(self.os_path, 'r')
        hp = bank.ckpt.load_hp()
        return hp['feature']
    

[Checkpoint.from_blob(blob) 
 for blob in client.list_blobs(Checkpoint.bucket, prefix='checkpoints')
 if "epoch=" in blob.name and ".h5" == blob.name[-3:]
][-1].feature

In [None]:
list(client.list_blobs("ax6-outputs", prefix='checkpoints'))

In [None]:
from models.nnn import *
from checkpoints import *
from datasets import *
import mimikit as mmk

In [None]:
bank = Trainset("Cough").bank


In [None]:
fft = mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag").transform
optimal_path(fft(bank.snd[:8000]), fft(bank.snd[5000:15000]))

In [None]:
nnn = NearestNextNeighbor(
    mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag"),
    bank.snd
)
prompt = fft(bank.snd[3000:8000])

mmk.GenerateLoop(
    nnn,
    [(torch.from_numpy(prompt).unsqueeze(0), )],
    inputs=(h5m.Input(None,
        getter=h5m.AsSlice(dim=1, shift=-22050, length=22050),
        setter=h5m.Setter(dim=1)),),
    n_steps=32,
    device="cpu"
).run()