# Spectrogram Channels U-net

* Spectrogram-Channels U-Net: A Source Separation Model Viewing Each Channel as the Spectrogram of Each Source, [arXiv:1810.11520](https://arxiv.org/abs/1810.11520)
  * Jaehoon Oh∗, Duyeon Kim∗, Se-Young Yun

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'spectrogram_unet'
train_dir = 'train/' + model_name + '/exp1/'
max_epochs = 200
save_model_epochs = 20
print_steps = 1
batch_size = 8
learning_rate = 2e-4
N = 100 # number of samples in train_dataset

BUFFER_SIZE = N

## Set up dataset with `tf.data`

In [None]:
data_path = './datasets/spectrogram/'
train_data_filenames = [os.path.join(data_path, 'train', name)
                        for name in os.listdir(os.path.join(data_path, 'train')) if 'tfrecord' in name]
for name in train_data_filenames:
  print(name)

#test_data_filenames = [os.path.join(data_path, 'train', name)
#                       for name in os.listdir(os.path.join(data_path, 'test')) if 'tfrecord' in name]
#for name in   print(nametest_data_filenames:
#  print(name)

In [None]:
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an spectragram and corresponding informations, respectively.
def _parse_function(example_proto):
  features = {'spec_raw': tf.FixedLenFeature([], tf.string, default_value=""),
              'frequency_bin': tf.FixedLenFeature([], tf.int64, default_value=0),
              'time_step': tf.FixedLenFeature([], tf.int64, default_value=0),
              'channel': tf.FixedLenFeature([], tf.string, default_value=""),
              'track_number': tf.FixedLenFeature([], tf.int64, default_value=0),
              'split_number': tf.FixedLenFeature([], tf.int64, default_value=0),}
  
  parsed_features = tf.parse_single_example(example_proto, features)

  spec_raw = tf.decode_raw(parsed_features["spec_raw"], out_type=tf.float32)
  frequency_bin = tf.cast(parsed_features["frequency_bin"], dtype=tf.int32)
  time_step = tf.cast(parsed_features["time_step"], dtype=tf.int32)
  #channel = tf.cast(parsed_features["channel"], dtype=tf.string)
  #track_number = tf.cast(parsed_features["track_number"], dtype=tf.int32)
  #split_number = tf.cast(parsed_features["split_number"], dtype=tf.int32)
  
  num_channels = 6 # for [mixtures, vocals, drums, basses, others, accompaniments]
  spec_raw = tf.reshape(spec_raw, shape=[frequency_bin, time_step, num_channels])

  return spec_raw, time_step

In [None]:
def _augmentation_function(spec_raw, time_step):
  """Random cropping for data augmentation
  """
  target_time_step = 128 # our input size
  available_time_step = time_step - target_time_step
  
  crop_index = tf.random_uniform(shape=[]) * tf.cast(available_time_step, dtype=tf.float32)
  crop_index = tf.cast(crop_index, dtype=tf.int32)
  spec_raw_crop = spec_raw[:, crop_index:crop_index+target_time_step, :]
  
  return spec_raw_crop[..., 0:1], spec_raw_crop[..., 1:]

In [None]:
train_dataset = tf.data.TFRecordDataset(train_data_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.map(_augmentation_function)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(batch_size)

## Build the model

In [None]:
class ConvBlock(tf.keras.Model):
  def __init__(self, num_filters, size):
    super(ConvBlock, self).__init__()
    self.conv1 = layers.Conv2D(filters=num_filters,
                               kernel_size=(size, size),
                               padding='same',
                               use_bias=False)
    self.batchnorm1 = layers.BatchNormalization()
    self.conv2 = layers.Conv2D(filters=num_filters,
                               kernel_size=(size, size),
                               padding='same',
                               use_bias=False)
    self.batchnorm2 = layers.BatchNormalization()
  
  def call(self, x, training=True):
    x = self.conv1(x)
    x = self.batchnorm1(x, training=training)
    x = tf.nn.relu(x)
    x = self.conv2(x)
    x = self.batchnorm2(x, training=training)
    x = tf.nn.relu(x)
    
    return x

In [None]:
class EncoderBlock(tf.keras.Model):
  def __init__(self, num_filters, size):
    super(EncoderBlock, self).__init__()
    self.conv_block = ConvBlock(num_filters, 3)
    
  def call(self, x, training=True):
    encoder = self.conv_block(x, training=training)
    encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
  
    return encoder_pool, encoder

In [None]:
class ConvTransposeBlock(tf.keras.Model):
  def __init__(self, num_filters, size):
    super(ConvTransposeBlock, self).__init__()
    self.convT1 = layers.Conv2DTranspose(filters=num_filters,
                                         kernel_size=(size, size),
                                         padding='same',
                                         use_bias=False)
    self.batchnorm1 = layers.BatchNormalization()
    self.convT2 = layers.Conv2DTranspose(filters=num_filters,
                                         kernel_size=(size, size),
                                         padding='same',
                                         use_bias=False)
    self.batchnorm2 = layers.BatchNormalization()
  
  def call(self, x, training=True):
    x = self.convT1(x)
    x = self.batchnorm1(x, training=training)
    x = tf.nn.relu(x)    
    x = self.convT2(x)
    x = self.batchnorm2(x, training=training)
    x = tf.nn.relu(x)
    
    return x

In [None]:
class DecoderBlock(tf.keras.Model):
  def __init__(self, num_filters, size):
    super(DecoderBlock, self).__init__()
    self.convT = layers.Conv2DTranspose(filters=num_filters,
                                        kernel_size=(size+2, size+2),
                                        strides=(2, 2),
                                        padding='same',
                                        use_bias=False)
    self.batchnorm = layers.BatchNormalization()
    self.dropout = layers.Dropout(0.4)
    self.convT_block = ConvTransposeBlock(num_filters, size)
  
  def call(self, input_tensor, concat_tensor, training=True):
    # Upsampling
    x = self.convT(input_tensor)
    x = self.batchnorm(x, training=training)
    x = tf.nn.relu(x)
    x = self.dropout(x, training=training)
    
    # concatenate
    x = tf.concat([x, concat_tensor], axis=-1)
    
    # just two consecutive conv_transpose
    x = self.convT_block(x, training=training)
    
    return x

In [None]:
class SpectrogramChannelsUNet(tf.keras.Model):
  def __init__(self):
    super(SpectrogramChannelsUNet, self).__init__()
    self.down1 = EncoderBlock(32, 3)
    self.down2 = EncoderBlock(64, 3)
    self.down3 = EncoderBlock(128, 3)
    self.down4 = EncoderBlock(256, 3)
    self.center = ConvBlock(512, 3)

    self.up1 = DecoderBlock(256, 3)
    self.up2 = DecoderBlock(128, 3)
    self.up3 = DecoderBlock(64, 3)
    self.up4 = DecoderBlock(32, 3)

    self.last = layers.Conv2D(filters=5,
                              kernel_size=(1, 1),
                              padding='same')
  
  @tf.contrib.eager.defun
  def call(self, x, training):
    # x shape == (bs, 1024, 256, 1)
    x1_pool, x1 = self.down1(x, training=training) # (bs, 512, 128, 32)
    x2_pool, x2 = self.down2(x1_pool, training=training) # (bs, 256, 64, 64)
    x3_pool, x3 = self.down3(x2_pool, training=training) # (bs, 128, 32, 128)
    x4_pool, x4 = self.down4(x3_pool, training=training) # (bs, 64, 16, 256)
    x_center = self.center(x4_pool, training=training) # (bs, 64, 16, 512)

    x5 = self.up1(x_center, x4, training=training) # (bs, 128, 32, 256)
    x6 = self.up2(x5, x3, training=training) # (bs, 256, 64, 128)
    x7 = self.up3(x6, x2, training=training) # (bs, 512, 128, 64)
    x8 = self.up4(x7, x1, training=training) # (bs, 1024, 256, 32)

    x_last = self.last(x8) # (bs, 1024, 256, 5)
    x_last = tf.math.sigmoid(x_last)

    return x_last

In [None]:
model = SpectrogramChannelsUNet()

In [None]:
for inputs, targets in train_dataset.take(1):
  x = model(inputs, training=True)
  y = targets

In [None]:
model.summary()

## Defining loss functions and optimizers

In [None]:
def bce_loss(y_true, y_pred):
  loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
  loss = tf.reduce_mean(loss)
  return loss

In [None]:
optimizer = tf.train.AdamOptimizer(learning_rate)

In [None]:
checkpoint_dir = train_dir
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [None]:
tf.logging.info('Start Training.')
global_step = tf.train.get_or_create_global_step()
for epoch in range(max_epochs):
  
  for mixtures, targets in train_dataset:
    start_time = time.time()
    
    with tf.GradientTape() as tape:
      predictions = model(mixtures, training=True)
      loss = bce_loss(targets, predictions)

    gradients = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(zip(gradients, model.variables), global_step=global_step)
        
    epochs = global_step.numpy() * batch_size / float(N)
    duration = time.time() - start_time

    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), loss, examples_per_sec, duration))
      # generate sample image from random test image
      # the training=True is intentional here since
      # we want the batch statistics while running the model
      # on the test dataset. If we use training=False, we will get 
      # the accumulated statistics learned from the training dataset
      # (which we don't want)
      #for test_input, test_target in test_dataset.take(1):
      #  prediction = generator(test_input, training=True)
      #  print_or_save_sample_images(test_input, test_target, prediction)

#   if (epoch + 1) % save_images_epochs == 0:
#     display.clear_output(wait=True)
#     print("This images are saved at {} epoch".format(epoch+1))
#     prediction = generator(constant_test_input, training=True)
#     print_or_save_sample_images(constant_test_input, constant_test_target, prediction,
#                                 is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

#   # saving (checkpoint) the model every save_epochs
#   if (epoch + 1) % save_model_epochs == 0:
#     checkpoint.save(file_prefix = checkpoint_prefix)

## Restore the latest checkpoint and test

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Testing on the entire test dataset