In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""Regeneration of lead synth from combined signal"""
from tensorflow.keras.layers import Dense, Dropout, PReLU
from tensorflow.keras.optimizers import Adam

from music_generator.basic.random import generate_dataset
from music_generator.basic.signalproc import SamplingInfo
from music_generator.musical.timing import Tempo
from music_generator.musical.scales import GenericScale
from music_generator.basic.signalproc import mix_at
from music_generator.analysis import preprocessing

from music_generator.musical import scales
import numpy as np
from multiprocessing import Pool
from functools import partial

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io.wavfile import read
from music_generator.analysis import regen_models
from scipy.io import wavfile
from music_generator.analysis import jamdataset
import tensorflow as tf

from music_generator.analysis import regen_models
import importlib

from tensorflow.keras.layers import Input, GRU, PReLU, Dropout, Dense, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop, Adam

from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, ModelCheckpoint
import datetime as dt

In [None]:
sr = 11025
sampling_info = SamplingInfo(sr)
n_train = 4096
n_test = 128
fragment_length = 4096
sr = sampling_info.sample_rate

## Generate data

In [None]:
# Generate in all keys
all_roots = scales.chromatic_scale('C')
roots = [n.get_symbol() for n in all_roots.generate(0, 1)]
print(roots)

def generate_dataset_for_root(root):
    return generate_dataset(n_measures=32,
                            tempo=Tempo(120),
                            scale=GenericScale(root, [0, 2, 3, 5, 7, 8, 10]),
                            sampling_info=sampling_info)
    
with Pool(8) as pool:
    datasets = pool.map(generate_dataset_for_root, roots)
    
# Make one big data set and make sure data is of same size    
audio_tracks, mix = preprocessing.combine_datasets(datasets)    

input_track = mix
target_track = audio_tracks[2]

In [None]:
# Audio(input_track, rate=sr)
# Audio(target_track, rate=sr)

In [None]:
x, y = preprocessing.create_training_data_set(n_train + n_test, 
                                              fragment_length, 
                                              input_track, 
                                              target_track)

In [None]:
x_train, y_train = x[:n_train], y[:n_train]
x_test, y_test = x[-n_test:], y[-n_test:]

In [None]:
del x, y

In [None]:
Audio(x_train[0], rate=sr)
# Audio(y_train[0], rate=sr)

## Construct model

In [None]:
def build_fft_model(fragment_length):
    
    inp = Input(shape=[fragment_length])
    hidden = inp
    
    n_channels_fft = fragment_length // 8
    
    hidden = Reshape([8, n_channels_fft])(hidden)
    
    hidden = tf.signal.fft(tf.cast(hidden, tf.complex64))
    hidden_abs = tf.math.abs(hidden)
    hidden_ang = tf.math.angle(hidden)
    
    hidden_abs = GRU(2048, return_sequences=True)(hidden_abs)
    hidden_abs = GRU(1024, return_sequences=True)(hidden_abs)    
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)    
    
    hidden = tf.complex(hidden_abs * tf.math.cos(hidden_ang), hidden_abs * tf.math.sin(hidden_ang))
    
    hidden = tf.signal.ifft(hidden)
    hidden = tf.cast(hidden, tf.float32)
    hidden = Reshape([fragment_length])(hidden)
    
    out = hidden
    
    return Model(inp, out)

def fft_loss(y_target, y_predicted):
    y_target_complex = tf.cast(y_target, dtype=tf.complex64)
    y_predicted_complex = tf.cast(y_predicted, dtype=tf.complex64)

    loss = tf.square(tf.abs(tf.signal.fft(y_target_complex)) -
                     tf.abs(tf.signal.fft(y_predicted_complex)))
    return loss

model = build_fft_model(fragment_length)
model.summary()
model.compile(RMSprop(1e-3), loss='mse')

## Setup callbacks

In [None]:
callbacks = []

# log_file_name = f"tensorboard/{dt.datetime.now().strftime('%Y%m%d%H%M%S')}"
# tensorboard_callback = TensorBoard(log_dir=log_file_name, histogram_freq=1, update_freq='batch')
# callbacks.append(tensorboard_callback)

# reduce_lr_callback = ReduceLROnPlateau(verbose=1)
# callbacks.append(reduce_lr_callback)

# model_checkpoint_callback = ModelCheckpoint("weights.{epoch:02d}.h5")
# callbacks.append(model_checkpoint_callback)

## Fit model

In [None]:
model.fit(x_train, y_train, validation_data=[x_test, y_test], callbacks=callbacks, epochs=5, shuffle=True, batch_size=32)

In [None]:
n_batches_inference = len(input_track) // fragment_length
inference_ds = input_track[:n_batches_inference * fragment_length]
inference_ds = inference_ds.reshape(-1, fragment_length)

In [None]:
n_batches_inference

In [None]:
output = model.predict(inference_ds, verbose=1).reshape(-1)

In [None]:
Audio(output[:1000000], rate=sr)