In [None]:
"""Regeneration of lead synth from combined signal"""
from tensorflow.keras.layers import Dense, Dropout, PReLU
from tensorflow.keras.optimizers import Adam

from music_generator.basic.random import generate_dataset
from music_generator.basic.signalproc import SamplingInfo
from music_generator.musical.timing import Tempo
from music_generator.musical.scales import GenericScale
from music_generator.basic.signalproc import mix_at
from music_generator.analysis import preprocessing

from music_generator.musical import scales
import numpy as np
from multiprocessing import Pool
from functools import partial

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io.wavfile import read
from music_generator.analysis import regen_models
from scipy.io import wavfile
from music_generator.analysis import jamdataset
import tensorflow as tf

from music_generator.analysis import regen_models
import importlib

from tensorflow.keras.layers import Input, GRU, PReLU, Dropout, Dense, Reshape, Conv1D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop, Adam

from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, ModelCheckpoint
import datetime as dt

In [None]:
sr = 44100
sampling_info = SamplingInfo(sr)
fragment_length = 4096
batch_size = 32
sr = sampling_info.sample_rate

In [None]:
from music_generator.analysis.data.filtering import generate

In [None]:
x_train, y_train, x_test, y_test = generate.generate_synthetic_data(batch_size, fragment_length, sr)

In [None]:
def build_fft_model(batch_size, fragment_length):
    
    inp = Input(batch_shape=[batch_size, fragment_length])
    hidden = inp
    
    n_fft_steps = 4
    n_channels_fft = fragment_length // n_fft_steps
    
    hidden = Reshape([n_fft_steps, n_channels_fft])(hidden)
    
    hidden = tf.signal.fft(tf.cast(hidden, tf.complex64))
    hidden_abs = tf.math.abs(hidden)
    hidden_ang = tf.math.angle(hidden)
    
    hidden_abs = GRU(2048, return_sequences=True)(hidden_abs)
    hidden_abs = GRU(1024, return_sequences=True)(hidden_abs)    
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)    

    hidden = tf.complex(hidden_abs * tf.math.cos(hidden_ang), hidden_abs * tf.math.sin(hidden_ang))
    
    # hidden = Dense(n_channels_fft)
    
    hidden = tf.signal.ifft(hidden)
    hidden = tf.cast(hidden, tf.float32)
    hidden = Reshape([fragment_length])(hidden)
    
    out = hidden
    
    return Model(inp, out)

def fft_loss(y_target, y_predicted):
    y_target_complex = tf.cast(y_target, dtype=tf.complex64)
    y_predicted_complex = tf.cast(y_predicted, dtype=tf.complex64)

    loss = tf.square(tf.abs(tf.signal.fft(y_target_complex))[100:] -
                     tf.abs(tf.signal.fft(y_predicted_complex))[100:])
    return loss

model = build_fft_model(batch_size, fragment_length)
model.summary()
model.compile(RMSprop(1e-3), loss='mse')

In [None]:
callbacks = []

model.fit(x_train, y_train, validation_data=[x_test, y_test], callbacks=callbacks, epochs=25, batch_size=32)

In [None]:
output_test_1 = model.predict(x_test, verbose=1).reshape(-1)
output_train_1 = model.predict(x_train, verbose=1).reshape(-1)
Audio(output_test_1[:1000000], rate=sr)

In [None]:
WINDOW_LENGTH = 512

In [None]:
def build_fix_model(window_length):
    
    inp = Input(shape=[window_length])
    hidden = inp
    
    hidden = Reshape([window_length, 1])(hidden)
    
    hidden = Conv1D(filters=25, kernel_size=15, padding="same")(hidden)
    
    hidden = Dense(1)(hidden)
    hidden = Reshape([window_length])(hidden)
    
    out = hidden
    
    model = Model(inp, out)
    model.compile("adam", "mse")
    return model
    
fix_model = build_fix_model(WINDOW_LENGTH)
fix_model.summary()
    

In [None]:
y_train_fix_input = output_train_1.reshape(-1).reshape(-1, WINDOW_LENGTH)
y_train_fix_target = y_train.reshape(-1).reshape(-1, WINDOW_LENGTH)

In [None]:
fix_model.fit(y_train_fix_input, y_train_fix_target, epochs=10)

In [None]:
output_fixed = fix_model.predict(output_test_1.reshape(-1).reshape(-1, WINDOW_LENGTH)).reshape(-1)

In [None]:
Audio(output_fixed[:1000000], rate=sr)

In [None]:
Audio(output_test_1[:1000000], rate=sr)