In [None]:
from os import path, listdir, mkdir
import time

from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import signal
from scipy.io import wavfile
import scipy.signal as sps
import random

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

In [None]:
@tf.function
def load_wav_mono(filename, out_rate):
    rate, data = wavfile.read(filename)
    data = sps.resample(data, int(out_rate / rate * len(data)))
    data = data[:, np.newaxis]
    data = np.stack(np.array_split(data, 1))
    return data

In [None]:
def unet(pretrained_weights = None,input_size = (None, 1)):
  inputs = Input(input_size)
  norm0 = tf.keras.layers.BatchNormalization()(inputs)

  cnn1 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(norm0)
  cnn1 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn1)
  cnn1 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn1)
  pool1 = MaxPooling1D(pool_size=10)(cnn1)
  norm1 = BatchNormalization()(pool1)

  cnn2 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(norm1)
  cnn2 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn2)
  cnn2 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn2)
  pool2 = MaxPooling1D(pool_size=10)(cnn2)
  norm2 = BatchNormalization()(pool2)

  cnn3 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(norm2)
  cnn3 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn3)
  cnn3 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn3)
  drop3 = Dropout(0.5)(cnn3)
  pool3 = MaxPooling1D(pool_size=10)(drop3)
  norm3 = BatchNormalization()(pool3)
  
  cnn4 = Conv1D(128, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(norm3)
  cnn4 = Conv1D(128, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn4)
  cnn4 = Conv1D(128, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn4)
  drop4 = Dropout(0.5)(cnn4)
  norm4 = BatchNormalization()(drop4)

  up5 = UpSampling1D(10)(norm4) 
  cnn5 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(up5)
  concat5 = Concatenate()([drop3, cnn5])
  cnn5 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(concat5)
  cnn5 = Conv1D(64, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn5)
  norm5 = BatchNormalization()(cnn5)

  up6 = UpSampling1D(10)(norm5) 
  cnn6 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(up6)
  concat6 = Concatenate()([cnn2, cnn6])
  cnn6 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(concat6)
  cnn6 = Conv1D(32, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn6)
  norm6 = BatchNormalization()(cnn6)

  up7 = UpSampling1D(10)(norm6) 
  cnn7 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(up7)
  concat7 = Concatenate()([cnn1, cnn7])
  cnn7 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(concat7)
  cnn7 = Conv1D(16, 100, activation = 'relu', padding='same', kernel_initializer='glorot_uniform')(cnn7)
  norm7 = BatchNormalization()(cnn7)

  cnn8 = Conv1D(8, 100, activation = 'linear', padding='same', kernel_initializer='glorot_uniform')(norm7)
  cnn8 = Conv1D(1, 100, activation = 'linear', padding='same', kernel_initializer='glorot_uniform')(cnn8)

  model = Model(inputs = inputs, outputs = cnn8)
  model.compile(loss=tf.keras.losses.MeanSquaredError(),
                steps_per_execution = 1000,
                optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
  
  if(pretrained_weights):
    model.load_weights(pretrained_weights)

  return model

In [None]:
with tpu_strategy.scope():
  model = unet()

@tf.function
def load_a_batch_data(file_src_list):
  size = len(file_src_list)
  X = [0] * size
  Y = [0] * size
  for i in range(size):
    x = load_wav_mono(file_src_list[i][0], 8000)
    y = load_wav_mono(file_src_list[i][1], 8000)
    X[i] = x
    Y[i] = y
  X = tf.concat(X, 0)
  Y = tf.concat(Y, 0)
  return X, Y

@tf.function
def load_batch_data(load_batch_size, start_point, random_select=False):
  mixed_srcs = listdir('/content/drive/MyDrive/songs/mix/')
  mixed_srcs.sort()
  talk_srcs = listdir('/content/drive/MyDrive/songs/vocal/')
  talk_srcs.sort()
  mixed_srcs = ['/content/drive/MyDrive/songs/mix/' + el for el in mixed_srcs]
  talk_srcs = ['/content/drive/MyDrive/songs/vocal/' + el for el in talk_srcs]
  
  data_total_size = len(mixed_srcs)

  data_srcs = list(zip(mixed_srcs, talk_srcs))
  if random_select:
    data_srcs = random.sample(data_srcs, load_batch_size)
  else:
    start = start_point
    end = data_total_size if start_point + load_batch_size >= data_total_size else start_point + load_batch_size

    data_srcs = data_srcs[start:end] 
  

  X, Y = load_a_batch_data(data_srcs)
  return X, Y

def train_model(model, data_size, batch_size, epoches, start_point):
  LOAD_BATCH_SIZE = 200
  t = round(data_size / LOAD_BATCH_SIZE)
  checkpoint_filepath = '/content/drive/MyDrive/checkpoint/checkpoint.h5'
  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
      filepath=checkpoint_filepath,
      save_weights_only=True,
      monitor='val_loss',
      mode='min',
      save_best_only=True
      )
  total_time = 0
  for i in range(t):
    t_start = time.time()
    print('Loading Data Batch ' + str(i + 1) + '/' + str(t))
    X, Y = load_batch_data(LOAD_BATCH_SIZE, start_point, False)
    print('Data Batch ' + str(i + 1) + '/' + str(t) + ' has been loaded, ready to train. ')
    start_point += LOAD_BATCH_SIZE
    model.fit(x=X, y=Y, batch_size=batch_size, epochs = epoches, validation_split=0.2, callbacks=[model_checkpoint_callback], shuffle=True)
    model.save('/content/drive/MyDrive/model/model.h5')
    t_end = time.time()
    t_per_loop = t_end - t_start
    total_time += t_per_loop
    print('Time per Loop: ' + str(round(t_per_loop)))
  print('Total Time: ' + str(round(total_time)))

In [None]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, 1)      4           input_1[0][0]                    
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 16)     1616        batch_normalization[0][0]        
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 16)     25616       conv1d[0][0]                     
______________________________________________________________________________________________

In [None]:
train_model(model, 2000, 32, 300, 6000)