In [None]:
!pip install git+https://github.com/ktonal/mimikit@experiment/new-data-again

### Connect to gdrive for making a db with your audios (a `freqnet-db` won't work...)

In [None]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive/MyDrive

## Imports and Model Definition

In [None]:
import numpy as np
from numpy.lib.stride_tricks import as_strided as np_as_strided
from pytorch_lightning import LightningModule
import torch
import torch.nn as nn
from torch.utils.data import Sampler, RandomSampler, BatchSampler
import math
from typing import Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt

from mimikit.kit.connectors.neptune import NeptuneConnector
from mimikit.utils import audio
from mimikit.audios import transforms as A
from mimikit.kit.ds_utils import ShiftedSequences
from mimikit.kit import DBDataset, SuperAdam, SequenceModel, DataSubModule
from mimikit.kit import get_trainer
from mimikit.kit import tqdm


# first we define a sampler & a db then the network. Finaly, we "merge" them all in a LightningModule.

class TBPTTSampler(Sampler):
    """
    yields batches of indices for performing Truncated Back Propagation Through Time
    """
    def __init__(self,
                 n_samples,
                 batch_size=64,  # nbr of "tracks" per batch
                 chunk_length=8*16000,  # total length of a track
                 seq_len=1024  # nbr of samples per backward pass
                 ):
        super().__init__(None)
        self.n_samples = n_samples
        self.batch_size = batch_size
        self.chunk_length = chunk_length
        self.seq_len = seq_len
        self.n_chunks = self.n_samples // self.chunk_length
        self.n_per_chunk = self.chunk_length // self.seq_len

    def __iter__(self):
        smp = RandomSampler(torch.arange(self.n_chunks))
        for top in BatchSampler(smp, self.batch_size, True):  # drop last!
            for start in range(self.n_per_chunk):
                # start indices of the batch
                yield tuple((t * self.chunk_length) + (start * self.seq_len) for t in top)

    def __len__(self):
        return int(max(1, math.floor(self.n_chunks / self.batch_size)) * self.n_per_chunk)


class FramesDB(DBDataset):
    qx = None

    @staticmethod
    def extract(path, sr=16000, mu=255):
        signal = A.FileTo.mu_law_compress(path, sr=sr, mu=mu)
        return dict(qx=(dict(sr=sr, mu=mu), signal.reshape(-1, 1), None))

    def prepare_dataset(self, model, datamodule):
        batch_size, chunk_len, batch_seq_len, frame_sizes = model.batch_info()
        shifts = [frame_sizes[0] - size for size in frame_sizes + (0,)]  # (0,) for the target
        lengths = [batch_seq_len for _ in frame_sizes[:-1]]
        # bottom tier
        lengths += [frame_sizes[0] + batch_seq_len]
        # targets
        lengths += [batch_seq_len]

        self.slicer = ShiftedSequences(len(self.qx), list(zip(shifts, lengths)))
        self.frame_sizes = frame_sizes
        self.seq_len = batch_seq_len

        # round the size of the dataset to a multiple of the chunk size :
        batch_sampler = TBPTTSampler(chunk_len * (len(self.qx) // chunk_len),
                                     batch_size,
                                     chunk_len,
                                     batch_seq_len)
        datamodule.loader_kwargs.update(dict(batch_sampler=batch_sampler))
        for k in ["batch_size", "shuffle", "drop_last"]:
            if k in datamodule.loader_kwargs:
                datamodule.loader_kwargs.pop(k)
        datamodule.loader_kwargs["sampler"] = None

    def __getitem__(self, item):
        if type(self.qx) is not torch.Tensor:
            itemsize = self.qx.dtype.itemsize
            as_strided = lambda slc, fs: np_as_strided(self.qx[slc],
                                                       shape=(self.seq_len, fs),
                                                       strides=(itemsize, itemsize))
        else:
            as_strided = lambda slc, fs: torch.as_strided(self.qx[slc],
                                                          size=(self.seq_len, fs),
                                                          stride=(1, 1))

        slices = self.slicer(item)
        tiers_slc, bottom_slc, target_slc = slices[:-2], slices[-2], slices[-1]
        inputs = [self.qx[slc].reshape(-1, fs) for slc, fs in zip(tiers_slc, self.frame_sizes[:-1])]
        # ugly but necessary if self.qx became a tensor...
        with torch.no_grad():
            inputs += [as_strided(bottom_slc, self.frame_sizes[-1])]

        target = self.qx[target_slc]

        return tuple(inputs), target

    def __len__(self):
        return len(self.qx)


@dataclass(init=True, repr=False, eq=False, frozen=False, unsafe_hash=True)
class SampleRNNTier(nn.Module):
    tier_index: int
    frame_size: int
    dim: int
    up_sampling: int = 1
    n_rnn: int = 2
    q_levels: int = 256
    embedding_dim: Optional[int] = None
    mlp_dim: Optional[int] = None

    is_bottom = property(lambda self: self.up_sampling == 1)

    def linearize(self, q_samples):
        """ maps input samples (0 <= qx < 256) to floats (-.5 <= x < .5) """
        return (q_samples.float() / self.q_levels) - .5

    def embeddings_(self):
        if self.embedding_dim is not None:
            return nn.Embedding(self.q_levels, self.embedding_dim)
        return None

    def input_proj_(self):

        if not self.is_bottom:  # top & middle tiers
            return nn.Linear(self.frame_size, self.dim, bias=False)

        else:  # bottom tier
            class BottomProjector(nn.Module):
                def __init__(self, emb_dim, out_dim, frame_size):
                    super(BottomProjector, self).__init__()
                    self.cnv = nn.Conv1d(emb_dim, out_dim, kernel_size=frame_size, bias=False)

                def forward(self, hx):
                    """ hx : (B x T x FS x E) """
                    B, T, FS, E = hx.size()
                    hx = self.cnv(hx.view(B * T, FS, E).transpose(1, 2).contiguous())
                    # now hx : (B*T, DIM, 1)
                    return hx.squeeze().reshape(B, T, -1)

            return BottomProjector(self.embedding_dim, self.dim, self.frame_size)

    def rnn_(self):
        # no rnn for bottom tier
        if self.is_bottom:
            return None

        return nn.LSTM(self.dim, self.dim, self.n_rnn, batch_first=True)

    def up_sampling_net_(self):
        # no up sampling for bottom tier
        if self.is_bottom:
            return None

        class TimeUpscalerLinear(nn.Module):

            def __init__(self, in_dim, out_dim, up_sampling, **kwargs):
                super(TimeUpscalerLinear, self).__init__()
                self.up_sampling = up_sampling
                self.out_dim = out_dim
                self.fc = nn.Linear(in_dim, out_dim * up_sampling, **kwargs)

            def forward(self, x):
                B, T, _ = x.size()
                return self.fc(x).reshape(B, T * self.up_sampling, self.out_dim)

        return TimeUpscalerLinear(self.dim, self.dim, self.up_sampling)

    def mlp_(self):
        if not self.is_bottom:
            return None
        return nn.Sequential(
            nn.Linear(self.dim, self.mlp_dim), nn.ReLU(),
            nn.Linear(self.mlp_dim, self.mlp_dim), nn.ReLU(),
            nn.Linear(self.mlp_dim, self.q_levels),
        )

    def __post_init__(self):
        nn.Module.__init__(self)
        self.embeddings = self.embeddings_()
        self.inpt_proj = self.input_proj_()
        self.rnn = self.rnn_()
        self.up_net = self.up_sampling_net_()
        self.mlp = self.mlp_()

    def forward(self, input_samples, prev_tier_output=None, hidden=None):

        if self.embeddings is None:
            x = self.linearize(input_samples)
        else:
            x = self.embeddings(input_samples)

        if self.inpt_proj is not None:
            p = self.inpt_proj(x)
            if prev_tier_output is not None:
                x = p + prev_tier_output
            else:
                x = p

        if self.rnn is not None:
            if hidden is None or x.size(0) != hidden[0].size(1):
                h0 = (torch.randn(self.n_rnn, x.size(0), self.dim) * .05).to(x)
                c0 = (torch.randn(self.n_rnn, x.size(0), self.dim) * .05).to(x)
                hidden = (h0, c0)
            else:
                # TRUNCATED back propagation through time == detach()!
                hidden = tuple(h.detach() for h in hidden)

            x, hidden = self.rnn(x, hidden)

        if self.up_net is not None:
            x = self.up_net(x)

        if self.mlp is not None:
            x = self.mlp(x)

        return x, hidden


@dataclass(init=True, repr=False, eq=False, frozen=False, unsafe_hash=True)
class SampleRNNNetwork(nn.Module):

    frame_sizes: tuple  # from top to bottom!
    dim: int = 512
    n_rnn: int = 2
    q_levels: int = 256
    embedding_dim: int = 256
    mlp_dim: int = 512

    def __post_init__(self):
        nn.Module.__init__(self)
        n_tiers = len(self.frame_sizes)
        tiers = []
        if n_tiers > 2:
            for i, fs in enumerate(self.frame_sizes[:-2]):
                tiers += [SampleRNNTier(i, fs, self.dim,
                                        up_sampling=fs // self.frame_sizes[i + 1],
                                        n_rnn=self.n_rnn,
                                        q_levels=self.q_levels,
                                        embedding_dim=None,
                                        mlp_dim=None)]
        # before last tier
        tiers += [SampleRNNTier(n_tiers - 2, self.frame_sizes[-2], self.dim,
                                up_sampling=self.frame_sizes[-1],
                                n_rnn=self.n_rnn,
                                q_levels=self.q_levels,
                                embedding_dim=None,
                                mlp_dim=None)]
        # bottom tier
        tiers += [SampleRNNTier(n_tiers - 1, self.frame_sizes[-1], self.dim,
                                up_sampling=1,
                                n_rnn=self.n_rnn,
                                q_levels=self.q_levels,
                                embedding_dim=self.embedding_dim,
                                mlp_dim=self.mlp_dim)]

        self.tiers = nn.ModuleList(tiers)
        self.hidden = [None] * n_tiers

    def forward(self, tiers_inputs):
        prev_out = None
        for i, (tier, inpt) in enumerate(zip(self.tiers, tiers_inputs)):
            prev_out, self.hidden[i] = tier(inpt, prev_out, self.hidden[i])
        return prev_out

    def reset_h0(self):
        self.hidden = [None] * len(self.frame_sizes)


class SampleRNN(SequenceModel,
                DataSubModule,
                SuperAdam,
                SampleRNNNetwork,
                LightningModule):

    @staticmethod
    def loss_fn(output, target):
        criterion = nn.CrossEntropyLoss(reduction="mean")
        return criterion(output.view(-1, output.size(-1)), target.view(-1))

    db_class = FramesDB

    def __init__(self,
                 frame_sizes=(4, 4),
                 net_dim=512,
                 emb_dim=256,
                 mlp_dim=512,
                 n_rnn=2,
                 q_levels=256,  # == mu + 1
                 max_lr=1e-3,
                 betas=(.9, .9),
                 div_factor=3.,
                 final_div_factor=1.,
                 pct_start=.25,
                 cycle_momentum=True,
                 db=None,
                 files=None,
                 batch_size=64,
                 batch_seq_len=1024,
                 chunk_len=8*16000,
                 in_mem_data=True,
                 splits=None,  # tbptt should implement the splits...
                 **loaders_kwargs
                 ):
        super(LightningModule, self).__init__()
        SequenceModel.__init__(self)
        DataSubModule.__init__(self, db, files, in_mem_data, splits, batch_size=batch_size, **loaders_kwargs)
        SuperAdam.__init__(self, max_lr, betas, div_factor, final_div_factor, pct_start, cycle_momentum)
        SampleRNNNetwork.__init__(self, frame_sizes, net_dim, n_rnn, q_levels, emb_dim, mlp_dim)
        self.save_hyperparameters()

    def batch_info(self, *args, **kwargs):
        return tuple(self.hparams[key] for key in ["batch_size", "chunk_len", "batch_seq_len", "frame_sizes"])

    def setup(self, stage: str):
        SuperAdam.setup(self, stage)

    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
        if (batch_idx * self.hparams.batch_seq_len) % self.hparams.chunk_len == 0:
            self.reset_h0()

## Make or Get a db

Supposed to throw an error! Uncomment the lines you need.

In [None]:
db_name = "test_rnn.h5"

nep_con = NeptuneConnector(user="k-tonal",
                           setup={
                               "db": "data-and-base-notebooks",
                                 })


# use this 2 lines to create a db and upload it :

# db = SampleRNN.db_class.make(db_name, roots=["./"], sr=16000, mu=255)
# nep_con.upload_database("db", db_name)

# use those 2 lines to download and open a db that exists

# nep_con.download_database("db")
# db = SampleRNN.db_class(db_name)

db

## About the model's args


- the [original paper](https://arxiv.org/pdf/1612.07837.pdf) has some recommendations, and the [Dadabots](http://dadabots.com/nips2017/generating-black-metal-and-math-rock.pdf) made a pretty good job at describing their experiments.


- the most important arg is `frame_sizes` (see [the original repo](https://github.com/soroushmehr/sampleRNN_ICLR2017) for a visual aid to this)

    - `SampleRNN` doesn't have "layers", it has "tiers" : small models that process `frame_size` inputs at a time and combine their outputs with the outputs of the previous tier.
    
    - `frame_sizes` argument determine how many samples each tier (from top to bottom!) processes at a time. The repo's image corresponds to `frame_sizes=(16, 4, 4)` for tier 3, 2, 1, which does, in fact work pretty well...
    
    - **IMPORTANT!** you can have as many tiers as you want, but :
    
        1. the two last tiers must have the same `frame_size`
        
        2. dividing a tier's frame_size by the next tier's frame_size should always result in an exact integer. e.g. 
            - (128, 4, 4) => **yes** because 128 / 4 == 32
            - (12, 11, 11) => **no** because 12 / 11 == 1.0909090909090908
            
        3. The first frame_size has to be smaller or equal to an other arg : `batch_seq_length`

    - the original paper says `(8, 2, 2)` worked best. Dadabots used only 2 tiers, probably `(4, 4)` or similar. With this implementation you could go wild and do `(256, 128, 64, 32, 16, 8, 4, 2, 2)` or even more...
    

- Each tier but the last has a Recurrent Network with 1 or more layers. The `n_rnn` argument specifies how many layers **per tier**. It seems to me that it starts working when the whole model has a total of at least 4 rnns : e.g. `frame_sizes=(8, 2, 2)` & `n_rnn=2` corresponds to 4 rnns total (last tier always has 0 rnns). Dadabots made their streams with 2 tiers and the top tier had between 5 and 9 rnns...


- `*_dim` arguments are very similar to `model_dim` in `FreqNet`. 

    - `net_dim` is the most important and will greatly influence the trade-off between speed & expressivity. It could have been named `model_dim` because most of the network's parameters will have sizes proportional to `net_dim`. `512` works well. Maybe you can go down to `256` for more speed or up to `1024` for more expressivity... Definitely worths playing with!
    
    - `emb_dim` is just for a few parameters and might not be very important. `256` works, but I would expect so would `128` or `64`, maybe even `32`... More than `256` could be too much but, honnestly, IDUNO!... :)
    
    - `mlp_dim` is for the tipp of the model (which makes the prediction). `512` works. Once again, I'm not sure how relevant this `dim` is...


- `max_epochs` : it seems SampleRNN generates quite well very early! Values for the loss that resulted in cool outputs for me were around 1.6 to 1.9 and this was after just a few epochs! It seems even that training too long results in long silent outputs, this happened to me after 100 epochs and a loss around 1.4.


- `max_lr` : it also seems that SampleRNN withstands high learning rates, which also means faster training! As a comparaison, freqnet starts to diverge with `max_lr > 1e-3` but here `5e-3` works, even if it's probably already at the limit... If the loss starts to increase, fall back to `max_lr=1e-3` and you should be fine.
    
    
- the values used in the next cell seem to work quite well. In doubt, use them.

In [None]:
net = SampleRNN(db=db, 
                net_dim=512,
                emb_dim=256,
                mlp_dim=512,
                n_rnn=2,
                frame_sizes=(16, 4, 4),
                max_lr=5e-3,
                div_factor=3.,
                betas=(.9, .9),
                batch_size=64,
                batch_seq_len=1024
               )

trainer = get_trainer(root_dir="test_rnn",
                      max_epochs=50,
                      epochs=[49],
                     # uncomment these if you want to track with neptune :
#                      model=net,
#                      neptune_connector=nep_con,
                     )

net

## train
you can stop the training at any time and start/resume it with the next cell

In [None]:
trainer.fit(net)

## Generation

Generating with SampleRNN is quite flexible!

1. The model has an internal state that is suppose to encode what the model has seen _until now_. So before we let him run loose, we can "warm up" the model with a prompt of `n_warmups` batches. (I'm not sure if it changes much for the outputs but it's worth playing with...)


2. The generation method has 2 modes : deterministic and probabilistic. In the first, the output will always be the same for the same prompt/warm-up. But the second mode samples the outputs from probabilities that can be modified with a very interesting parameter : `temperature`.
    - `temperature` must be bigger than 0. and altough it could theoretically be greater than 1., values above 1. might not be so interesting because :
    - the higher the `temperature`, the "noisier" the output. The lower the temp, the more "frozen" the output. It is called "temperature" because it corresponds to some heat equations : more heat = particles move faster, less heat = particles stop moving. Musically, it means : hotter = more contrasts, cooler = longer sounds.
    - Concretly, I recommend starting around `temperature=0.5` and going tiny bits up & down....
    
    
3. Because generating in time-domain is much slower than in freq-dom, **generation is split in 2 cells**:
    - the first gets a **new prompt** and do some warmups
    - the second generates and **appends** the results to what has been previously generated.
This way, you can evaluate the first once and evaluate the 2nd several times. This beats waiting 30min to discover that it generated 30 seconds of silence...


4. You can generate `n_prompts` at the same time (like in redundance rate). This is much faster than generating one prompt at a time.


5. Because feeding data to SampleRNN is a bit complex, you'll have to stick to random prompts from the training data for now...

In [None]:
# WARM-UP

n_prompts = 8
n_warmups = 20


net.eval()
net = net.to("cuda")
net.reset_h0()

dl = iter(net.datamodule.train_dataloader())
for _ in tqdm(range(n_warmups)):
    inpt, trgt = next(dl)
    inpt = tuple(x[:n_prompts].to("cuda") for x in inpt)
    with torch.no_grad():
        new = nn.Softmax(dim=-1)(net(inpt)).argmax(dim=-1)
        
new.size()

In [None]:
########### PLAY WITH THOSE : ###################

# to use the deterministic mode, set temperature to None

temperature = 0.5

# 1 second = 16000 steps !

n_steps = 32000

###############################################

## LOS GEHT'S!

fs = [*net.frame_sizes]
outputs = [None] * (len(fs) - 1)
hiddens = net.hidden
tiers = net.tiers

with torch.no_grad():
    for t in tqdm(range(new.size(1), n_steps + new.size(1))):
        for i in range(len(tiers)-1):
            if t % fs[i] == 0:
                inpt = new[:, t-fs[i]:t].reshape(-1, 1, fs[i])
                
                if i == 0:
                    prev_out = None
                else:
                    prev_out = outputs[i-1][:, (t // fs[i]) % (fs[i-1] // fs[i])]
                    
                out, h = tiers[i](inpt, prev_out, hiddens[i])
                hiddens[i] = h
                outputs[i] = out
                
        prev_out = outputs[-1]
        inpt = new[:, t-fs[-1]:t].reshape(-1, 1, fs[-1])

        out, _ = tiers[-1](inpt, prev_out[:, (t % fs[-1]) - fs[-1]].unsqueeze(1))
        
        if temperature is None:
            
            pred = (nn.Softmax(dim=-1)(out)).argmax(dim=-1)
        
        else:
        
        # great place to implement dynamic cooling/heating !
            tempered = out / temperature
            
            pred = torch.multinomial(nn.Softmax(dim=-1)(tempered).squeeze(), 1)
        
        new = torch.cat((new, pred), dim=-1)


for i in range(new.size(0)):

    y = new[i].squeeze().cpu().numpy()
    y = A.SignalFrom.mu_law_compressed(y - 128, 255)

    print("prompt number", i)
    plt.figure(figsize=(20, 8))
    plt.plot(y)
    plt.show()

    audio(y, sr=16000)