In [None]:
import mimikit as mmk
import h5mapper as h5m
import torch.nn as nn
import torch
import os

from models.wavenets import WaveNetFFT
from models.srnns import SampleRNN
from models.s2s import Seq2SeqLSTM, Seq2SeqLSTMv0, Seq2SeqMuLaw
from mains import train, generate

from datasets import from_gcloud, LUNGS, Trainset
from grids import *

if os.path.exists("train.h5"):
    os.remove("train.h5")

# Seq2Seq MuLaw

In [None]:
# for train_file in [*h5m.FileWalker(h5m.Sound.__re__, 'train-data/sounds')][2:]:

#     h5m.sound_bank.callback("train.h5", train_file, sr=22050,)

#     soundbank = h5m.TypedFile("train.h5", mode='r', keep_open=True)
    
net_grid = instance_grid(Seq2SeqMuLaw, 
                         dict(mlp_dim=512,
                              input_dim=128,
                              feature=mmk.MuLawSignal(q_levels=128, sr=16000)
                             ),
                         sampler_zipper(32,
                                        ParameterGrid(dict(hop=[32,],
                                                           with_sampler=[True, False],
                                                           bias=[False],
                                                           model_dim=[256],
                                                           n_lstm=[1],
                                                           with_tbptt=[True]))))
    
id_, s = [*COUGH.items()][0]

for net in net_grid:
    feature = net.feature
    
    class GCPSoundBank(h5m.TypedFile):
        snd = from_gcloud(h5m.Sound(sr=feature.sr, mono=True, normalize=True))
    
    
    GCPSoundBank.create(f"train.h5", s, parallelism="threads", n_workers=8)
    soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)
    
    print(net.hp)
    
    
#     feature = mmk.Spectrogram(sr=soundbank.snd.attrs["sr"],
#                                 n_fft=1024, hop_length=256,
#                                 coordinate='mag',
#                                 center=False,
#                                 normalize=True)

#     net = Seq2SeqLSTM(
#         feature=feature,
#         input_heads=1,
#         output_heads=1,
#         scaled_activation=False,
#         model_dim = 256,
#         num_layers = 1,
#         n_lstm = 1,
#         bottleneck = "add",
#         n_fc = 1,
#         hop = 4,
#         weight_norm=False,
#         with_tbptt=False,
#     )

    optims = next(n_choices(optim_grid(), 1))
    print(net.hp.__class__(**optims))
    is_tbptt = (1+int(net.hp.with_tbptt)*1)
    train(
        soundbank,
        net,
        input_feature=feature,
        target_feature=feature,
        root_dir="./trainings/s2s-cough-qx",
        
#         batch_size=8,
        batch_length=net.hop * is_tbptt,
        tbptt_chunk_length=(net.hop*is_tbptt) if net.hp.with_tbptt else None,
        shift_error=0,
        downsampling=1,
        oversampling=1 if net.hp.with_tbptt else None,
        
        max_epochs=30,
        limit_train_batches=4000,
        **optims,
#         max_lr=4e-4,
#         betas=(0.9, 0.92),
        div_factor=5,
        final_div_factor=1.,
        pct_start=0.,
        cycle_momentum=False,

        CHECKPOINT_TRAINING=False,
        MONITOR_TRAINING=True,
        OUTPUT_TRAINING="",

        every_n_epochs=5,
        n_examples=4,
        prompt_length=net.hop,
        n_steps=int(6*(net.feature.sr) // net.hp.hop),
        temperature=torch.tensor([[[1.1]], [[1.]], [[.5]], [[.25]]]).repeat(1, 1, int(6*(net.feature.sr))),
        
        trainset=id_
    )

# Seq2Seq FFT

In [None]:
def fft_grid():
    return ParameterGrid(dict(feature=instance_grid(
        mmk.Spectrogram, dict(center=False),
        ParameterGrid([
            dict(sr=[44100],
                 )]),
        [
#             dict(n_fft=2048, hop_length=256),
            dict(n_fft=1024, hop_length=256),
        ],
        ParameterGrid([
            dict(coordinate=["mag"])
        ]),
    )))

def optim_grid():
    return instance_grid(
        None,
        dict(),
        ParameterGrid([
            dict(
                batch_size=[8, 16, 32],
                max_lr=[6e-4, 5e-4, 4e-4],
                betas=[(0.9, 0.915), (0.92, 0.92), (0.91846, 0.925), (0.95, 0.95), (0.925, 0.97)],
            )
        ])
    )
    
net_grid = instance_grid(Seq2SeqLSTM, 
                         dict(),
                         sampler_zipper(8,
                                        fft_grid(),
                                        ParameterGrid(dict(hop=[4, 8],
                                                           input_heads=[1],
                                                           output_heads=[1],
                                                           scaled_activation=[True, False],
                                                           bias=[False],
                                                           model_dim=[512],
                                                           n_lstm=[1],
                                                           with_tbptt=[False],
                                                           with_sampler=[True, False]
                                                          )))
                         )
    
id_, s = [*LUNGS.items()][0]

for net in net_grid:
    feature = net.feature
    
    class GCPSoundBank(h5m.TypedFile):
        snd = from_gcloud(h5m.Sound(sr=feature.sr, mono=True, normalize=True))
    
    
    GCPSoundBank.create(f"train.h5", s, parallelism="threads", n_workers=8)
    soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)
    
    print(net.hp)
    
    
#     feature = mmk.Spectrogram(sr=soundbank.snd.attrs["sr"],
#                                 n_fft=1024, hop_length=256,
#                                 coordinate='mag',
#                                 center=False,
#                                 normalize=True)

#     net = Seq2SeqLSTM(
#         feature=feature,
#         input_heads=1,
#         output_heads=1,
#         scaled_activation=False,
#         model_dim = 256,
#         num_layers = 1,
#         n_lstm = 1,
#         bottleneck = "add",
#         n_fc = 1,
#         hop = 4,
#         weight_norm=False,
#         with_tbptt=False,
#     )

    optims = next(n_choices(optim_grid(), 1))
    print(net.hp.__class__(**optims))
    is_tbptt = (1+int(net.hp.with_tbptt))
    train(
        soundbank,
        net,
        input_feature=feature,
        target_feature=feature,
        root_dir="./trainings/s2s-grid-lungs",
        
#         batch_size=8,
        batch_length=net.hop * is_tbptt,
        tbptt_chunk_length=(net.hop*10*is_tbptt) if net.hp.with_tbptt else None,
        shift_error=0,
        downsampling=128 if not net.hp.with_tbptt else 1,
        oversampling=32 if net.hp.with_tbptt else None,
        
        max_epochs=50,
        limit_train_batches=None,
        **optims,
#         max_lr=4e-4,
#         betas=(0.9, 0.92),
        div_factor=5,
        final_div_factor=1.,
        pct_start=0.,
        cycle_momentum=False,

        CHECKPOINT_TRAINING=True,
        MONITOR_TRAINING=True,
        OUTPUT_TRAINING="mp3",

        every_n_epochs=10,
        n_examples=4,
        prompt_length=net.hop,
        n_steps=int(12*(net.feature.sr//net.feature.hop_length) // net.hp.hop),
        trainset=id_
    )

In [None]:
1e-32*1e-1200

# Sample RNN

In [None]:
# for train_file in [*h5m.FileWalker(h5m.Sound.__re__, 'train-data/Heyr Himna.m4a')][0:]:
#     h5m.sound_bank.callback("train.h5", train_file, sr=16000,)

#     soundbank = h5m.TypedFile("train.h5", mode='r', keep_open=True)

class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))
    
    
for id_, s in COUGH.items():
    
    GCPSoundBank.create(f"train.h5", s, parallelism="threads", n_workers=8)
    soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)
    
    feature = mmk.MuLawSignal(sr=soundbank.snd.attrs["sr"],
                                q_levels=256)
    net = SampleRNN(
        feature=feature,
        chunk_length=16000*8,
        frame_sizes = (16, 8, 8),
        dim= 1024,
        n_rnn = 1,
        q_levels = 256,
        embedding_dim = 256,
        mlp_dim = 1024,
    )

    train(
        soundbank,
        net,
        input_feature=mmk.MultiScale(feature, net.frame_sizes, (*net.frame_sizes[:-1], 1)),
        target_feature=feature,
        root_dir="./trainings/srnn-cough",
        batch_size=32,
        batch_length=512,
        oversampling=1,
        shift_error=0,
        tbptt_len=(16000*8//512),
        max_epochs=100,
        limit_train_batches=8000,

        max_lr=5e-4,
        betas=(0.9, 0.99),
        div_factor=5.,
        final_div_factor=1.,
        pct_start=0.,
        cycle_momentum=False,

        CHECKPOINT_TRAINING=True,
        MONITOR_TRAINING=True,
        OUTPUT_TRAINING="mp3",

        every_n_epochs=5,
        n_examples=4,
        prompt_length=16000,
        n_steps=int(12*(net.feature.sr)),
        temperature=torch.tensor([[1.51], [1.25], [1.05], [.95]]).repeat(1, int(12*(net.feature.sr))),
        trainset=id_
    )

# Wavenets FFT

In [None]:
4**4

In [None]:
# for train_file in h5m.FileWalker(h5m.Sound.__re__, 'train-data/sounds'):
#     h5m.sound_bank.callback("train.h5", train_file, sr=22050,)

#     soundbank = h5m.TypedFile("train.h5", mode='r', keep_open=True)
class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))
    
feature = mmk.Spectrogram(sr=16000,
                            n_fft=1024,
                            hop_length=128,
                            coordinate='mag',
                            center=False,
                            normalize=True)
net_grid = sampler_zipper(8,
                        [dict(feature=feature,
                             kernel_sizes=(4,),
                              input_heads=2,
                              output_heads=2,
                            dims_dilated=(512,),
                            dims_1x1=(),
                            residuals_dim=None,
                            apply_residuals=False,
                            skips_dim=None,
                            groups=2,
                            act_f=nn.Tanh(),
                            act_g=nn.Sigmoid(),
                            pad_side=0,
                            stride=1,
                            bias=True,)],
                         ParameterGrid([
                             dict(
                                 blocks=[(4,), (3,), (2,)],
                                 scaled_activation=[True, False],
#                                  phs=["b", "c"],

                             )
                         ])
                        )   
    
def optim_grid():
    return instance_grid(
        None,
        dict(),
        ParameterGrid([
            dict(
#                 batch_size=[8, 16, 32],
#                 batch_length=[64, 64+32, 128, 128+32],
#                 batch_length=[16, 32, 64],
                max_lr=[7e-4, 5e-4, 3e-4],
                betas=[(0.9, 0.9), (0.92, 0.92), (0.97, 0.95), (0.95, 0.97), (.92, .98)],
            )
        ])
    )

    
    
for id_, s in LUNGS.items():
#     if id_ != "Verdi_X_3_bis":
#         continue
    for hp in net_grid:
        GCPSoundBank.create(f"train.h5", s, parallelism="threads", n_workers=8)
        soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)

    #     net = WaveNetFFT(
    #         feature=feature,
    #         input_heads=2,
    #         output_heads=2,
    #         scaled_activation=True,

    #         kernel_sizes=(2,),
    #         blocks=(4,),
    #         dims_dilated=(1024,),
    #         dims_1x1=(),
    #         residuals_dim=None,
    #         apply_residuals=False,
    #         skips_dim=None,
    #         groups=2,
    #         act_f=nn.Tanh(),
    #         act_g=nn.Sigmoid(),
    #         pad_side=0,
    #         stride=1,
    #         bias=True,
    #     )
        net = WaveNetFFT(**hp)
        
        net.use_fast_generate = False

        optims = next(n_choices(optim_grid(), 1))
        print(net.hp)
        print(net.hp.__class__(**optims))
        train(
            soundbank,
            net,
            root_dir="./trainings/wn-lungs",
            input_feature=feature,
            target_feature=feature,
            **optims,
            batch_size=8,
            batch_length=256,
            downsampling=128,
            shift_error=0,

            max_epochs=50,
            limit_train_batches=None,

#             max_lr=5e-4,
#             betas=(0.9, 0.92),
            div_factor=5.,
            final_div_factor=1.,
            pct_start=0.,
            cycle_momentum=False,

            CHECKPOINT_TRAINING=True,
            MONITOR_TRAINING=True,
            OUTPUT_TRAINING="mp3",

            every_n_epochs=10,
            n_examples=4,
            prompt_length=64,
            n_steps=int(12*(net.feature.sr//net.feature.hop_length)),
    #         temperature=torch.tensor([[.85] * 200]),
            trainset=id_
        )

In [None]:
class GCPSoundBank(h5m.TypedFile):
    snd = from_gcloud(h5m.Sound(sr=16000, mono=True, normalize=True))

GCPSoundBank.create(f"train.h5", ((*sets.values(),))[0], parallelism="threads", n_workers=8)


soundbank = GCPSoundBank(f"train.h5", mode='r', keep_open=True)


net = SampleRNN(
    feature=mmk.MuLawSignal(sr=soundbank.snd.attrs["sr"],
                            q_levels=256,
                            normalize=True),
    chunk_length=16000*8,
    frame_sizes = (16, 8, 8),
    dim= 512,
    n_rnn = 2,
    q_levels = 256,
    embedding_dim = 256,
    mlp_dim = 512,
)
dl = net.train_dataloader(soundbank, 8, 32, 1, 0)
inp, trg = next(iter(dl))
inp[0], trg[0]