# Spectrogram Channels U-net

* SPECTROGRAM-CHANNELS U-NET: A SOURCE SEPARATION MODEL VIEWING EACH CHANNEL AS THE SPECTROGRAM OF EACH SOURCE, [arXiv:1810.11520](https://arxiv.org/abs/1810.11520)
  * Jaehoon Oh∗, Duyeon Kim∗, Se-Young Yun

## Import modules

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import collections

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
#import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
# Training Flags (hyperparameter configuration)
model_name = 'spectrogram_unet'
train_dir = 'train/' + model_name + '/exp1/'
max_epochs = 1500
save_model_epochs = 20
print_steps = 1
batch_size = 8
learning_rate = 2e-4
N = 128 # number of samples in train_dataset

BUFFER_SIZE = N

## Set up dataset with `tf.data`

In [3]:
data_path = './datasets/tfrecords/'
train_data_filenames = [os.path.join(data_path, 'train', name)
                        for name in os.listdir(os.path.join(data_path, 'train')) if 'tfrecord' in name]
for name in train_data_filenames:
  print(name)
print('--------')
  
test_data_filenames = [os.path.join(data_path, 'test', name)
                       for name in os.listdir(os.path.join(data_path, 'test')) if 'tfrecord' in name]
for name in test_data_filenames:
  print(name)

./datasets/tfrecords/train/wav_train_00010-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00003-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00012-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00023-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00017-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00006-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00013-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00016-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00015-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00022-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00001-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00021-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00004-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00020-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00011-of-00025.tfrecord
./datasets/tfrecords/train/wav_train_00005-of-00025.tfrecord
./datasets/tfrecords/tra

In [4]:
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an spectragram and corresponding informations, respectively.
def _parse_function(example_proto):
  features = {'wav_concat': tf.io.FixedLenFeature([], tf.string, default_value=""),
              'time_step': tf.io.FixedLenFeature([], tf.int64, default_value=0),
              'track_number': tf.io.FixedLenFeature([], tf.int64, default_value=0),
              'split_number': tf.io.FixedLenFeature([], tf.int64, default_value=0),}
  
  parsed_features = tf.io.parse_single_example(example_proto, features)

  wav_concat = tf.io.decode_raw(parsed_features["wav_concat"], out_type=tf.float32)
  time_step = tf.cast(parsed_features["time_step"], dtype=tf.int32)
  #track_number = tf.cast(parsed_features["track_number"], dtype=tf.int32)
  #split_number = tf.cast(parsed_features["split_number"], dtype=tf.int32)
  
  num_sources = 6 # for [mixtures, vocals, drums, basses, others, accompaniments]
  num_channels = 2 # for left, right
  wav_concat = tf.reshape(wav_concat, shape=[num_sources, num_channels, time_step])

  return wav_concat, time_step

In [5]:
def _augmentation_function(wav_concat, time_step):
  """Random cropping for data augmentation
  We set that the input (spectrogram) shape is (num_frames, frequency_bin) = (128, 1024).
  Number of samples to needed: 128 x 512 (hop_size) = 131072
  Time of wave to needed: 131072 / 44100 = 2.97 sec
  """
  target_time_step = 131072
  available_time_step = time_step - target_time_step
  
  crop_index = tf.random.uniform(shape=[]) * tf.cast(available_time_step, dtype=tf.float32)
  crop_index = tf.cast(crop_index, dtype=tf.int32)
  wav_concat_crop = wav_concat[:, :, crop_index:crop_index+target_time_step]
  
  return wav_concat_crop

In [6]:
def _stft(wav, sources_to_wanted):
  """Short time Fourier Transform
  
  Args:
    wav: Tens
    sources_to_wanted: list of name of sources to wanted
        - full source names: ['vocals', 'drums', 'basses', 'others', 'accompaniments']
        
  Returns:
    stft: The results of stft of input wav
  """
  num = collections.OrderedDict()
  num['vocals'] = 1
  num['drums'] = 2
  num['basses'] = 3
  num['others'] = 4
  num['accompaniments'] = 5
  
  wav_to_wanted = collections.OrderedDict()
  wav_to_wanted['mixtures'] = wav[0]
  for key, value in num.items():
    if key in sources_to_wanted:
      wav_to_wanted[key] = wav[value]  
      
  stfts_db = []
  for key, wav in wav_to_wanted.items():
    stft = tf.contrib.signal.stft(wav,
                          frame_length=2048,
                          frame_step=512,
                          fft_length=2048,
                          pad_end=True)[..., :1024]
    
    # [num_channels, num_frames, frequency_bin]
    # -> [num_frames, frequency_bin, num_channels]
    stft = tf.transpose(stft, perm=[1, 2, 0])

    def _amp_to_db(x):
      log_offset = 1e-6
      return 20 * tf.math.log(x + log_offset) / tf.math.log(10.0) # natural logarithm -> logarithm based on 10
    # for our suggestion
    min_level_db = -110
    ref_level_db = 40
    stft_db = _amp_to_db(tf.math.abs(stft)) - ref_level_db
    stfts_db.append(stft_db)

  return stfts_db[0], tf.concat(stfts_db[1:], axis=-1)

In [7]:
def _stft_for_test(wav, sources_to_wanted):
  """Short time Fourier Transform
  
  Args:
    wav: Tens
    sources_to_wanted: list of name of sources to wanted
        - full source names: ['vocals', 'drums', 'basses', 'others', 'accompaniments']
        
  Returns:
    mixture_stft: The real part of stft of input wav
    targets_stft: The real part of stft of target (each source) wav
    targets_stft_imag: The imaginary part of stft of target (each source) wav
  """
  num = collections.OrderedDict()
  num['vocals'] = 1
  num['drums'] = 2
  num['basses'] = 3
  num['others'] = 4
  num['accompaniments'] = 5
  
  wav_to_wanted = collections.OrderedDict()
  wav_to_wanted['mixtures'] = wav[0]
  for key, value in num.items():
    if key in sources_to_wanted:
      wav_to_wanted[key] = wav[value]  
      
  stfts_db = []
  stfts_angle = []
  for key, wav in wav_to_wanted.items():
    stft = tf.contrib.signal.stft(wav,
                          frame_length=2048,
                          frame_step=512,
                          fft_length=2048,
                          pad_end=True)[..., :1024]
    
    # [num_channels, num_frames, frequency_bin]
    # -> [num_frames, frequency_bin, num_channels]
    stft = tf.transpose(stft, perm=[1, 2, 0])
    
    # separate real part and imaginary part
    stft_angle = tf.math.angle(stft)
    stfts_angle.append(stft_angle)
    
    def _amp_to_db(x):
      log_offset = 1e-6
      return 20 * tf.math.log(x + log_offset) / tf.math.log(10.0) # natural logarithm -> logarithm based on 10
    # for our suggestion
    min_level_db = -110
    ref_level_db = 40
    stft_db = _amp_to_db(tf.math.abs(stft)) - ref_level_db
    stfts_db.append(stft_db)

  return stfts_db[0], tf.concat(stfts_db[1:], axis=-1), tf.concat(stfts_angle[1:], axis=-1)

In [8]:
def inverse_stft(test_predictions, test_targets_angle):
  target_shape = test_predictions.shape
  zero_padding = tf.zeros(shape=[target_shape[0], target_shape[1], 1, target_shape[3]])
  pred = tf.concat([test_predictions, zero_padding], axis=2)
  angle = tf.concat([test_targets_angle, zero_padding], axis=2)
  
  def _db_to_amp(x):
    return tf.math.pow(10.0, x * 0.05)

  ref_level_db = 40
  magnitude = _db_to_amp(pred + ref_level_db)
  
  magnitude_complex = tf.dtypes.complex(magnitude, 0.)
  angle_complex = tf.dtypes.complex(0., angle)
  
  stfts = magnitude_complex * angle_complex
  
  # [batch_size, num_frames, frequency_bin, num_channels]
  # -> [batch_size, num_channels, num_frames, frequency_bin]
  stfts = tf.transpose(stfts, perm=[0, 3, 1, 2])
  
  inv_stfts = tf.contrib.signal.inverse_stft(stfts=stfts,
                                     frame_length=2048,
                                     frame_step=512)#,
#                                     fft_length=2048)

  #inv_stfts = librosa.core.istft(stft[0,0, :,:].numpy().T, hop_length=512, win_length=2048)
  
  return inv_stfts

In [9]:
train_dataset = tf.data.TFRecordDataset(train_data_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.map(_augmentation_function)
extracted_sources_list = ['vocals', 'accompaniments']
train_dataset = train_dataset.map(lambda x: _stft(x, extracted_sources_list))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(batch_size)

In [50]:
# To test for batch
for mixtures, targets in train_dataset.take(1):
  start_time = time.time()
  print(mixtures[0].shape)
  print(mixtures[1].shape)
  print(targets.shape)
  print("duration: {} sec".format(time.time() - start_time))

(256, 1024, 2)
(256, 1024, 2)
(8, 256, 1024, 4)
duration: 0.005666971206665039 sec


## Test dataset

In [56]:
test_dataset = tf.data.TFRecordDataset(test_data_filenames)
test_dataset = test_dataset.map(_parse_function)
test_dataset = test_dataset.map(_augmentation_function)
extracted_sources_list = ['vocals', 'accompaniments']
test_dataset = test_dataset.map(lambda x: _stft_for_test(x, extracted_sources_list))
test_dataset = test_dataset.batch(1)

In [57]:
# To test for batch
for mixtures, targets, targets_imaginary in test_dataset.take(1):
  start_time = time.time()
  print(mixtures.shape)
  print(targets.shape)
  print(targets_imaginary.shape)
  print("duration: {} sec".format(time.time() - start_time))

(1, 256, 1024, 2)
(1, 256, 1024, 4)
(1, 256, 1024, 4)
duration: 0.0005030632019042969 sec


## Build the model

In [14]:
class Downsample(tf.keras.Model):
    
  def __init__(self, filters, size, apply_batchnorm=True):
    super(Downsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    initializer = tf.random_normal_initializer(0., 0.02)

    self.conv1 = tf.keras.layers.Conv2D(filters, 
                                        (size, size), 
                                        strides=2, 
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)
    if self.apply_batchnorm:
        self.batchnorm = tf.keras.layers.BatchNormalization()
  
  def call(self, x, training):
    x = self.conv1(x)
    if self.apply_batchnorm:
        x = self.batchnorm(x, training=training)
    x = tf.nn.leaky_relu(x) # a = 0.2
    return x 

In [15]:
class Upsample(tf.keras.Model):
    
  def __init__(self, filters, size, apply_dropout=False):
    super(Upsample, self).__init__()
    self.apply_dropout = apply_dropout
    initializer = tf.random_normal_initializer(0., 0.02)

    self.up_conv = tf.keras.layers.Conv2DTranspose(filters, 
                                                   (size, size), 
                                                   strides=2, 
                                                   padding='same',
                                                   kernel_initializer=initializer,
                                                   use_bias=False)
    self.batchnorm = tf.keras.layers.BatchNormalization()
    if self.apply_dropout:
        self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, x1, x2, training):
    x = self.up_conv(x1)
    x = self.batchnorm(x, training=training)
    if self.apply_dropout:
        x = self.dropout(x, training=training)
    x = tf.nn.relu(x)
    x = tf.concat([x, x2], axis=-1)
    return x

In [59]:
class SpectrogramChannelsUNet(tf.keras.Model):
  def __init__(self):
    super(SpectrogramChannelsUNet, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)
    
    self.down1 = Downsample(16, 5)
    self.down2 = Downsample(32, 5)
    self.down3 = Downsample(64, 5)
    self.down4 = Downsample(128, 5)
    self.down5 = Downsample(256, 5)
    
    self.down6 = Downsample(512, 5)
    
    self.up1 = Upsample(256, 5, apply_dropout=True)
    self.up2 = Upsample(128, 5, apply_dropout=True)
    self.up3 = Upsample(64, 5, apply_dropout=True)
    self.up4 = Upsample(32, 5)
    self.up5 = Upsample(16, 5)
    
    self.last = tf.keras.layers.Conv2DTranspose(4, #OUTPUT_CHANNELS
                                                (5, 5), 
                                                strides=2, 
                                                padding='same',
                                                kernel_initializer=initializer)
  
  @tf.contrib.eager.defun
  def call(self, x, training):
    # x shape == (bs, 512, 128, 1)    
    x1 = self.down1(x, training=training) # (bs, 256, 64, 16)
    x2 = self.down2(x1, training=training) # (bs, 128, 32, 32)
    x3 = self.down3(x2, training=training) # (bs, 64, 16, 64)
    x4 = self.down4(x3, training=training) # (bs, 32, 8, 128)
    x5 = self.down5(x4, training=training) # (bs, 16, 4, 256)
    
    x6 = self.down6(x5, training=training) # (bs, 8, 2, 512) - Center    

    x7 = self.up1(x6, x5, training=training) # (bs, 16, 4, 256)
    x8 = self.up2(x7, x4, training=training) # (bs, 32, 8, 128)
    x9 = self.up3(x8, x3, training=training) # (bs, 64, 16, 64)
    x10 = self.up4(x9, x2, training=training) # (bs, 128, 32, 32)
    x11 = self.up5(x10, x1, training=training) # (bs, 256, 64, 16)
    
    x12 = self.last(x11) # (bs, 512, 128, 1)
    x13 = tf.nn.tanh(x12) # (bs, 512, 128, 1)
    
    #mask = np.multiply(x, x13)
    
    return x13 #training

In [60]:
model = SpectrogramChannelsUNet()

In [61]:
# To fetch inputs to model in order to set to input shape
for mixtures, targets in train_dataset.take(1):
  predictions = model(mixtures, training=True)

In [62]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
downsample_12 (Downsample)   multiple                  864       
_________________________________________________________________
downsample_13 (Downsample)   multiple                  12928     
_________________________________________________________________
downsample_14 (Downsample)   multiple                  51456     
_________________________________________________________________
downsample_15 (Downsample)   multiple                  205312    
_________________________________________________________________
downsample_16 (Downsample)   multiple                  820224    
_________________________________________________________________
downsample_17 (Downsample)   multiple                  3278848   
_________________________________________________________________
upsample_10 (Upsample)       multiple                  3277824   
__________

# Train

In [20]:
def singing_voice_loss(y_true, y_pred, alpha=1.0):
  # Use the mean_absolute_error
  #mae = tf.keras.losses.MeanAbsoluteError()
  #loss_vocal = mae(y_true[..., :2], y_pred[..., :2])
  #loss_accompaniments = mae(y_true[..., 2:], y_pred[..., 2:])
  
  loss_vocal = tf.losses.absolute_difference(y_true[..., :2], y_pred[..., :2])
  loss_accompaniments = tf.losses.absolute_difference(y_true[..., 2:], y_pred[..., 2:])
  
  return alpha * loss_vocal + (1. - alpha) * loss_accompaniments

In [21]:
#optimizer = tf.keras.optimizers.Adam(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)

In [22]:
checkpoint_dir = train_dir
#if not tf.io.gfile.exists(checkpoint_dir):
#  tf.io.gfile.makedirs(checkpoint_dir)
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [None]:
print('Start Training.')
global_step = 1

for epoch in range(max_epochs):
#for epoch in range(1):
  
  for mixtures, targets in train_dataset.take(1):
    start_time = time.time()
    
    with tf.GradientTape() as tape:
      predictions = model(mixtures, training=True)
      loss = singing_voice_loss(targets, predictions)

    gradients = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(zip(gradients, model.variables))
        
    epochs = global_step * batch_size / float(N)
    duration = time.time() - start_time

    if global_step % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step, loss, examples_per_sec, duration))
      for test_mixtures, test_targets, test_targets_angle in test_dataset.take(1):
        test_predictions = model(test_mixtures, training=False)
        test_wav = inverse_stft(test_predictions, test_targets_angle)
    global_step = global_step + 1

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix=checkpoint_prefix)

Epochs: 0.19 global_step: 3 loss: 104.520 (105.21 examples/sec; 0.076 sec/batch)


In [None]:
for test_mixtures, test_targets, test_targets_angle in test_dataset.take(1):
  test_predictions = model(test_mixtures, training=False)
  test_wav = inverse_stft(test_predictions, test_targets_angle)

In [None]:
display.Audio(test_wav[0, :2, :], rate=44100)