In [None]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os

from models.wavenets import WaveNetFFT, WaveNetQx
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM, Seq2SeqLSTMv0
from mains import train, generate

from datasets import from_gcloud, INSECTS_X, VERDI_X, COUGH, LUNGS

if os.path.exists("train-2.h5"):
    os.remove("train-2.h5")

# SEQ2SEQ

In [None]:
from grids import fft_grid, fft_io_grid


class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=22050, mono=True, normalize=True))
    
    
for id_, s in COUGH.items():
#     if id_ != "Verdi_X_3_bis":
#         continue
    
    GCPSoundBank.create(f"train-2.h5", s, parallelism="threads", n_workers=8)
    soundbank = GCPSoundBank(f"train-2.h5", mode='r', keep_open=True)
    
    feature = mmk.Spectrogram(sr=soundbank.snd.attrs["sr"],
                                n_fft=2048, hop_length=512,
                                coordinate='pol',
                                center=False,
                                normalize=True)

    net = Seq2SeqLSTM(
        feature=feature,
        input_heads=1,
        output_heads=1,
        scaled_activation=True,
        model_dim = 1024,
        num_layers = 1,
        n_lstm = 1,
        bottleneck = "add",
        n_fc = 1,
        hop = 4,
        weight_norm=False,
        with_tbptt=False,
        with_sampler=False,
    )

    train(
        soundbank,
        net,
        input_feature=feature,
        target_feature=feature,
        root_dir="./trainings/s2s-cough-pol",
        batch_size=8,
        batch_length=4,
        shift_error=0,
        downsampling=32,

        max_epochs=50,
        limit_train_batches=None,

        max_lr=4e-4,
        betas=(0.9, 0.93),
        div_factor=5.,
        final_div_factor=1.,
        pct_start=0.,
        cycle_momentum=False,

        CHECKPOINT_TRAINING=False,
        MONITOR_TRAINING=True,
        OUTPUT_TRAINING="",

        every_n_epochs=10,
        n_examples=4,
        prompt_length=4,
        n_steps=int(12*(net.feature.sr//net.feature.hop_length) // net.hp.hop),
    )

# Sample RNN

In [None]:
16000*4/512

In [None]:
# for train_file in [*h5m.FileWalker(h5m.Sound.__re__, 'train-data/Heyr Himna.m4a')][0:]:
#     h5m.sound_bank.callback("train.h5", train_file, sr=16000,)

#     soundbank = h5m.TypedFile("train.h5", mode='r', keep_open=True)

class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))
    
    
for id_, s in LUNGS.items():
    for lr, betas in [(7e-4, (0.97, 0.995)), (5e-4, (.95, .985))]:

        GCPSoundBank.create(f"train-2.h5", s, parallelism="threads", n_workers=8)
        soundbank = GCPSoundBank(f"train-2.h5", mode='r', keep_open=True)

        feature = mmk.MuLawSignal(sr=soundbank.snd.attrs["sr"],
                                    q_levels=128)
        net = SampleRNN(
            feature=feature,
            chunk_length=2048,
            frame_sizes = (512, 64, 8, 8),
            dim= 512,
            n_rnn = 1,
            q_levels = 128,
            embedding_dim = 256,
            mlp_dim = 512,
        )
        print("------------", lr, betas)

        train(
            soundbank,
            net,
            input_feature=mmk.MultiScale(feature, net.frame_sizes, (*net.frame_sizes[:-1], 1)),
            target_feature=feature,
            root_dir="./trainings/srnn-lungs",
            batch_size=16,
            batch_length=2048,
            oversampling=32,
            shift_error=0,
            tbptt_chunk_length=2048*8,    #(16000*8//512),
            max_epochs=1000,
            limit_train_batches=None,

#             max_lr=3e-4,
            max_lr=lr,
            betas=betas,
#             betas=(0.95, 0.95),
            div_factor=5.,
            final_div_factor=1.,
            pct_start=0.,
            cycle_momentum=False,

            CHECKPOINT_TRAINING=True,
            MONITOR_TRAINING=True,
            OUTPUT_TRAINING="mp3",

            every_n_epochs=50,
            n_examples=6,
            prompt_length=16000,
            n_steps=int(12*(net.feature.sr)),
            temperature=torch.tensor([[1.25], [1.1], [1.], [.995], [.95], [.75]]).repeat(1, int(12*(net.feature.sr))),
            trainset=id_
        )

# Wavenets

In [None]:
4**4

In [None]:
# for train_file in h5m.FileWalker(h5m.Sound.__re__, 'train-data/sounds'):
#     h5m.sound_bank.callback("train.h5", train_file, sr=22050,)

#     soundbank = h5m.TypedFile("train.h5", mode='r', keep_open=True)
class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))
    
    
for id_, s in LUNGS.items():
#     if id_ != "Verdi_X_3_bis":
#         continue
    
    GCPSoundBank.create(f"train-2.h5", s, parallelism="threads", n_workers=8)
    soundbank = GCPSoundBank(f"train-2.h5", mode='r', keep_open=True)
    
#     feature = mmk.Spectrogram(sr=soundbank.snd.attrs["sr"],
#                                 n_fft=2048,
#                                 hop_length=512,
#                                 coordinate='mag',
#                                 center=False,
#                                 normalize=True)
    feature = mmk.MuLawSignal(sr=16000,
                q_levels=256)
    net = WaveNetQx(
            feature=feature,
            mlp_dim=512,

            kernel_sizes=(4, ),
            blocks=(4,),
            dims_dilated=(1024,),
            dims_1x1=(),
            residuals_dim=None,
            apply_residuals=False,
            skips_dim=None,
            groups=8,
            pad_side=0,
            stride=1,
            bias=True,
    )
    net.use_fast_generate = True

    train(
        soundbank,
        net,
        root_dir="./trainings/wn-lungs-qx",
        input_feature=feature,
        target_feature=feature,
        batch_size=16,
        batch_length=1024,
        downsampling=4,
        shift_error=0,

        max_epochs=500,
        limit_train_batches=8000,

        max_lr=5e-4,
        betas=(0.92, 0.975),
        div_factor=5.,
        final_div_factor=1.,
        pct_start=0.,
        cycle_momentum=False,

        CHECKPOINT_TRAINING=True,
        MONITOR_TRAINING=True,
        OUTPUT_TRAINING="mp3",

        every_n_epochs=25,
        n_examples=8,
        prompt_length=256,
#         n_steps=int(6*(net.feature.sr//net.feature.hop_length)),
        n_steps=int(6*net.feature.sr),
        temperature=torch.tensor([[2.1], [1.95], [1.75], [1.5], [1.], [.9], [.5], [.0005]]).repeat(1, int(6*(net.feature.sr))),
        trainset=id_
    )

In [None]:
class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))

GCPSoundBank.create(f"train.h5", ((*sets.values(),))[0], parallelism="threads", n_workers=8)


soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)


net = SampleRNN(
    feature=mmk.MuLawSignal(sr=soundbank.snd.attrs["sr"],
                            q_levels=256,
                            normalize=True),
    chunk_length=16000*8,
    frame_sizes = (16, 8, 8),
    dim= 512,
    n_rnn = 2,
    q_levels = 256,
    embedding_dim = 256,
    mlp_dim = 512,
)
dl = net.train_dataloader(soundbank, 8, 32, 1, 0)
inp, trg = next(iter(dl))
inp[0], trg[0]