# Spectrogram Channels U-net

* SPECTROGRAM-CHANNELS U-NET: A SOURCE SEPARATION MODEL VIEWING EACH CHANNEL AS THE SPECTROGRAM OF EACH SOURCE, [arXiv:1810.11520](https://arxiv.org/abs/1810.11520)
  * Jaehoon Oh∗, Duyeon Kim∗, Se-Young Yun

## Import modules

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
# Training Flags (hyperparameter configuration)
model_name = 'spectrogram_unet'
train_dir = 'train/' + model_name + '/exp1/'
max_epochs = 200
save_model_epochs = 20
print_steps = 1
batch_size = 8
learning_rate = 2e-4
N = 100 # number of samples in train_dataset

BUFFER_SIZE = N

## Set up dataset with `tf.data`

In [3]:
data_path = './datasets/spectrogram/'
train_data_filenames = [os.path.join(data_path, 'train', name)
                        for name in os.listdir(os.path.join(data_path, 'train')) if 'tfrecord' in name]
for name in train_data_filenames:
  print(name)

./datasets/spectrogram/train/spectrogram_train_00000-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00001-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00002-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00003-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00004-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00005-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00006-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00007-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00008-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00009-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00010-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00011-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00012-of-00020.tfrecord
./datasets/spectrogram/train/spectrogram_train_00013-of-00020.tfrecord
./data

In [4]:
def _parse_function(example_proto):
  features = {'spec_raw': tf.FixedLenFeature([], tf.string, default_value=""),
              'frequency_bin': tf.FixedLenFeature([], tf.int64, default_value=0),
              'time_step': tf.FixedLenFeature([], tf.int64, default_value=0),
              'channel': tf.FixedLenFeature([], tf.string, default_value=""),
              'track_number': tf.FixedLenFeature([], tf.int64, default_value=0),
              'split_number': tf.FixedLenFeature([], tf.int64, default_value=0),}
  
  parsed_features = tf.parse_single_example(example_proto, features)

  spec_raw = tf.decode_raw(parsed_features["spec_raw"], out_type=tf.float32)
  frequency_bin = tf.cast(parsed_features["frequency_bin"], dtype=tf.int32)
  time_step = tf.cast(parsed_features["time_step"], dtype=tf.int32)
  #channel = tf.cast(parsed_features["channel"], dtype=tf.string)
  #track_number = tf.cast(parsed_features["track_number"], dtype=tf.int32)
  #split_number = tf.cast(parsed_features["split_number"], dtype=tf.int32)
  
  num_channels = 6 # for [mixtures, vocals, drums, basses, others, accompaniments]
  spec_raw = tf.reshape(spec_raw, shape=[frequency_bin, time_step, num_channels])

  return spec_raw, time_step

In [5]:
def _augmentation_function(spec_raw, time_step):
  """Random cropping for data augmentation
  """
  target_time_step = 128 # our input size
  available_time_step = time_step - target_time_step
  
  crop_index = tf.random_uniform(shape=[]) * tf.cast(available_time_step, dtype=tf.float32)
  crop_index = tf.cast(crop_index, dtype=tf.int32)
  spec_raw_crop = spec_raw[:, crop_index:crop_index+target_time_step, :]
  
  return spec_raw_crop[..., 0:1], spec_raw_crop[..., 1:]

In [6]:
train_dataset = tf.data.TFRecordDataset(train_data_filenames)
train_dataset = train_dataset.map(_parse_function)
train_dataset = train_dataset.map(_augmentation_function)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(batch_size)

## Build the model

In [7]:
class Downsample(tf.keras.Model):
    
  def __init__(self, filters, size, apply_batchnorm=True):
    super(Downsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    initializer = tf.random_normal_initializer(0., 0.02)

    self.conv1 = tf.keras.layers.Conv2D(filters, 
                                        (size, size), 
                                        strides=2, 
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)
    if self.apply_batchnorm:
        self.batchnorm = tf.keras.layers.BatchNormalization()
  
  def call(self, x, training):
    x = self.conv1(x)
    if self.apply_batchnorm:
        x = self.batchnorm(x, training=training)
    x = tf.nn.leaky_relu(x) # a = 0.2
    return x 

In [8]:
class Upsample(tf.keras.Model):
    
  def __init__(self, filters, size, apply_dropout=False):
    super(Upsample, self).__init__()
    self.apply_dropout = apply_dropout
    initializer = tf.random_normal_initializer(0., 0.02)

    self.up_conv = tf.keras.layers.Conv2DTranspose(filters, 
                                                   (size, size), 
                                                   strides=2, 
                                                   padding='same',
                                                   kernel_initializer=initializer,
                                                   use_bias=False)
    self.batchnorm = tf.keras.layers.BatchNormalization()
    if self.apply_dropout:
        self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, x1, x2, training):
    x = self.up_conv(x1)
    x = self.batchnorm(x, training=training)
    if self.apply_dropout:
        x = self.dropout(x, training=training)
    x = tf.nn.relu(x)
    x = tf.concat([x, x2], axis=-1)
    return x

In [14]:
class SpectrogramChannelsUNet(tf.keras.Model):
  def __init__(self):
    super(SpectrogramChannelsUNet, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)
    
    self.down1 = Downsample(16, 5)
    self.down2 = Downsample(32, 5)
    self.down3 = Downsample(64, 5)
    self.down4 = Downsample(128, 5)
    self.down5 = Downsample(256, 5)
    self.down6 = Downsample(512, 5)
    
    self.up1 = Upsample(256, 5, apply_dropout=True)
    self.up2 = Upsample(128, 5, apply_dropout=True)
    self.up3 = Upsample(64, 5, apply_dropout=True)
    self.up4 = Upsample(32, 5)
    self.up5 = Upsample(16, 5)
    
    self.last = tf.keras.layers.Conv2DTranspose(1, #OUTPUT_CHANNELS
                                                (5, 5), 
                                                strides=2, 
                                                padding='same',
                                                kernel_initializer=initializer)
  
  @tf.contrib.eager.defun
  def call(self, x, training):
    # x shape == (bs, 512, 128, 1)    
    x1 = self.down1(x, training=training) # (bs, 256, 64, 16)
    x2 = self.down2(x1, training=training) # (bs, 128, 32, 32)
    x3 = self.down3(x2, training=training) # (bs, 64, 16, 64)
    x4 = self.down4(x3, training=training) # (bs, 32, 8, 128)
    x5 = self.down5(x4, training=training) # (bs, 16, 4, 256)
    x6 = self.down6(x5, training=training) # (bs, 8, 2, 512)    

    x7 = self.up1(x6, x5, training=training) # (bs, 16, 4, 256)
    x8 = self.up2(x7, x4, training=training) # (bs, 32, 8, 128)
    x9 = self.up3(x8, x3, training=training) # (bs, 64, 16, 64)
    x10 = self.up4(x9, x2, training=training) # (bs, 128, 32, 32)
    x11 = self.up5(x10, x1, training=training) # (bs, 256, 64, 16)
    
    x12 = self.last(x11) # (bs, 512, 128, 1)
    x13 = tf.nn.tanh(x12)

    return x13

In [15]:
model = SpectrogramChannelsUNet()

# Train

In [16]:
#tf.constant([1, 2]) * tf.constant([2, 4])

In [57]:
# The loss function to be optimized
def loss(mask, mixture, target):
    mae = tf.keras.losses.MAE((mask * mixture), targets)
    return tf.reduce_mean(mae)

loss_history = []

In [63]:
opt = tf.train.AdamOptimizer()

for mixtures, targets in train_dataset.take(100): 
  with tf.GradientTape() as tape:
    mask = model(mixtures, training=True)
    loss_value = loss(mask, mixtures, targets[...,1:2])
  
  grads = tape.gradient(loss_value, model.variables)
  opt.apply_gradients(zip(grads, model.variables),
                      global_step=tf.train.get_or_create_global_step())

  loss_history.append(loss_value.numpy())
  print(loss_value.numpy())

0.17438972
0.13549604
0.1627253
0.16768746
0.15774262
0.1581802
0.18270281
0.15513977
0.13489163
0.13389716
0.14207926
0.137745
0.14833662
0.13203382
0.16116506
0.14065489
0.13761377
0.13966236
0.16630752
0.16217561
0.18096146
0.17479874
0.16842443
0.21351776
0.15991601
0.15410852
0.18181835
0.18500432
0.18466529
0.18146083
0.17682369
0.16609979
0.21200842
0.1715177
0.15307137
0.16380733
0.17596932
0.1776322
0.17703666
0.1689827
0.16732062
0.16513914
0.14752947
0.17941774
0.18015602
0.18523388
0.1548334
0.16107185
0.17408752
0.180076
0.13717848
0.13075906
0.16254312
0.17640615
0.1603344
0.16533434
0.14193028
0.18052402
0.14285341
0.15650919
0.16770941
0.15208648
0.22039983
0.17831582
0.15746094
0.17132315
0.21732365
0.18463273
0.15467647
0.15902914
0.16382195
0.21801022
0.18448743
0.17364337
0.19178307
0.22444627
0.15017024
0.18999913
0.17095539
0.2092202
0.17465582
0.14467487
0.20117772
0.20129387
0.19109887
0.17886178
0.20767376
0.19907999
0.17359787
0.1794733
0.14190751
0.1532608
0.

0.20566222
