<a href="https://colab.research.google.com/github/matthewmcq/lanternfish/blob/main/vocals_model_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Imports

In [1]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [2]:
import sys
sys.path.append('/content/drive/MyDrive/lanternfish/Code/')
import Utils.Batch.generate_examples
import Utils.Batch.batch_data
import Utils.Plot
import tensorflow as tf
# import Models.wavelet_unet
import Config as cfg
# from Train import train, WaveletLoss
import numpy as np
import cv2
# from Utils.Wavelets import inverseWaveletReshape
import Utils.Wavelets
# import Utils.Wavelets.inverseWaveletReshape
import matplotlib.pyplot as plt
import pywt
import soundfile as sf

In [3]:
!pip install --upgrade tensorflow



In [4]:
SR = 44100

### Dataset Paths

In [5]:
### DO NOT CHANGE ###
MEDLEY2_PATH = 'Datasets/MedleyDB/V2/'
MEDLEY1_PATH = 'Datasets/MedleyDB/V1/'
TRAIN_PATH =  "/content/Datasets/TrainingData/"

### Set stem type to process
Options are: 'vocals', 'drums', 'bass', 'midrange'

In [6]:
CURR_STEM_TYPE = 'vocals'

### Preprocessing + Batching

In [7]:
def preprocess_medleydb(stem_type: str, clean: bool =False, sample_length=65536) -> None:
    '''
    Preprocess the MedleyDB dataset to generate training data

    params:
    - stem_type: str, type of stem to split (e.g. vocals, drums, bass, midrange)
    - clean: bool, flag to clean the training data

    return: None
    '''

    ## call clean_training_data() first to clean the training data if something goes wrong
    if clean:
        Utils.Batch.generate_examples.clean_training_data(TRAIN_PATH, stem_type)

    ## call generate_examples() to generate the examples
    Utils.Batch.generate_examples.generate_data(MEDLEY1_PATH, TRAIN_PATH, stem_type, sample_length) ## -- WORKS!
    Utils.Batch.generate_examples.generate_data(MEDLEY2_PATH, TRAIN_PATH, stem_type, sample_length) ## -- WORKS!

In [8]:
def batch_training_data(level: int = 12, batch_size: int = 8, max_songs: int = 2, max_samples_per_song: int = 10, num_features: int=65536) -> tf.data.Dataset:
    '''
    Batch the wavelet data for training

    params:
    - level: int, level of wavelet decomposition
    - batch_size: int, number of samples per batch
    - max_songs: int, maximum number of songs to include in the batch
    - max_samples_per_song: int, maximum number of samples per song

    return:
    - tf.data.Dataset, batched wavelet data
    '''
    ## call batch_wavelets() to batch the wavelet data
    y_train, y_true, shape = Utils.Batch.batch_data.batch_wavelets_dataset(TRAIN_PATH, CURR_STEM_TYPE, level, batch_size, max_songs, max_samples_per_song, num_features, diff=False)

    return y_train, y_true, shape


def batch_training_data_debug(level: int = 12, batch_size: int = 8, max_songs: int = 2, max_samples_per_song: int = 10, num_features: int=65536) -> tf.data.Dataset:
    '''
    Batch the wavelet data for training

    params:
    - level: int, level of wavelet decomposition
    - batch_size: int, number of samples per batch
    - max_songs: int, maximum number of songs to include in the batch
    - max_samples_per_song: int, maximum number of samples per song

    return:
    - tf.data.Dataset, batched wavelet data
    '''
    ## call batch_wavelets() to batch the wavelet data
    y_train, y_true, shape = Utils.Batch.batch_data.batch_wavelets_debug(TRAIN_PATH, CURR_STEM_TYPE, level, batch_size, max_songs, max_samples_per_song, num_features, diff=False)

    return y_train, y_true, shape

In [9]:

def inverseWaveletReshape(tensor_coeffs, shape, wavelet_depth):
    """
    Reverse the wavelet transform and downscale the tensor coefficients to match the original shape.

    Args:
        tensor_coeffs (tf.Tensor): The tensor of wavelet coefficients, with shape (max_features, wavelet_depth + 1).
        shape (tuple): The original shape of the waveform.
        wavelet_depth (int): The depth of the wavelet decomposition.

    Returns:
        list: A list of tuples representing the downscaled wavelet coefficients.
    """
    # Convert the tensor to a NumPy array
    # coeffs = tensor_coeffs.numpy()
    coeffs = tensor_coeffs

    # Create a list to store the downscaled coefficients
    downscaled_coeffs = []

    # Iterate over the wavelet levels
    for level in range(wavelet_depth + 1):
        # Get the coefficients for the current level
        level_coeffs = coeffs[:, level].numpy()
        # print(f"level_coeffs: {level_coeffs.shape}")
        # print(f"level_coeffs: {level_coeffs.shape}")
        # interval = shape[level][0] // level_coeffs.shape[0]
        # replace = level_coeffs[::interval, :]

        # print(f"replace: {replace}")
        # print(f"replace.shape: {replace.shape}")


        # Reshape the coefficients to match the original shape
        # reshaped_coeffs = level_coeffs.reshape(shape[level])
        dsize = (shape[level][0], 1)
        # print(f"dsize: {dsize}")
        reshaped_coeffs = cv2.resize(level_coeffs.reshape(1, -1), dsize=dsize, interpolation=cv2.INTER_AREA).flatten()
        # print(f"reshaped_coeffs.shape: {reshaped_coeffs.shape}")
        # print(f"reshaped_coeffs: {reshaped_coeffs}")

        # Collapse the noisy lower LOD detail and approximation coefficients
        # collapsed_coeffs = np.mean(reshaped_coeffs, axis=1)
        # collapsed_coeffs = np.median(reshaped_coeffs, axis=1)

        # Append the collapsed coefficients to the list
        downscaled_coeffs.append(reshaped_coeffs)

    # print(f"downscaled_coeffs: {downscaled_coeffs}")
    # downscaled_coeffs = np.array(downscaled_coeffs).flatten()
    return downscaled_coeffs

### Get prediction

In [10]:
def get_prediction(model, y_train, y_true, shape, model_config):
  print(f"expand dims shape: {tf.expand_dims(y_train[i], axis=0).shape}")
  predict_train_0 = model.predict(tf.expand_dims(y_train[i], axis=0))[0]
  predict_true_0 = y_true[i]

  print(f"predict_train_0.shape before reshape: {predict_train_0.shape}")

  predict_train_0 = inverseWaveletReshape(predict_train_0, shape, model_config['wavelet_depth'])
  predict_true_0 = inverseWaveletReshape(predict_true_0.numpy(), shape, model_config['wavelet_depth'])

  print(f"predict_train_0.shape aftr reshape: {([len(coef) for coef in predict_train_0])}")

  print(f"predict_train_0: {(predict_train_0)}")
  print(f"predict_true_0: {(predict_true_0)}")

  return predict_train_0, predict_true_0

### Generating Audio Output

In [11]:
def get_wav_output(predict_train_0, predict_true_0, index):
  output = pywt.waverec(predict_train_0, 'haar', axis=-1)
  output_true = pywt.waverec(predict_true_0, 'haar', axis=-1)

  sf.write(f'test{index}.wav', output, SR)
  sf.write(f'true{index}.wav', output_true, SR)

### Plot Wavelet Data

In [12]:
def plot_wavelet_data(predict_train_0, predict_true_0, model_config):
  # Prepare the time axis
  time = np.arange(sum([len(coef) for coef in predict_true_0]))

  plt.rcParams['figure.figsize'] = (8, 8)
  plt.rcParams['font.size'] = 8

  # Plot the wavelet coefficients
  fig, ax = plt.subplots(model_config['wavelet_depth'] + 1, 1, figsize=(12, 6))

  # # Plot the original signal
  # start = 0
  # for level, coef in enumerate(predict_true_0):
  #     ax[0].plot(time[start:start+len(coef)], coef, label=f'Level {level+1}')
  #     start += len(coef)
  # ax[0].set_title('Original Signal')
  # ax[0].set_xlabel('Time')
  # ax[0].set_ylabel('Amplitude')
  # ax[0].legend()

  # Plot the wavelet coefficients
  start = 0
  for level in range(model_config['wavelet_depth']):
      ax[level + 1].plot(time[start:start+len(predict_true_0[-level])], predict_true_0[-level], label='True')
      ax[level + 1].plot(time[start:start+len(predict_train_0[-level])], predict_train_0[-level], label='Predicted')
      ax[level + 1].set_title(f'Wavelet Level {level + 1}')
      ax[level + 1].set_xlabel('Time')
      ax[level + 1].set_ylabel(f'Coef Level {level + 1}')
      ax[level + 1].legend()
      start += len(predict_true_0[-level])

  plt.tight_layout()
  plt.show()

### Main Method - Train and save model
Trains model, saves to .keras file, and calls get_prediction().

In [13]:
@tf.keras.utils.register_keras_serializable()
class WaveletUNet(tf.keras.Model):

    def __init__(self, num_coeffs, wavelet_depth, batch_size, channels, num_layers, num_init_filters, filter_size, merge_filter_size, l1_reg, l2_reg, **kwargs):
        super().__init__(**kwargs)
        self.num_coeffs = num_coeffs
        self.wavelet_depth = wavelet_depth + 1
        self.batch_size = batch_size
        self.channels = channels
        self.num_layers = num_layers
        self.num_init_filters = num_init_filters
        self.filter_size = filter_size
        self.merge_filter_size = merge_filter_size
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg

        self.input_shape = (self.batch_size, self.num_coeffs, self.wavelet_depth)

    @classmethod
    def from_config(cls, config):
        # Extract the necessary arguments from the config dictionary
        num_coeffs = config.pop('num_coeffs')
        wavelet_depth = config.pop('wavelet_depth')
        batch_size = config.pop('batch_size')
        channels = config.pop('channels')
        num_layers = config.pop('num_layers')
        num_init_filters = config.pop('num_init_filters')
        filter_size = config.pop('filter_size')
        merge_filter_size = config.pop('merge_filter_size')
        l1_reg = config.pop('l1_reg')
        l2_reg = config.pop('l2_reg')

        return cls(
            num_coeffs=num_coeffs,
            wavelet_depth=wavelet_depth,
            batch_size=batch_size,
            channels=channels,
            num_layers=num_layers,
            num_init_filters=num_init_filters,
            filter_size=filter_size,
            merge_filter_size=merge_filter_size,
            l1_reg=l1_reg,
            l2_reg=l2_reg,
            **config  # Pass any remaining arguments to the constructor
        )

    # Create an instance of WaveletUNet with the extracted arguments
    def get_config(self):
        config = super().get_config()
        config.update({
            'num_coeffs': self.num_coeffs,
            'wavelet_depth': self.wavelet_depth,
            'batch_size': self.batch_size,
            'channels': self.channels,
            'num_layers': self.num_layers,
            'num_init_filters': self.num_init_filters,
            'filter_size': self.filter_size,
            'merge_filter_size': self.merge_filter_size,
            'l1_reg': self.l1_reg,
            'l2_reg': self.l2_reg
        })
        return config

    def build(self, input_shape):
        # Create downsampling blocks
        self.downsampling_blocks = {}
        self.learnable_downsampling_blocks = {}
        self.P = {}
        self.U = {}
        for i in range(self.num_layers):
            block_name = f'{i+1}'
            num_filters = self.num_init_filters + (self.num_init_filters * i)
            self.downsampling_blocks[block_name] = DownsamplingLayer(num_filters, self.filter_size, name=block_name, l1_reg=self.l1_reg, l2_reg=self.l2_reg)
            # self.learnable_downsampling_blocks[block_name] = LearnableDownsamplingLayer(num_filters, self.filter_size, name=block_name, l1_reg=self.l1_reg, l2_reg=self.l2_reg)
            self.P[block_name] = tf.keras.layers.Conv1D(
                num_filters,
                3,
                activation=None,
                padding='same',
                name=f'P_{block_name}'
            )
            self.U[block_name] = tf.keras.layers.Conv1D(
                num_filters,
                3,
                activation=None,
                padding='same',
                name=f'U_{block_name}'
            )

        # Create bottle neck
        self.bottle_neck = tf.keras.layers.Conv1D(
            self.num_init_filters * (self.num_layers + 1),
            self.filter_size,
            activation='leaky_relu',
            padding='same',
            name='bottleneck_conv',
            kernel_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg),
            activity_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg)
        )

        # Create upsampling blocks
        # self.upsampling_blocks = {}
        self.us_conv1d = {}
        self.even = {}
        self.odd = {}
        for i in range(self.num_layers):
            block_name = f'{self.num_layers - i}'
            num_filters = self.num_init_filters + (self.num_init_filters * (self.num_layers - i - 1))
            # out_channels = num_filters // 2

            # self.upsampling_blocks[block_name] = LearnableUpsamplingLayer(num_filters, self.merge_filter_size, name=block_name, l1_reg=self.l1_reg, l2_reg=self.l2_reg)

            self.us_conv1d[block_name] = tf.keras.layers.Conv1D(
                num_filters,
                self.merge_filter_size,
                activation='leaky_relu',
                padding='same',
                name=f'us_conv1d_{block_name}',
                trainable=True
            )
            self.even[block_name] = tf.keras.layers.Conv1D(
                num_filters,
                1,
                activation=None,
                padding='same',
                name=f'even_{block_name}',
                trainable=True
            )
            self.odd[block_name] = tf.keras.layers.Conv1D(
                num_filters,
                1,
                activation=None,
                padding='same',
                name=f'odd_{block_name}',
                trainable=True
            )


        self.output_conv3 = tf.keras.layers.Conv1D(
            1,
            1,
            activation='tanh',
            padding='same',
            name='output_conv3',
            # kernel_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg),
            # activity_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg)
        )
        super().build(input_shape)


    def call(self, inputs, is_training=True):


        current_layer = inputs

        full_mix = tf.math.reduce_sum(current_layer, axis=-1)

        enc_outputs = list()

        # Downsampling path
        for i in range(self.num_layers):

            block_name = f'{i+1}'

            current_layer = self.downsampling_blocks[block_name](current_layer)


            # Save for skip connections
            enc_outputs.append(current_layer)

            # Decimation step
            x_even, x_odd = current_layer[:, ::2, :], current_layer[:, 1::2, :]
            # print(f"shape of even: {x_even.shape}")
            # print(f"shape of odd: {x_odd.shape}")
            d = x_odd - self.P[block_name](x_even)

            c = x_even + self.U[block_name](d)

            A = 2**(1/2)

            c = c * A
            d = d * 1/A

            current_layer = tf.concat([c, d], axis=-1)


        # Bottle neck
        current_layer = self.bottle_neck(current_layer)

        # Upsampling path
        for i in range(self.num_layers):

            block_name = f'{self.num_layers - i}'

            x_even, x_odd = current_layer[:, :, :-current_layer.shape[-1]//2], current_layer[:, :, -current_layer.shape[-1]//2:]
            x_even = self.even[block_name](x_even)
            x_odd = self.odd[block_name](x_odd)
            # print(f"shape of even: {x_even.shape}")
            # print(f"shape of odd: {x_odd.shape}")

            A = 2**(1/2)
            x_odd *= A
            x_even *= 1/A

            # print(f"shape of u_setp: {u_step.shape}")
            c = x_even - self.U[block_name](x_odd)

            d = x_odd + self.P[block_name](c)

            # print(f"shape of d: {d.shape}")

            output = tf.concat([c, d], axis=1)
            # print(f"shape of output: {output.shape}")

            indices = []
            num_entries = x_even.shape[1]
            num_outputs = 2 * num_entries

            for idx in range(num_outputs):
                if idx % 2 == 0:
                    indices.append(idx // 2)
                else:
                    indices.append(num_entries + idx // 2)

            current_layer = tf.gather(output, indices, axis = 1)

            # Get skip connection
            skip_conn = enc_outputs[-i-1]

            # Pad if necessary
            desired_shape = skip_conn.shape

            ### NEW CROPPING METHOD -- crop current_layer to match skip_conn
            if current_layer.shape[1] != desired_shape[1]:
                if current_layer.shape[1] != desired_shape[1]:
                    diff = desired_shape[1] - current_layer.shape[1]
                    if diff >0:
                        pad_start = diff // 2
                        pad_end = diff - pad_start
                        current_layer = tf.pad(current_layer, [[0, 0], [pad_start, pad_end], [0,0]], mode='SYMMETRIC')
                    else:
                        diff = -diff
                        crop_start = diff // 2
                        current_layer = tf.slice(current_layer, [0, crop_start, 0], [-1, desired_shape[1], -1])


            # Concatenate with skip connection
            current_layer = tf.keras.layers.Concatenate()([current_layer, skip_conn])

            conv1d = self.us_conv1d[block_name]
            current_layer = conv1d(current_layer)



        desired_shape = full_mix.shape

        if current_layer.shape[1] != desired_shape[1]:
            diff = desired_shape[1] - current_layer.shape[1]
            if diff >0:
                pad_start = diff // 2
                pad_end = diff - pad_start
                current_layer = tf.pad(current_layer, [[0, 0], [pad_start, pad_end], [0,0]], mode='SYMMETRIC')
            else:
                diff = -diff
                crop_start = diff // 2
                current_layer = tf.slice(current_layer, [0, crop_start, 0], [-1, desired_shape[1], -1])


        current_layer = tf.keras.layers.Concatenate()([tf.expand_dims(full_mix, axis=-1), current_layer])

        current_layer = self.output_conv3(current_layer)

        return current_layer

@tf.keras.utils.register_keras_serializable()
class DownsamplingLayer(tf.keras.layers.Layer):
    def __init__(self, num_filters, filter_size, l1_reg=0.0, l2_reg=0.0, **kwargs):
        super().__init__(**kwargs)
        self.num_filters = num_filters
        self.filter_size = filter_size
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg

    def build(self, input_shape):
        self.conv = tf.keras.layers.Conv1D(
            self.num_filters,
            self.filter_size,
            activation='leaky_relu',
            padding='same',
            kernel_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg),
            activity_regularizer=tf.keras.regularizers.l1_l2(l1=self.l1_reg, l2=self.l2_reg),
            name=f'downsampling_conv_{self.num_filters}'
        )
        super().build(input_shape)

    def call(self, inputs):
        x = self.conv(inputs)
        return x


    def get_config(self):
        config = super().get_config()
        config.update({
            'num_filters': self.num_filters,
            'filter_size': self.filter_size,
            'l1_reg': self.l1_reg,
            'l2_reg': self.l2_reg
        })
        return config

    def from_config(cls, config):
        return cls(**config)





### 2D Model

### CONFIG

In [14]:
def cfg():
    # Base configuration
    model_config = {'num_layers' : 12, # How many U-Net layers
                    'filter_size' : 15, # For Wave-U-Net: Filter size of conv in downsampling block
                    'merge_filter_size' : 5, # For Wave-U-Net: Filter size of conv in upsampling block
                    'num_init_filters': 24, # THIS MUST BE DIVISIBLE BY 4 IF USING DUALWAVELETUNET


                    'learning_rate': 1e-4, # determine's the model's learning rate
                    'validation_split': 0.2, # determines what % of training data is used for validation
                    'channels': 1,
                    'num_coeffs': 16384, # Number of audio samples/detail coefficients per input; currently 220500 for 10 sec audio snippets (our equivalent of num_frames from Wave-U-Net)
                    'wavelet_depth': 2,
                    'batch_size' : 16, # Batch size
                    'epochs': 1000,
                    'max_songs': 86, # 86 = all songs vox, 84 = all songs bass, 106 = all songs drumkit
                    'max_samples_per_song': 700, #

                    'l1_reg': 1e-11, # L1 regularization -> sparse
                    'l2_reg': 1e-12, # L2 regularization -> non-sparse

                    'lambda_vec': [1],
                    'lambda_11': 1,
                    'lambda_12': 1,
                    }

    return model_config


### Retrain config

In [15]:
def cfg_retrain():
    # Base configuration
    model_config = {'num_layers' : 12, # How many U-Net layers
                    'filter_size' : 15, # For Wave-U-Net: Filter size of conv in downsampling block
                    'merge_filter_size' : 5, # For Wave-U-Net: Filter size of conv in upsampling block
                    'num_init_filters': 24, # THIS MUST BE DIVISIBLE BY 4 IF USING DUALWAVELETUNET


                    'learning_rate': 4e-5, # determine's the model's learning rate
                    'validation_split': 0.2, # determines what % of training data is used for validation
                    'channels': 1,
                    'num_coeffs': 16384, # Number of audio samples/detail coefficients per input; currently 220500 for 10 sec audio snippets (our equivalent of num_frames from Wave-U-Net)
                    'wavelet_depth': 2,
                    'batch_size' : 32, # Batch size
                    'epochs': 1000,
                    'max_songs': 86, # 86 = all songs
                    'max_samples_per_song': 700, #

                    'l1_reg': 1e-12, # L1 regularization -> sparse
                    'l2_reg': 1e-11, # L2 regularization -> non-sparse

                    'lambda_vec': [1],
                    'lambda_11': 1,
                    'lambda_12': 1,
                    }

    return model_config

### Train

In [16]:
import soundfile as sf

def train(model, model_config, loss, train, val):
    es = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=20,
        restore_best_weights=True,
        start_from_epoch=0
    )


    optimizer = tf.keras.optimizers.Adam(learning_rate=model_config['learning_rate'])
    metrics = [tf.keras.metrics.RootMeanSquaredError(), tf.keras.metrics.MeanSquaredError()]
    # Compile the model
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # Train the model
    model.fit(
        train,
        epochs=model_config['epochs'],
        validation_data=val,
        callbacks=[es]

    )
    BATCH_PARAMS = (model_config['wavelet_depth'], model_config['batch_size'], 40, 10)


    y_train, y_true, shape = batch_training_data_debug(*BATCH_PARAMS)

    for i in range(50):
        prediction = model.predict(tf.expand_dims(y_train[i], axis=0))[0]
        true = np.transpose(y_true[i], (1,0))
        a3, d3, d2 = true
        sum_true = a3 + d3 + d2

        sum_pred = tf.squeeze(prediction, axis=-1)

        train = np.transpose(y_train[i], (1,0))
        a1, d1, d0 = train
        sum_train = a1 + d1 + d0

        sf.write(f'/content/drive/MyDrive/lanternfish/Code/examples/{CURR_STEM_TYPE}/train_{i}.wav', sum_train, 22050)
        sf.write(f'/content/drive/MyDrive/lanternfish/Code/examples/{CURR_STEM_TYPE}/true_{i}.wav', sum_true, 22050)
        sf.write(f'/content/drive/MyDrive/lanternfish/Code/examples/{CURR_STEM_TYPE}/pred_{i}.wav', sum_pred, 22050)

    return model

@tf.keras.utils.register_keras_serializable()
class WaveletLoss(tf.keras.losses.Loss):
    def __init__(self, wavelet_level=4, lambda_vec=[10, 1000, 1000], lambda_11=1, lambda_12=0.25, name='wavelet_loss',   l1_reg=0.0, l2_reg=0.0, **kwargs):
        super().__init__(name=name, **kwargs)
        self.wavelet_level = wavelet_level
        self.lambda_vec = lambda_vec
        self.lambda_11 = lambda_11
        self.lambda_12 = lambda_12
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg

    # # @tf.function
    def call(self, y_true, y_pred):

        # Sum the audios along the wavelet_filter dimension for each example in the batch
        summed_true = tf.math.reduce_sum(y_true, axis=-1)
        summed_pred = tf.math.reduce_sum(y_pred, axis=-1)

        # Calculate the mean squared error between the summed audios for each example in the batch
        mse = tf.math.reduce_mean(tf.math.square(summed_true - summed_pred))

        # Take the mean of the MSE across the batch
        return mse


    def get_config(self):
        config = super().get_config()
        config.update({
            'wavelet_level': self.wavelet_level,
            'lambda_vec': self.lambda_vec,
            'lambda_11': self.lambda_11,
            'lambda_12': self.lambda_12,
            'l1_reg': self.l1_reg,
            'l2_reg': self.l2_reg
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

### Main Train

In [17]:
def main_train():

    # model_config = cfg.test_saving()
    # aggregate_dataset()

    model_config = cfg()

    ## Set the parameters -- might want to move to Config.py later
    WAVELET_DEPTH = model_config['wavelet_depth'] # level of wavelet decomposition
    BATCH_SIZE = model_config['batch_size'] # number of samples per batch
    MAX_SONGS = model_config['max_songs'] # maximum number of songs to include in the batch
    MAX_SAMPLES_PER_SONG = model_config['max_samples_per_song'] # maximum number of samples per song to include in the batch

    ## Set the batch parameters, pass to batch_training_data()
    BATCH_PARAMS = (WAVELET_DEPTH, BATCH_SIZE, MAX_SONGS, MAX_SAMPLES_PER_SONG)


    ## set the batch size and epochs
    batch_size = model_config['batch_size']
    epochs = model_config['epochs']


    model = WaveletUNet(
            num_coeffs=model_config['num_coeffs'],
            wavelet_depth=model_config['wavelet_depth'],
            batch_size=model_config['batch_size'],
            channels=model_config['channels'],
            num_layers=model_config['num_layers'],
            num_init_filters=model_config['num_init_filters'],
            filter_size=model_config['filter_size'],
            merge_filter_size=model_config['merge_filter_size'],
            l1_reg=model_config['l1_reg'],
            l2_reg=model_config['l2_reg']
        )

    # define a dummy input to build the model
    model(tf.random.normal(shape=(batch_size, model_config['num_coeffs'], WAVELET_DEPTH+1)))

    # print the model summary
    model.summary()

    dataset, validation_data, shape = batch_training_data(*BATCH_PARAMS)

    print("y_train shape:", shape)
    loss =  WaveletLoss(wavelet_level=model_config['wavelet_depth'], lambda_vec=model_config['lambda_vec'], lambda_11=model_config['lambda_11'], lambda_12=model_config['lambda_12'], name='wavelet_loss')

    ## train the model
    model = train(model, model_config, loss, dataset, validation_data)

    model_name = f'/content/drive/MyDrive/lanternfish/Code/new_goated_{CURR_STEM_TYPE}_v1.keras'
    model.save(model_name)
    # model.save('wavelet_unet_model.h5')

    loaded_model = tf.keras.models.load_model(model_name)
    # loaded_model = tf.keras.models.load_model('wavelet_unet_model.h5')

### retrain

In [18]:
def retrain():

    # model_config = cfg.test_saving()
    # aggregate_dataset()

    model_config = cfg_retrain()

    ## Set the parameters -- might want to move to Config.py later
    WAVELET_DEPTH = model_config['wavelet_depth'] # level of wavelet decomposition
    BATCH_SIZE = model_config['batch_size'] # number of samples per batch
    MAX_SONGS = model_config['max_songs'] # maximum number of songs to include in the batch
    MAX_SAMPLES_PER_SONG = model_config['max_samples_per_song'] # maximum number of samples per song to include in the batch

    ## Set the batch parameters, pass to batch_training_data()
    BATCH_PARAMS = (WAVELET_DEPTH, BATCH_SIZE, MAX_SONGS, MAX_SAMPLES_PER_SONG)


    ## set the batch size and epochs
    batch_size = model_config['batch_size']
    epochs = model_config['epochs']

    custom_objects = {
    'WaveletUNet': WaveletUNet,
    'DownsamplingLayer': DownsamplingLayer
    }
    model = tf.keras.models.load_model(f'/content/drive/MyDrive/lanternfish/Code/new_goated_{CURR_STEM_TYPE}_v1.keras', custom_objects=custom_objects)


    # define a dummy input to build the model
    model(tf.random.normal(shape=(batch_size, model_config['num_coeffs'], WAVELET_DEPTH+1)))

    # print the model summary
    model.summary()

    dataset, validation_data, shape = batch_training_data(*BATCH_PARAMS)

    print("y_train shape:", shape)
    loss =  WaveletLoss(wavelet_level=model_config['wavelet_depth'], lambda_vec=model_config['lambda_vec'], lambda_11=model_config['lambda_11'], lambda_12=model_config['lambda_12'], name='wavelet_loss')

    ## train the model
    model = train(model, model_config, loss, dataset, validation_data)

    model_name = f'/content/drive/MyDrive/lanternfish/Code/new_goated_{CURR_STEM_TYPE}_v1_RETRAIN.keras'
    model.save(model_name)
    # model.save('wavelet_unet_model.h5')

    loaded_model = tf.keras.models.load_model(model_name)
    # loaded_model = tf.keras.models.load_model('wavelet_unet_model.h5')

### Main Method - Load model

In [19]:
def main_load():

  # model_config = cfg.test_saving()
  model_config = cfg()

  ## Set the parameters -- might want to move to Config.py later
  WAVELET_DEPTH = model_config['wavelet_depth'] # level of wavelet decomposition
  BATCH_SIZE = model_config['batch_size'] # number of samples per batch
  MAX_SONGS = model_config['max_songs'] # maximum number of songs to include in the batch
  MAX_SAMPLES_PER_SONG = model_config['max_samples_per_song'] # maximum number of samples per song to include in the batch

  ## Set the batch parameters, pass to batch_training_data()
  BATCH_PARAMS = (WAVELET_DEPTH, BATCH_SIZE, MAX_SONGS, MAX_SAMPLES_PER_SONG)

  ## batch the data for medleyDB
  # preprocess_medleydb(CURR_STEM_TYPE, clean=True)

  ## set the batch size and epochs
  batch_size = model_config['batch_size']
  epochs = model_config['epochs']

  ## test that generate_pairs() works
  y_train, y_true, shape = batch_training_data(*BATCH_PARAMS)

  print("y_train shape:", shape)

  ## check the loss function for all zeros
  zero_train = tf.zeros_like(y_train)


  wavelet_loss = WaveletLoss(
      wavelet_level=model_config['wavelet_depth'],
      lambda_vec=model_config['lambda_vec'],
      lambda_11=model_config['lambda_11'],
      lambda_12=model_config['lambda_12'],
  )

  ## check default loss:
  loss = WaveletLoss( wavelet_level=model_config['wavelet_depth'], lambda_vec=model_config['lambda_vec'], lambda_11=model_config['lambda_11'], lambda_12=model_config['lambda_12'], name='wavelet_loss')
  print("Default Loss with regularization:", loss(y_true, y_train))
  print("Default Loss (All zeros):", loss(y_true, zero_train))

  ## define the model
  model = Models.wavelet_unet.WaveletUNet(
      num_coeffs=model_config['num_coeffs'],
      wavelet_depth=model_config['wavelet_depth'],
      batch_size=model_config['batch_size'],
      channels=1,
      num_layers=model_config['num_layers'],
      num_init_filters=model_config['num_init_filters'],
      filter_size=model_config['filter_size'],
      merge_filter_size=model_config['merge_filter_size'],
      l1_reg=model_config['l1_reg'],
      l2_reg=model_config['l2_reg']
      )

  # define a dummy input to build the model
  model(tf.random.normal(shape=(batch_size, model_config['num_coeffs'], WAVELET_DEPTH+1)))

  # print the model summary
  model.summary()
  ## train the model
  model = train(model, wavelet_loss, y_train, y_true, epochs, batch_size)

  model_name = f'wavelet_unet_model_nif{model_config["num_init_filters"]}_filter{model_config["filter_size"]}_layers{model_config["num_layers"]}.keras'
  model.save(model_name)
  # model.save('wavelet_unet_model.h5')

  # loaded_model = tf.keras.models.load_model(model_name)
  # loaded_model = tf.keras.models.load_model('wavelet_unet_model.h5')

  for i in range(10):
    predict_train_0, predict_true_0 = get_prediction(model, y_train, y_true, shape)
    plot_wavelet_data(predict_train_0, predict_true_0, model_config)
    get_wav_output(predict_train_0, predict_true_0, i)

### Run

In [20]:
import os
import zipfile

def unzip_to_colab():
    zip_dir = f'/content/drive/MyDrive/lanternfish/Code/Datasets/TrainingData/{CURR_STEM_TYPE}.zip'
    extract_base_dir = '/content/Datasets/TrainingData/'

    os.makedirs(extract_base_dir, exist_ok=True)

    with zipfile.ZipFile(zip_dir, 'r') as zip_ref:
        zip_ref.extractall(extract_base_dir)

    print("Unzipping to Colab storage completed.")

In [21]:
# unzip_to_colab()

Unzipping to Colab storage completed.


In [None]:
main_train()



(TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None), TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None), TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None))
(TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None), TensorSpec(shape=(16384, 3), dtype=tf.float32, name=None))
y_train shape: [(4096,), (4096,), (8192,)]
Epoch 1/1000
[1m2008/2008[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m342s[0m 135ms/step - loss: 0.0053 - mean_squared_error: 0.0043 - root_mean_squared_error: 0.0626 - val_loss: 0.0030 - val_mean_squared_error: 0.0025 - val_root_mean_squared_error: 0.0503
Epoch 2/1000
[1m2008/2008[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m226s[0m 113ms/step - loss: 0.0029 - mean_squared_error: 0.0032 - root_mean_squared_error: 0.0563 - val_loss: 0.0027 - val_mean_squared_error: 0.0033 - val_root_mean_squared_error: 0.0573
Epoch 3/1000
[1m1936/2008[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m7s

In [None]:
retrain()