# Convert wave to tfrecords format

* Sampling down: 44100Hz
* musdb data
  * `train_data`: 100 tracks
    * the shortest time track: 12.91 sec for 66th track
    * the largest time track: 628.38 sec for 51th track
  * `test_data`: 50 tracks
    * the shortest time track: 76.00 sec for 29th track
    * the largest time track: 430.20 sec for 15th track
* So, I will split wave files into `5 second` segments
  * for example, track = 17 sec
  * part1: [0:5] sec, part2: [5:10] sec, part3: [10:17] sec
  * add ramainder part (last 2sec segmentation) to last part

In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import librosa
import librosa.display
import musdb

from IPython import display

import tensorflow as tf

In [None]:
mus = musdb.DB(root_dir='./datasets/musdb18/')

In [None]:
# load the training tracks
split_name = 'train'
assert split_name in ['train', 'test']

tracks = mus.load_mus_tracks(subsets=[split_name])
print(type(tracks))
print(len(tracks))

In [None]:
# check for total time (minute, sec) info
# min_sec = 1000.0
# max_sec = 0.0
# for i, track in enumerate(tracks):
#   sec = track.audio.T.shape[1] / 44100
#   if sec < min_sec:
#     min_sec = sec
#   if sec > max_sec:
#     max_sec = sec
  
#   minute = int(sec / 60)
#   sec = sec - minute * 60
#   print("{}th: {} min {:.2f} sec".format(i, minute, sec))
# print("the shortest time track: {:.2f} sec".format(min_sec))
# print("the largest time track: {:.2f} sec".format(max_sec))

In [None]:
tracks[0].targets.keys()

In [None]:
index = 0
print(tracks[index].name)

### Listen the track

In [None]:
# original track - mixture
display.Audio(tracks[index].audio.T, rate=44100)

In [None]:
print(tracks[index].audio.T.shape)

In [None]:
# if you want to listen in each stem source then uncomment them
# display.Audio(tracks[index].targets['vocals'].audio.T, rate=44100)
# display.Audio(tracks[index].targets['drums'].audio.T, rate=44100)
# display.Audio(tracks[index].targets['bass'].audio.T, rate=44100)
# display.Audio(tracks[index].targets['other'].audio.T, rate=44100)
# display.Audio(tracks[index].targets['accompaniment'].audio.T, rate=44100)

## Plot for short time

In [None]:
# Separate to left and right channels
second = 10
left_wave = tracks[index].audio.T[0][:44100 * second]
right_wave = tracks[index].audio.T[1][:44100 * second]

In [None]:
# Plot the each channel
plt.figure(figsize=[18, 3])
plt.plot(left_wave)

plt.figure(figsize=[18, 3])
plt.plot(right_wave)
plt.show()

In [None]:
print(left_wave.shape)

In [None]:
# Short-time Fourier Transform
# n_fft: number of samples used to calculate fft
# hop_length: like concept of stride
left_stft = librosa.core.stft(left_wave, n_fft=2048, hop_length=512)
print(left_stft.shape)
print(type(left_stft[0, 0]))
print(left_stft[1, 1])

In [None]:
left_abs = abs(left_stft)
librosa.display.specshow(left_abs)
plt.colorbar()
plt.show()

librosa.display.specshow(librosa.amplitude_to_db(left_abs, ref=np.max))
plt.colorbar()
plt.show()

## Spectrogram using normalize (for maybe standard method)

In [None]:
# based on Tacotron
#min_level_db = -100
#ref_level_db = 20
# our suggestion
min_level_db = -110
ref_level_db = 40

In [None]:
def spectrogram(y):
  D = _stft(y)
  S = _amp_to_db(np.abs(D)) - ref_level_db
  return _normalize(S)

In [None]:
def _stft(y):
  #n_fft, hop_length, win_length = 2048, 512, 2048
  n_fft, hop_length, win_length = 1024, 512, 1024
  return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)

In [None]:
def _amp_to_db(x):
  return 20 * np.log10(np.maximum(1e-5, x))

In [None]:
def _normalize(S):
  return np.clip((S - min_level_db) / -min_level_db, 0, 1)

In [None]:
left_spec = spectrogram(left_wave)
print(left_spec.shape)

In [None]:
plt.figure(figsize=(16, 4))
librosa.display.specshow(left_spec)
plt.colorbar()
plt.show()

## Convert to tfrecords format

In [None]:
def int64_feature(values):
  """Returns a TF-Feature of int64s.

  Args:
    values: A scalar or list of values.

  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
  """Returns a TF-Feature of bytes.

  Args:
    values: A string.

  Returns:
    A TF-Feature.
  """
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def float_feature(values):
  """Returns a TF-Feature of floats.

  Args:
    values: A scalar of list of values.

  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(float_list=tf.train.FloatList(value=values))

In [None]:
def _get_dataset_filename(dataset_dir, split_name, shard_id, num_shards):
  output_filename = 'wav_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, num_shards)
  return os.path.join(dataset_dir, output_filename)

In [None]:
def convert_dataset(split_name, dataset_dir, N, num_shards):
  """Converts the wav of given tracks to a TFRecord dataset.

  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    dataset_dir: The directory where the converted datasets are stored.
    N: number of total examples # train: 100, test: 50
    num_shards: number of shards
  """
  assert split_name in ['train', 'test']

  # data split
  dataset_path = os.path.join(dataset_dir, split_name)
  print(dataset_path)
  if not tf.gfile.Exists(dataset_path):
    tf.gfile.MakeDirs(dataset_path)
  
  # for data suffling
  permutation_track_number = np.random.permutation(N)
  
  num_per_shard = int(N / float(num_shards))
  for shard_id in range(num_shards):
    output_filename = _get_dataset_filename(
              dataset_path, split_name, shard_id, num_shards)
    print('Writing', output_filename)

    # step 1
    with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
      start_ndx = shard_id * num_per_shard
      end_ndx = min((shard_id+1) * num_per_shard, N)

      for i in range(start_ndx, end_ndx):
        sys.stdout.write('\r>> Converting wave %d/%d shard %d\n' % (
            i+1, N, shard_id))
        sys.stdout.flush()

        mixtures = tracks[permutation_track_number[i]].audio.T
        vocals = tracks[permutation_track_number[i]].targets['vocals'].audio.T
        drums = tracks[permutation_track_number[i]].targets['drums'].audio.T
        basses = tracks[permutation_track_number[i]].targets['bass'].audio.T
        others = tracks[permutation_track_number[i]].targets['other'].audio.T
        accompaniments = tracks[permutation_track_number[i]].targets['accompaniment'].audio.T
        
        num_channels, number_of_samples = mixtures.shape
        assert (num_channels, number_of_samples) == vocals.shape
        assert (num_channels, number_of_samples) == drums.shape
        assert (num_channels, number_of_samples) == basses.shape
        assert (num_channels, number_of_samples) == others.shape
        assert (num_channels, number_of_samples) == accompaniments.shape

        sources = [mixtures, vocals, drums, basses, others, accompaniments]
        time_step_for_one_example = 44100 * 5 # 5 sec for one example
        num_split = int(number_of_samples / time_step_for_one_example)
        print("{}th track; num_split: {}".format(permutation_track_number[i], num_split))

        for split_index in range(num_split-1):
          if split_index > 2:
            break
          # step 2
          for k, wav in enumerate(sources):
            wav_split = wav[:, split_index*time_step_for_one_example:(split_index+1)*time_step_for_one_example]
            if k < 1:
              wav_concat = np.expand_dims(wav_split, axis=0)
            else:
              wav_concat = np.concatenate((wav_concat, np.expand_dims(wav_split, axis=0)), axis=0)
          
          time_step = wav_concat.shape[2]
          wav_concat_string = wav_concat.tostring()
          print("{}th track; {}th split: wav_shape: {}".format(permutation_track_number[i], split_index, wav_concat.shape))

          # step 3:
          features = tf.train.Features(feature={'wav_concat': bytes_feature(wav_concat_string),
                                                'time_step': int64_feature(time_step),
                                                'track_number': int64_feature(permutation_track_number[i]),
                                                'split_number': int64_feature(split_index),
                                               })

          # step 4
          example = tf.train.Example(features=features)

          # step 5
          tfrecord_writer.write(example.SerializeToString())

        # merge between last split part and residual part
        # step 2
        for k, wav in enumerate(sources):
          wav_split = wav[:, (num_split-1)*time_step_for_one_example:]
          if k < 1:
            wav_concat = np.expand_dims(wav_split, axis=0)
          else:
            wav_concat = np.concatenate((wav_concat, np.expand_dims(wav_split, axis=0)), axis=0)
          
        time_step = wav_concat.shape[2]
        wav_concat_string = wav_concat.tostring()
        print("{}th track; {}th split: wav_shape: {}".format(permutation_track_number[i], num_split, wav_concat.shape))

        # step 3:
        features = tf.train.Features(feature={'wav_concat': bytes_feature(wav_concat_string),
                                              'time_step': int64_feature(time_step),
                                              'track_number': int64_feature(permutation_track_number[i]),
                                              'split_number': int64_feature(num_split),
                                             })

        # step 4
        example = tf.train.Example(features=features)

        # step 5
        tfrecord_writer.write(example.SerializeToString())

In [None]:
tfrecords_dir = './datasets/tfrecords'
NUM_SHARDS = 25 # for train: 25, for test: 10
N = 100 # for train: 100, for test: 50
convert_dataset(split_name, tfrecords_dir, N, NUM_SHARDS)