In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [0]:
base_path = './drive/My Drive/Colab Notebooks/Speech Denoising/LotOfData/'

base_path_train = base_path + 'tr/'
base_path_val = base_path + 'v/'
base_path_test = base_path + 'te/'

base_path_pickle = base_path + 'pickle/'
base_path_result = base_path + 'result/'
base_path_model = base_path + 'model/'

In [3]:
!pip install librosa



In [0]:
import librosa
import pickle

In [0]:
import os
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
from sklearn.model_selection import train_test_split
import tensorflow as tf

In [0]:
trx = []
trs = []
trn = []

trx_val = []
trs_val = []
trn_val = []
target = []

trx_len = []

max_width = 513
max_length = 200

In [0]:
def get_stft(file_path):
  s, sr = librosa.load(file_path, sr=None)
  stft = librosa.stft(s, n_fft=1024, hop_length=512).T
  return stft

In [0]:
def load_from_directory(directory, file_prefix, all_required):
  file_prefix_dirty = file_prefix + 'x'
  file_prefix_clean = file_prefix + 's'
  file_prefix_noise = file_prefix + 'n'
  
  lfd_x = []
  lfd_s = []
  lfd_n = []
  lfd_id = []
  
  n = 0
  for file in sorted(os.listdir(directory)): 
    # consider only .wav files starting with file_prefix_dirty
    if file.endswith('.wav') and file.startswith(file_prefix_dirty):
      if n % 50 == 0:
        print(n, sep='.. ', flush=True)
      n += 1
      
      file_id = file[len(file_prefix_dirty):-len('.wav')]

      dirty_file_path = os.path.join(directory, file)

      if all_required:
        # check if there is corresponding target/clean file
        clean_file_name = file.replace(file_prefix_dirty, file_prefix_clean)
        clean_file_path = os.path.join(directory, clean_file_name)
        if not os.path.exists(clean_file_path):
          continue

        noise_file_name = file.replace(file_prefix_dirty, file_prefix_noise)
        noise_file_path = os.path.join(directory, noise_file_name)
        if not os.path.exists(clean_file_path):
          continue

      # load both dirty, clean and noise files
      train_dirty = get_stft(dirty_file_path)
      if all_required:
        train_clean = get_stft(clean_file_path)
        train_noise = get_stft(noise_file_path)
      
      lfd_x.append(train_dirty)
      lfd_id.append(file_id)
      
      if all_required:
        lfd_s.append(train_clean)
        lfd_n.append(train_noise)
        
  if all_required:
    return lfd_x, lfd_s, lfd_n, np.array(lfd_id)
  return lfd_x, np.array(lfd_id)

In [0]:
def pad_zeros(stft):
  stft_val = np.zeros((max_length, max_width))
  stft_val[:stft.shape[0], :stft.shape[1]] = np.abs(stft)
  return stft_val

def get_abs(stft_list):
  return np.array([pad_zeros(x) for x in stft_list])

def get_len(stft_list):
  return np.array([len(x) for x in stft_list])

In [11]:
# check if pickle files exist
if os.path.exists(base_path_pickle + 'tr.pickle'):
  print('loading pickle...')
  with open(base_path_pickle + 'tr.pickle', 'rb') as f:
    trx, trs, trn, trx_id = pickle.load(f)
  print('loading pickle complete.')
else:
  print('loading from directory...')
  trx, trs, trn, trx_id = load_from_directory(base_path_train, 'tr', True)
  print('loading from directory complete. saving...')
  with open(base_path_pickle + 'tr.pickle', 'wb') as f:
    pickle.dump([trx, trs, trn, trx_id], f)
  print('saving complete.')

trx_val = get_abs(trx)
trs_val = get_abs(trs)
trn_val = get_abs(trn)
target_tr = 1 * (trs_val > trn_val)

loading pickle...
loading pickle complete.


In [12]:
len(trx)

1200

In [13]:
trx_val.shape, trs_val.shape, trn_val.shape, target_tr.shape

((1200, 200, 513), (1200, 200, 513), (1200, 200, 513), (1200, 200, 513))

In [0]:
num_samples_tr = len(trx_val)
batch_size = 20
num_features = 513
num_hidden = 256

learning_rate = 0.001
num_epochs = 100

In [0]:
X = tf.placeholder(tf.float32, [None, max_length, num_features])
Y = tf.placeholder(tf.float32, [None, max_length, num_features])

In [16]:
cell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_hidden))
output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
dense_1 = tf.layers.Dense(units=513, activation=tf.nn.sigmoid)(output)

Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Colocations handled automatically by placer.


In [17]:
dense_1

<tf.Tensor 'dense/Sigmoid:0' shape=(?, 200, 513) dtype=float32>

In [18]:
# calculate loss - only calculate loss on valid data
loss = tf.losses.mean_squared_error(labels=Y, predictions=dense_1)
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=loss)
init = tf.global_variables_initializer()

Instructions for updating:
Use tf.cast instead.


In [0]:
sess = tf.Session()
saver = tf.train.Saver()

In [20]:
sess.run(tf.global_variables_initializer())

saver.restore(sess, base_path_model + 'model_95.ckpt')

# for epoch in range(num_epochs):
#   loss_val = 0
#   for i in range(0, num_samples_tr, batch_size):
#     start_idx = i
#     end_idx = min(i + batch_size, num_samples_tr)
    
#     batch_x = trx_val[start_idx:end_idx]
#     batch_y = target_tr[start_idx:end_idx]
    
#     _, lv = sess.run([train, loss], feed_dict={X: batch_x, Y: batch_y})
#     loss_val += lv
    
#   if epoch % 5 == 0:
#     print(epoch, loss_val)
#     saver.save(sess, base_path_model + 'model_' + str(epoch) + '.ckpt')

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./drive/My Drive/Colab Notebooks/Speech Denoising/LotOfData/model/model_95.ckpt


In [0]:
def snr(dirty, clean):
  return 10 * np.log10(np.sum(np.square(clean))/np.sum(np.square(clean - dirty)))

In [0]:
def save(cleaned, filename):
  sh_test = librosa.istft(cleaned.T, hop_length=512)
    
  # Save to a file
  librosa.output.write_wav(filename, sh_test, 16000)

In [23]:
# check if pickle files exist
if os.path.exists(base_path_pickle + 'v.pickle'):
  print('loading pickle...')
  with open(base_path_pickle + 'v.pickle', 'rb') as f:
    vx, vs, vn, vx_id = pickle.load(f)
  print('loading pickle complete.')
else:
  print('loading from directory...')
  vx, vs, vn, vx_id = load_from_directory(base_path_val, 'v', True)
  print('loading from directory complete. saving...')
  with open(base_path_pickle + 'v.pickle', 'wb') as f:
    pickle.dump([vx, vs, vn, vx_id], f)
  print('saving complete.')

vx_val = get_abs(vx)
vs_val = get_abs(vs)
vn_val = get_abs(vn)
vx_len = get_len(vx)
target_v = 1 * (vs_val > vn_val)

loading pickle...
loading pickle complete.


In [0]:
num_samples_v = len(vx)
total_snr = 0
num, den = 0, 0

for i in range(0, num_samples_v, batch_size):
  start_idx = i
  end_idx = min(i + batch_size, num_samples_v)
  batch_x = vx_val[start_idx:end_idx]

  m_pred = sess.run([dense_1], feed_dict={X: batch_x})
  for j in range(start_idx, end_idx):
    x = vx[j]
    s_val = vs_val[j][:vx_len[j], :]
    m = m_pred[0][j - start_idx][:vx_len[j], :]
    
    cleaned = x * m
    fname = base_path_result + 'cleaned' + vx_id[j] + '.wav'
    
#     save(x * m, fname)
    num += np.sum(np.square(s_val))
    den += np.sum(np.square(s_val - np.abs(cleaned)))

snr_val = 10 * np.log10(num/den)

In [25]:
snr_val

10.595921530168525

In [26]:
print('loading from directory...')
tex, tex_id = load_from_directory(base_path_test, 'te', False)
print('loading from directory complete.')

tex_val = get_abs(tex)
tex_len = get_len(tex)

loading from directory...
0
50
100
150
200
250
300
350
loading from directory complete.


In [0]:
num_samples_te = len(tex)

for i in range(0, num_samples_te, batch_size):
  start_idx = i
  end_idx = min(i + batch_size, num_samples_te)
  
  batch_x = tex_val[start_idx:end_idx]

  m_pred = sess.run([dense_1], feed_dict={X: batch_x})
  for j in range(start_idx, end_idx):
    x = tex[j]
    s = tex_val[j][:tex_len[j], :]
    m = m_pred[0][j - start_idx][:tex_len[j], :]
    
    cleaned = x * m
    fname = base_path_result + 'cleaned' + tex_id[j] + '.wav'
    save(x * m, fname)