In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""Regeneration of lead synth from combined signal"""
from tensorflow.keras.layers import Dense, Dropout, PReLU
from tensorflow.keras.optimizers import Adam

from music_generator.basic.random import generate_dataset
from music_generator.basic.signalproc import SamplingInfo
from music_generator.musical.timing import Tempo
from music_generator.musical.scales import GenericScale
from music_generator.basic.signalproc import mix_at
from music_generator.analysis import preprocessing

from music_generator.musical import scales
import numpy as np
from multiprocessing import Pool
from functools import partial

import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io.wavfile import read
from music_generator.analysis import regen_models
from scipy.io import wavfile
from music_generator.analysis import jamdataset
import tensorflow as tf

from music_generator.analysis import regen_models
import importlib

from tensorflow.keras.layers import Input, GRU, PReLU, Dropout, Dense, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop, Adam

from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, ModelCheckpoint
import datetime as dt

In [None]:
sr = 44100
sampling_info = SamplingInfo(sr)
fragment_length = 4096 * 5
batch_size = 32
sr = sampling_info.sample_rate

## Generate data

In [None]:
# Generate in all keys
all_roots = scales.chromatic_scale('C')
roots = [n.get_symbol() for n in all_roots.generate(0, 1)]
print(roots)

def generate_dataset_for_root(root):
    return generate_dataset(n_measures=32,
                            tempo=Tempo(120),
                            scale=GenericScale(root, [0, 2, 3, 5, 7, 8, 10]),
                            sampling_info=sampling_info)

def reshape_batches(x, batch_size, fragment_length):
    
    n_fragments = len(x) // fragment_length // batch_size
    x = x[:n_fragments * fragment_length * batch_size].reshape(-1, fragment_length)
    return x

def process(dataset, batch_size, fragment_length):
    # Make one big data set and make sure data is of same size        
    audio_tracks, mix = preprocessing.combine_datasets(dataset)
    
    x, y = mix, audio_tracks[2]
    
    x = reshape_batches(x, batch_size, fragment_length)
    y = reshape_batches(y, batch_size, fragment_length)    
    
    return x, y

_, x = wavfile.read("../data/full-mix-jam1-01.wav")
_, y = wavfile.read("../data/only-guitar-jam1-01.wav")

x_train, y_train = reshape_batches(x / 2**15, batch_size, fragment_length), reshape_batches(y / 2**15, batch_size, fragment_length)
    
# with Pool(8) as pool:
#     datasets_train = pool.map(generate_dataset_for_root, roots)
#     datasets_test = pool.map(generate_dataset_for_root, roots)
    
# x_train, y_train = process(datasets_train, batch_size, fragment_length)
# x_test, y_test = process(datasets_train, batch_size, fragment_length)
    

In [None]:
Audio(x_train[18].reshape(-1), rate=sr)

In [None]:
from tensorflow.keras.layers import Conv1D

In [None]:
# Conv1D?

In [None]:
# x_train.shape

In [None]:
# n_train = 4096
# n_test = 128

In [None]:
# Audio(input_track, rate=sr)
# Audio(target_track, rate=sr)

In [None]:
# x, y = preprocessing.create_training_data_set(n_train + n_test, 
#                                               fragment_length, 
#                                               input_track, 
#                                               target_track)

In [None]:
Audio(x_train[0], rate=sr)
# Audio(y_train[0], rate=sr)

## Construct model

In [None]:
def build_conv_model(batch_size, fragment_length):
    
    inp = Input(shape=[fragment_length])
    
    hidden = Reshape([fragment_length, 1])(inp)
    
    hidden = Conv1D(filters=20, kernel_size=5, padding='same')(hidden)
    hidden = Conv1D(filters=20, kernel_size=5, padding='same')(hidden)
    
    hidden = Dense(1024, activation="relu")(hidden)
    hidden = Dense(512, activation="relu")(hidden)
    
    hidden = Dense(1)(hidden)
    hidden = Reshape([fragment_length])(hidden)
    
    out = hidden
    
    return Model(inp, out)
    
    
def fft_loss(y_target, y_predicted):
    y_target_complex = tf.cast(y_target, dtype=tf.complex64)
    y_predicted_complex = tf.cast(y_predicted, dtype=tf.complex64)

    loss = tf.square(tf.abs(tf.signal.fft(y_target_complex))[100:] -
                     tf.abs(tf.signal.fft(y_predicted_complex))[100:])
    return loss

model = build_conv_model(batch_size, fragment_length)
model.summary()
model.compile(RMSprop(1e-3), loss='mse')

## Setup callbacks

In [None]:
callbacks = []

# log_file_name = f"tensorboard/{dt.datetime.now().strftime('%Y%m%d%H%M%S')}"
# tensorboard_callback = TensorBoard(log_dir=log_file_name, histogram_freq=1, update_freq='batch')
# callbacks.append(tensorboard_callback)

# reduce_lr_callback = ReduceLROnPlateau(verbose=1)
# callbacks.append(reduce_lr_callback)

# model_checkpoint_callback = ModelCheckpoint("weights.{epoch:02d}.h5")
# callbacks.append(model_checkpoint_callback)

In [None]:
Audio(model.predict(x_train[:64], verbose=1).reshape(-1), rate=sr)

## Fit model

In [None]:
# validation_data=[x_test, y_test], 

In [None]:
model.fit(x_train, y_train, callbacks=callbacks, epochs=1, shuffle=True, batch_size=batch_size)
Audio(model.predict(x_train[:batch_size*5], verbose=1).reshape(-1), rate=sr)

In [None]:
# n_batches_inference = len(input_track) // fragment_length
# inference_ds = input_track[:n_batches_inference * fragment_length]
# inference_ds = inference_ds.reshape(-1, fragment_length)

In [None]:
# n_batches_inference

In [None]:
output = model.predict(x_test, verbose=1).reshape(-1)

In [None]:
sr_file, data = wavfile.read("../data/full-mix-jam1-01.wav")
assert sr_file == sr, "Sample rate does not match, you will need to retrain"
data = reshape_batches(data, batch_size, fragment_length)

In [None]:
output_2 = model.predict(data, verbose=1).reshape(-1)

In [None]:
Audio(data.reshape(-1)[:1000000], rate=sr)

## Distortion due to phase matching issue?

In [None]:
tmp = model.predict(x_test, verbose=1)

In [None]:
tmp.shape

In [None]:
plt.plot(tmp.reshape(-1)[:1000])