In [None]:
"""Regeneration of lead synth from combined signal"""
from tensorflow.keras.layers import Dense, Dropout, PReLU
from tensorflow.keras.optimizers import Adam

from music_generator.basic.random import generate_dataset
from music_generator.basic.signalproc import SamplingInfo
from music_generator.musical.timing import Tempo
from music_generator.musical.scales import GenericScale
from music_generator.basic.signalproc import mix_at
from music_generator.analysis import preprocessing

from music_generator.musical import scales
import numpy as np
from multiprocessing import Pool
from functools import partial

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io.wavfile import read
from music_generator.analysis import regen_models
from scipy.io import wavfile
from music_generator.analysis import jamdataset
import tensorflow as tf

from music_generator.analysis import regen_models
import importlib

from tensorflow.keras.layers import Input, GRU, PReLU, Dropout, Dense, Reshape, Conv1D, LocallyConnected1D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop, Adam

from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, ModelCheckpoint
import datetime as dt

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 32
FRAGMENT_LENGTH = 4096
SR = 44100
TOTAL_LENGTH = 1000000

In [None]:
x = np.arange(TOTAL_LENGTH)
y = np.sin(x / SR * 440 * 2 * np.pi)

In [None]:
np.max(y)

In [None]:
def build_fully_seq_model(batch_size, fragment_length):
    
    inp = Input(batch_shape=[batch_size, fragment_length])
    hidden = inp
    
    hidden = Reshape([fragment_length, 1])(hidden)
    
    hidden = GRU(128, return_sequences=True, stateful=True)(hidden)
#     hidden = GRU(128, return_sequences=True, stateful=True)(hidden)    
    hidden = Dense(fragment_length, activation="relu")(hidden)
    hidden = Dense(1)(hidden)    
    
    hidden = Reshape([fragment_length])(hidden)
    
    out = hidden
    
    return Model(inp, out)

def fft_loss(y_target, y_predicted):
    y_target_complex = tf.cast(y_target, dtype=tf.complex64)
    y_predicted_complex = tf.cast(y_predicted, dtype=tf.complex64)

    loss = tf.square(tf.abs(tf.signal.fft(y_target_complex)) -
                     tf.abs(tf.signal.fft(y_predicted_complex)))
    return loss

model = build_fully_seq_model(BATCH_SIZE, FRAGMENT_LENGTH)
model.summary()
model.compile(Adam(1e-3), loss='mse')

In [None]:
def build_fft_model(batch_size, fragment_length):
    
    inp = Input(batch_shape=[batch_size, fragment_length])
    hidden = inp
    
    n_fft_steps = 1
    n_channels_fft = fragment_length // n_fft_steps
    
    hidden = Reshape([n_fft_steps, n_channels_fft])(hidden)
    
    hidden = tf.signal.fft(tf.cast(hidden, tf.complex64))
    hidden_abs = tf.math.abs(hidden)
    hidden_ang = tf.math.angle(hidden)
    
    hidden_abs = GRU(2048, return_sequences=True, stateful=True)(hidden_abs)
    hidden_abs = GRU(1024, return_sequences=True, stateful=True)(hidden_abs)    
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)
    hidden_abs = Dense(n_channels_fft, activation="relu")(hidden_abs)    
    
#     hidden_ang = GRU(2048, return_sequences=True, stateful=True)(hidden_ang)
#     hidden_ang = GRU(1024, return_sequences=True, stateful=True)(hidden_ang)    
#     hidden_ang = Dense(n_channels_fft, activation="relu")(hidden_ang)
#     hidden_ang = LocallyConnected1D(kernel_size=10, filters=1, padding="valid")(hidden_ang)        
    hidden_ang = Dense(n_channels_fft)(hidden_ang)

    hidden = tf.complex(hidden_abs * tf.math.cos(hidden_ang), hidden_abs * tf.math.sin(hidden_ang))
    
    # hidden = Dense(n_channels_fft)
    
    hidden = tf.signal.ifft(hidden)
    hidden = tf.cast(hidden, tf.float32)
    hidden = Reshape([fragment_length])(hidden)
    
    out = hidden
    
    return Model(inp, out)

def fft_loss(y_target, y_predicted):
    y_target_complex = tf.cast(y_target, dtype=tf.complex64)
    y_predicted_complex = tf.cast(y_predicted, dtype=tf.complex64)

    loss = tf.square(tf.abs(tf.signal.fft(y_target_complex)) -
                     tf.abs(tf.signal.fft(y_predicted_complex)))
    return loss

# model = build_fft_model(BATCH_SIZE, FRAGMENT_LENGTH)
# model.summary()
# model.compile(RMSprop(1e-3), loss=fft_loss)

In [None]:
from music_generator.analysis.data.filtering.generate import reshape_batches

In [None]:
x_train = reshape_batches(y, BATCH_SIZE, FRAGMENT_LENGTH)

In [None]:
model.fit(x_train, x_train, epochs=2)

In [None]:
plt.figure(figsize=[32,8])
pred = model.predict(x_train, verbose=1).reshape(-1)
Audio(pred.reshape(-1), rate=SR)

In [None]:
plt.figure(figsize=[32,8])
plt.plot(pred[2500:5000])
plt.plot(x_train.reshape(-1)[2500:5000])

In [None]:
Audio(x_train.reshape(-1), rate=SR)