<a href="https://colab.research.google.com/github/k-washi/notebook/blob/master/segan_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GANによる音声強調

## 音声強調とは

他の音響信号の干渉やネットワーク障害など、様々な形で劣化する。

NoisyAudio = Noise + Target

[SEGAN: Speech Enhancement Generative Adversarial Network](https://arxiv.org/pdf/1703.09452.pdf)


## インストール

In [0]:
pip install tensorboardx



In [0]:
!unzip './drive/My Drive/data/speech_enhance/noisy_testset_16kHz.zip'

In [0]:
!unzip './drive/My Drive/data/speech_enhance/clean_testset_16kHz.zip'

In [0]:
#%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)

2.2.0-rc3


In [0]:

import os
from glob import glob
import numpy as np
from scipy.io.wavfile import read as wavread
import tensorboardX as tbx
from functools import partial

from keras.layers import Input, Reshape, Concatenate, Activation, concatenate
from keras.layers import Conv2D, Conv2DTranspose, Dense, UpSampling2D
from keras.layers.advanced_activations import PReLU #Leaky ReLU
#from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.initializers import RandomNormal, Ones
from keras.optimizers import RMSprop, Adam

from keras import backend as K 
from keras.layers import Layer, Lambda
from keras.layers.merge import Concatenate

Using TensorFlow backend.


## 定義

In [0]:
clean_data_path = '/content/clean_testset_16kHz'
noisy_data_path = '/content/noisy_testset_16kHz'

MAX_WAV_VALUE = 32768.0 #int16の最大値
AUDIO_LENGTH = 16384

## データの取得 & 成形

[WSJデータセット](https://datashare.is.ed.ac.uk/handle/10283/1942)
soxにより56kHz -> 16kHzにダウンサンプリング

訓練データ *_trainset_16kHz.zip
テストデータ *_testset_16kHz.zip

※簡単のため、テストデータで訓練を行う


In [0]:
# データセットの存在確認
def confirm_dir(path):
  if not os.path.isdir(path):
    print(f"{path}が見つかりません")
  else:
    print(f"{path}が見つかりました")

confirm_dir(clean_data_path)
confirm_dir(noisy_data_path)

/content/clean_testset_16kHzが見つかりました
/content/noisy_testset_16kHzが見つかりました


In [0]:
# minibatch iterater
def iterate_minibatches(clean_data_path, noisy_data_path, audio_length=16384, batch_size=128):
  clean_filepaths = sorted(glob(os.path.join(clean_data_path, '*.wav')))
  noisy_filepaths = sorted(glob(os.path.join(noisy_data_path, '*.wav')))

  n_files = len(clean_filepaths)
  if n_files != len(noisy_filepaths):
    raise Exception("データセットがミスマッチです")
  
  cur_batch = 0

  while True:
    # 最初と最後で、データセットを初期化
    if (n_files - cur_batch * batch_size) < batch_size or cur_batch == 0:
      ids = np.random.choice(range(n_files), n_files, replace=False) # replace 重複禁止
      #np.random.shuffle(ids)     
      cur_batch = 0
    
    # データを読み込み一定の長さに成形
    train_data = []
    for i in range(batch_size):
      # sampling rate, audio data (パワーx時間)
      sr_clean, clean_data = wavread(clean_filepaths[ids[cur_batch * batch_size + i]])
      sr_noisy, noisy_data = wavread(noisy_filepaths[ids[cur_batch * batch_size + i]])

      start_id = np.random.randint(0, len(clean_data) - audio_length)

      clean_data = clean_data[start_id: start_id + audio_length]
      noisy_data = noisy_data[start_id: start_id + audio_length]
      clean_data = clean_data[None, :, None] # (1, x, 1, y, z)
      noisy_data = noisy_data[None, :, None]

      train_data.append([clean_data, noisy_data])
      #cur_batch = (cur_batch + 1) % int()

    cur_batch = (cur_batch + 1) % int(len(clean_filepaths)/batch_size)
    train_data = np.array(train_data) / MAX_WAV_VALUE

    yield train_data, cur_batch

In [0]:
#test = iterate_minibatches(clean_data_path, noisy_data_path)


# Speech Enhancement GAN

In [0]:
weight_init = RandomNormal(mean=0., stddev=0.02)

### 識別器の実装

In [0]:
def build_segen_discriminator(noisy_input_shape, clean_input_shape, 
                              n_filters = [64, 128, 256, 512, 1024], kernel_size=(1, 31)):
  print("Start segen discriminator build")
  clean_input = Input(shape=clean_input_shape)
  noisy_input = Input(shape=noisy_input_shape)

  # Conditional gan と同様に、channel dimをconcatenate
  x = Concatenate(-1)([clean_input, noisy_input])

  #以下n_filtersに則った畳み込み
  for i in range(len(n_filters)):
    #時系列データのため、width方向に長いkernel, strideを使用
    x = Conv2D(filters=n_filters[i], kernel_size=kernel_size, strides=(1, 4), 
               padding='same', use_bias=True, kernel_initializer=weight_init)(x)
    x = BatchNormalization(epsilon=1e-5, momentum=0.1)(x)
    x = PReLU()(x)

  print(x.shape)
  x = Reshape((16384, ))(x)

  x = Dense(256, activation=None, use_bias=True)(x)
  x = PReLU()(x)
  x = Dense(128, activation=None, use_bias=True)(x)
  x = PReLU()(x)

  #Least square => activation linear
  x = Dense(1, activation=None, use_bias=True)(x)
    
  model = Model(inputs=[noisy_input, clean_input], output=x, name="Discriminator")
    
  print("Build SEGEN Discriminator")
  #model.summary()
  return model

## 生成器の実装

U-Netの構成を流用

[カスタムKerasレイヤーを作成](https://keras.io/ja/layers/writing-your-own-keras-layers/)

In [0]:
#from keras.engine.topology import Layer
# https://github.com/keras-team/keras/issues/7736
# self.weightsはoverrideできない
class ScaleLayer(Layer):
    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(ScaleLayer, self).__init__(**kwargs)
    def build(self, input_shape):
        self.a_weights = self.add_weight(
            name='scale_weights', shape=(input_shape[1], self.output_dim), 
            initializer=Ones(), trainable=True)
        super(ScaleLayer, self).build(input_shape)
    def call(self, x):
        return x * self.a_weights #K.dot(x, self.a_weights)
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

In [0]:
def build_segan_generator(noisy_input_shape, z_input_shape,
                          n_filters=[64, 128, 256, 512, 1024], kernel_size=(1, 31), use_upsampling=False):
  noisy_input = Input(shape=noisy_input_shape)
  z_input = Input(shape=z_input_shape)
  
  skip_connection = []

  x = noisy_input
  print("Create Unet based model")

  for i in range(len(n_filters)):
    x = Conv2D(filters=n_filters[i], kernel_size=kernel_size, strides=(1, 4), 
               padding='same', use_bias=True, kernel_initializer=weight_init)(x)
    #x = BatchNormalization(epsilon=1e-5, momentum=0.1)(x)
    x = PReLU()(x)
    #skip_connection.append(ScaleLayer(n_filters[i])(x))
    skip_connection.append(x)
  
  print("add skip connection layer")
  n_filters = [1] +  n_filters[:-1]
  x = z_input 
  
  for i in range(len(n_filters)-1, -1, -1):
    sk = skip_connection[i]
    x = Concatenate(3)([x, sk]) # channel dim をconcatenate
    # https://qiita.com/MuAuan/items/69dda8bd4013007d7b18
    # UpSampling2Dの方がきれいな結果になる???
    
    if use_upsampling:
      x = UpSampling2D(size=(1, 4))(x)
      x = Conv2D(filters=n_filters[i], kernel_size=kernel_size, strides=(1, 1), 
                 padding='same', kernel_initializer=weight_init, use_bias=True)(x)
    else:
      x = Conv2DTranspose(filters=n_filters[i], kernel_size=kernel_size,
                strides=(1, 4), padding='same',
                kernel_initializer=weight_init)(x)
    if i > 0:
      x = PReLU()(x)
    else:
      # 最終層にtanhはよくやるそう,,,
      x = Activation('tanh')(x)
  
  #x = Lambda(lambda x: Activation('tanh')(x), output_shape=noisy_output_shape(x)
  
  print(x.shape)
  model = Model(inputs=[noisy_input, z_input], outputs=x, name='Generator')
  print("Build SEGEN Generator")

  #model.summary()
  return model
  


# モデル学習のためのロギング

In [0]:
def log_audio(audios, logger, name="synthesis", sr=16000):
    for i in range(len(audios)):
        logger.add_audio('{}_{}'.format(name, i), audios[i], 0, sample_rate=sr)


def log_losses(loss_d, loss_g, iteration, logger):
    logger.add_scalar("losses_d", loss_d[0], iteration)
    logger.add_scalar("losses_d_real", loss_d[1], iteration)
    logger.add_scalar("losses_d_fake", loss_d[2], iteration)
    logger.add_scalar("losses_g", loss_g[0], iteration)
    logger.add_scalar("losses_g_fake", loss_g[1], iteration)
    logger.add_scalar("losses_g_reconstruction", loss_g[2], iteration)

# モデルの学習

In [0]:
physical_devices = tf.config.list_physical_devices('GPU') 
print(physical_devices)
#https://stackoverflow.com/questions/57062456/function-call-stack-keras-scratch-graph-error
if physical_devices:
    try:
        # Restrict TensorFlow to only use the fourth GPU
        tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')
        tf.config.experimental.set_memory_growth(physical_devices[0], True) 
        # Currently, memory growth needs to be the same across GPUs
            
        print("Set growth")
        #logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        #print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print("Error")
        print(e)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Set growth


In [0]:
def train(ndf=64, ngf=64, use_upsampling=False, lr_d=5e-5, lr_g=5e-5, reconstruction_weight=100, 
          audio_length=16384, n_cp_audio=8, n_iter=int(1e6), batch_size=256, out_dir='data'):
  """
  ndf, ngf: discriminatorとgenerator フィルタの数
  use_upsampling: upsampringかtransposed convolutionsか
  """
  logger = tbx.SummaryWriter(out_dir)
  data_iterator = iterate_minibatches(clean_data_path, noisy_data_path, audio_length, batch_size)

  clean_input_shape = (1, 16384, 1)
  noisy_input_shape = (1, 16384, 1)
  z_input_shape = (1, 16, 1024)

  # instance Model Class
  print("Build Model")
  D = build_segen_discriminator(noisy_input_shape, clean_input_shape) 
  print("D build")
  G = build_segan_generator(noisy_input_shape, z_input_shape, use_upsampling=use_upsampling)
  print("G build")
  
  z_input = Input(shape=z_input_shape)
  noisy_input = Input(shape=noisy_input_shape, name="main_input_clean")
  clean_input = Input(shape=clean_input_shape, name="main_input_noisy")
  
  print("Create IO")
  G_out = G([noisy_input, z_input])
  #G = Model([noisy_input, z_input], G_out)
  
  print("Denoise")

  D_out = D([noisy_input, G_out])
  #D = Model([noisy_input, clean_input], D_out)

  print("Compile Model")
  D.compile(optimizer=Adam(learning_rate=lr_d), loss='mean_squared_error')
  #G.compile(optimizer=Adam(learning_rate=lr_d), loss='mean_squared_error')
  print("Compiled D")
 
  D.trainable = False
  
  print("Create Fa")
  print(D_out, G_out)
  D_of_G = Model(inputs=[noisy_input, z_input, clean_input], outputs=[D_out, G_out])

  print(D_out.shape, G_out.shape)
  print("Compiled G")
  """
  def mean_absolute_error(y_true, y_pred, denoised_audio, clean_audio):
    return K.mean(K.abs(clean_audio - denoised_audio))
   
  loss_reconstruction = partial(mean_absolute_error, denoised_audio=G_out, clean_audio=clean_input)
  """
  #Finally Compile
  D_of_G.compile(optimizer=Adam(learning_rate=lr_g), loss=['mean_squared_error', 'mean_absolute_error'], loss_weights=[1, reconstruction_weight])
  
  
  ones = np.ones((batch_size, 1), dtype=np.float32)
  zeros = np.zeros((batch_size, 1), dtype=np.float32)
  dummy = np.zeros((batch_size, 1), dtype=np.float32)

  # 評価用に固定zを作成
  z_fixed = np.random.normal(0, 1, size=(n_cp_audio,) + z_input_shape)
  data_batch, cur_batch = next(data_iterator)
  clean_fixed = data_batch[:n_cp_audio, 0]
  noisy_fixed = data_batch[:n_cp_audio, 1]
  log_audio(clean_fixed[:, 0, :, 0], logger, 'clean')
  log_audio(noisy_fixed[:, 0, :, 0], logger, 'noisy')

  epoch = 0
  
  for i in range(n_iter):
    if cur_batch == 1:
      G.trainable = False
      fake_audio = G.predict([noisy_fixed, z_fixed])
      log_audio(fake_audio[:n_cp_audio, 0, :, 0], logger, 'denoised')
      epoch += 1
    D.trainable = True
    G.trainable = False

    z = np.random.normal(0, 1, size=(batch_size, ) + z_input_shape)
    data_batch, cur_batch = next(data_iterator)
    clean_batch = data_batch[:, 0]
    noisy_batch = data_batch[:, 1]
    fake_batch = G.predict([noisy_batch, z])
    loss_real = D.train_on_batch([noisy_batch, clean_batch], ones)
    loss_fake = D.train_on_batch([noisy_batch, fake_batch], zeros)
    loss_d = [loss_real + loss_fake, loss_real, loss_fake]

    D.trainable = False
    G.trainable = True
    data_batch, cur_batch = next(data_iterator)
    clean_batch = data_batch[:, 0]
    noisy_batch = data_batch[:, 1]
    z = np.random.normal(0, 1, size=(batch_size, ) + z_input_shape)
    loss_g = D_of_G.train_on_batch([noisy_batch, z, clean_batch],[ones, clean_batch])
    fake_audio = G.predict([noisy_fixed, z_fixed])
    log_losses(loss_d, loss_g, i, logger)
    print("nxt_batch", cur_batch, "min", fake_audio.min(), "max", fake_audio.max())



In [0]:
train(n_iter = 500)

Build Model
Start segen discriminator build
(None, 1, 16, 1024)
Build SEGEN Discriminator
D build
Create Unet based model
add skip connection layer
(None, None, None, 1)
Build SEGEN Generator
G build
Create IO




Denoise
Compile Model
Compiled D
Create Fa
Tensor("Discriminator_1/dense_6/BiasAdd:0", shape=(None, 1), dtype=float32) Tensor("Generator_1/activation_2/Tanh:0", shape=(None, None, None, 1), dtype=float32)
(None, 1) (None, None, None, 1)
Compiled G
nxt_batch 0 min -0.48481172 max 0.49202752
nxt_batch 2 min -0.45815352 max 0.42805457
nxt_batch 1 min -0.4280539 max 0.36807194
nxt_batch 0 min -0.3953347 max 0.3635697
nxt_batch 2 min -0.36276612 max 0.36627218
nxt_batch 1 min -0.33817253 max 0.35781533
nxt_batch 0 min -0.3291124 max 0.3442443
nxt_batch 2 min -0.31024712 max 0.32682407
nxt_batch 1 min -0.2840935 max 0.30551782
nxt_batch 0 min -0.25518474 max 0.28293547
nxt_batch 2 min -0.2322183 max 0.2605797
nxt_batch 1 min -0.22377825 max 0.23673826
nxt_batch 0 min -0.21318962 max 0.21328878
nxt_batch 2 min -0.20032485 max 0.19538715
nxt_batch 1 min -0.18513227 max 0.1821143
nxt_batch 0 min -0.1756675 max 0.16911241
nxt_batch 2 min -0.16364497 max 0.15748526
nxt_batch 1 min -0.15226272 max

In [0]:
!nvidia-smi

In [0]:
%load_ext tensorboard
%tensorboard --logdir ./data

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 1007), started 1:47:36 ago. (Use '!kill 1007' to kill it.)

<IPython.core.display.Javascript object>