# Single Channel Source Separation
Tensorflow implementation of convolutional autoencoder in [DeepConvSep](https://github.com/MTG/DeepConvSep) from Chanda et al., 2017.

## Initial setup
Something weird is happening where the GPU is not being recognized in the `elec576` conda env. Just stick with the `vizdoom` env for now.

In [1]:
# If using one or multiple GPUs
#import os
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
#from tensorflow.python.client import device_lib
#print(device_lib.list_local_devices())

In [2]:
import tensorflow as tf
import numpy as np
import math
import os, errno
import re
from time import time, sleep
#from tqdm import tnrange, tqdm_notebook

In [3]:
def make_directory(f):
    """Makes directory if does not already exist"""
    try:
        os.makedirs(f)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise

## Layer definitions
I was confused about the decoding portion of the network but read up on docs for clarification. According to the Lasagne docs:

---
The `InverseLayer` class performs inverse operations for a single layer of a neural network by applying the partial derivative of the layer to be inverted with respect to its input: transposed layer for a `DenseLayer`, deconvolutional layer for `Conv2DLayer`, `Conv1DLayer`; or an unpooling layer for `MaxPool2DLayer`.

It is specially useful for building (convolutional) autoencoders with tied parameters.

Note that if the layer to be inverted contains a nonlinearity and/or a bias, the InverseLayer will include the derivative of that in its computation.

---
For convolutional networks, applying the derivative of the layer with respect to its input (e.g. dCONV/dX) amounts to multiplying by the tranpose of its weight matrix (assuming no activation functions in between). That is why, for convolutional layers, `conv2d_transpose` (using the same weights) and `inverse layer` are equivalent.

In [4]:

def _check_list(arg):
    if isinstance(arg, list):
        try:
            return arg[0], arg[1:]
        except IndexError:
            return arg[0], []
    else:
        return arg, []

def _get_variable_initializer(init_type, var_shape, *args):
    if init_type == "random_normal":
        mean = float(args[0])
        stddev = float(args[1])
        return tf.random_normal(var_shape, mean=mean, stddev=stddev)
    elif init_type == "truncated_normal":
        mean = float(args[0])
        stddev = float(args[1])
        return tf.truncated_normal(var_shape, mean=mean, stddev=stddev)
    elif init_type == "constant":
        c = args[0]
        return tf.constant(c, dtype=tf.float32, shape=var_shape)
    elif init_type == "xavier":
        n_in = tf.cast(args[0], tf.float32)
        return tf.div(tf.random_normal(var_shape), tf.sqrt(n_in))
    else:
        raise ValueError("Variable initializer \"" + init_type + "\" not supported.")

def _apply_normalization(norm_type, x, *args, **kwargs):
    if norm_type == "batch_norm":
        return batch_norm(x, *args, **kwargs)
    else:
        raise ValueError("Normalization type \"" + norm_type + "\" not supported.")

def _apply_activation(activation_type, x, *args):
    if activation_type.lower() == "relu":
        return tf.nn.relu(x, name="Relu")
    elif activation_type.lower() == "leaky_relu":
        return tf.maximum(x, 0.1 * x, name="Leaky_Relu")
    elif activation_type.lower() == "softmax":
        return tf.nn.softmax(x)
    elif activation_type.lower() == "none":
        return x
    else:
        raise ValueError("Activation type \"" + activation_type + "\" not supported.")
        
def conv2d(input_layer,
           num_outputs,
           kernel_size,
           stride=1,
           padding="VALID",
           data_format="NCHW",
           normalizer_fn=None,
           activation_fn=None,
           weights_initializer="random_normal",
           biases_initializer=None,
           trainable=True,
           scope="CONV"):
    with tf.name_scope(scope):
        input_shape = input_layer.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            if data_format == "NHWC":
                input_channels = input_shape[3]
            elif data_format == "NCHW":
                input_channels = input_shape[1]
            W_shape = kernel_size + [input_channels, num_outputs]
            if W_init_type == "xavier":
                layer_shape = input_shape[1:]
                n_in = tf.reduce_prod(layer_shape)
                W_init_params = [n_in] 
            W_init = _get_variable_initializer(W_init_type,
                                                W_shape,
                                                *W_init_params)
        W = tf.Variable(W_init, 
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        

        # Convolute input
        stride_h, stride_w = _check_list(stride)
        if isinstance(stride_w, list):
            if len(stride_w) == 0:
                stride_w = stride_h
            else:
                stride_w = stride_w[0]
        if data_format == "NHWC":
            strides = [1, stride_h, stride_w, 1]
        elif data_format == "NCHW":
            strides = [1, 1, stride_h, stride_w]
        out = tf.nn.conv2d(input_layer, 
                            filter=W,
                            strides=strides,
                            padding=padding,
                            data_format=data_format,
                            name="convolution")
        
        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=data_format)
        
        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            if data_format == "NHWC":
                b_shape = [1, 1, 1, num_outputs]
            elif data_format == "NCHW":
                b_shape = [1, num_outputs, 1, 1]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")

        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out

def inverse_conv2d(x,
                   output_shape,
                   conv_weights,
                   conv_stride, 
                   padding="VALID",
                   data_format="NCHW",
                   scope="CONV_T"):
    with tf.name_scope(scope):
        # Get input shape
        x_shape = x.get_shape().as_list()
        
        # Ignore subtracting shared biases or adding new ones
        
        # Get shared weights from conv layer to be inverted
        # Change shape from [k_w, k_h, in_ch, out_ch] to [k_w, k_h, out_ch, in_ch]
        W = conv_weights
        #W = tf.transpose(conv_weights, perm=[0, 1, 3, 2])

        # Set stride
        stride_h, stride_w = _check_list(conv_stride)
        if isinstance(stride_w, list):
            if len(stride_w) == 0:
                stride_w = stride_h
            else:
                stride_w = stride_w[0]
        if data_format == "NHWC":
            strides = [1, stride_h, stride_w, 1]
        elif data_format == "NCHW":
            strides = [1, 1, stride_h, stride_w]
            
        # Perform convolutional transpose
        out = tf.nn.conv2d_transpose(x, 
                                     filter=W,
                                     output_shape=output_shape,
                                     strides=strides,
                                     padding=padding,
                                     data_format=data_format,
                                     name="convolution_transpose")
        
        return out
    
def flatten(input_layer, 
            data_format="NCHW",
            scope="FLAT"):
    with tf.name_scope(scope):
        # Grab runtime values to determine number of elements
        input_shape = tf.shape(input_layer)
        input_ndims = input_layer.get_shape().ndims
        batch_size = tf.slice(input_shape, [0], [1])
        layer_shape = tf.slice(input_shape, [1], [input_ndims-1])
        num_neurons = tf.expand_dims(tf.reduce_prod(layer_shape), 0)
        flattened_shape = tf.concat([batch_size, num_neurons], 0)
        if data_format == "NHWC":
            input_layer = tf.transpose(input_layer, perm=[0, 3, 1, 2])
        flat = tf.reshape(input_layer, flattened_shape)
        
        # Attempt to set values during graph building
        input_shape = input_layer.get_shape().as_list()
        batch_size, layer_shape = input_shape[0], input_shape[1:]
        if all(layer_shape): # None not present
            num_neurons = 1
            for dim in layer_shape:
                num_neurons *= dim
            flat.set_shape([batch_size, num_neurons])
        else: # None present
            flat.set_shape([batch_size, None])
        return flat

def fully_connected(input_layer,
                    num_outputs,
                    normalizer_fn=None,
                    activation_fn=None,
                    weights_initializer="random_normal",
                    biases_initializer=None,
                    trainable=True,
                    scope="FC"):
    with tf.name_scope(scope):
        input_shape = input_layer.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            W_shape = [input_shape[1], num_outputs]
            if W_init_type == "xavier":
                layer_shape = input_shape[1]
                n_in = tf.reduce_prod(layer_shape)
                W_init_params = [n_in]
            W_init = _get_variable_initializer(W_init_type,
                                            W_shape,
                                            *W_init_params)
        W = tf.Variable(W_init,
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        
        # Multiply inputs by weights
        out = tf.matmul(input_layer, W)

        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=None)

        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            b_shape = [num_outputs]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")
       
        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out

## Dataset handling
The training data will be fed in a given amount of files at a time specified by number of mixture-source feature file pairs.
There are two approaches to using mem map, and unfortunately neither seems to completely avoid moving large amounts of data upon initialization:
1. Use `np.memmap` following instructions from [this stackoverflow question](https://stackoverflow.com/questions/13780907/is-it-possible-to-np-concatenate-memory-mapped-files). The third (placeholder) array ends up writing all data to the file of the initial array. This in essence creates a single, giant array that contains the concatenated information from all songs along the time axis. While this would work, it doubles the amount of space on the hard drive if not deleted after training, and takes a long time (~30 min) to initialize if deleted after every use.
2. Use the `mmap_mode` arg in `np.load`. While this works for single files, any results of manipulation of the arrays (e.g. `np.concatenate` along time axis) are loaded into memory, which defeats the purpose of using mmap in the first place.

I think the best compromise is to use `np.load(filename, mmap_mode='r')` to point to the arrays and grab shapes initially, which can be used to track the global time point. Then some number of files can be loaded at a time that correspond to the number of time points to load at a time. Class variables can track the global time point, time point within loaded files, etc.

In [5]:
class Dataset(object):
    
    def __init__(self,
                 input_dir, 
                 target_dir, 
                 batch_size=32,
                 time_context=30,
                 mem_len=1e5,
                 load_by_file=False,
                 sources=['bass', 'drums', 'other', 'vocal'],
                 data_format="NHWC",
                 shuffle=True,
                 scale_factor=1.0,
                 verbose=True):
        # Grab arguments
        self.batch_size = batch_size
        self.time_context = time_context
        self.num_sources = len(sources)
        self.data_format = data_format
        self.shuffle = shuffle
        self.verbose = verbose
        self.load_by_file = load_by_file
        
        # Get data files
        self.input_files = sorted([os.path.join(input_dir, f) 
                                   for f in os.listdir(input_dir) if "mag" in f])
        target_files = sorted([os.path.join(target_dir, f) 
                               for f in os.listdir(target_dir) if "mag" in f])
        self.source_files = [] # [source_id][t]
        for i, s in enumerate(sources):
            self.source_files.append([t for t in target_files if s in t])
        
        # Read mem maps of data files
        self.inputs = []
        self.sources = [] # [t][source_id]
        self.shapes = []
        shapes_ = [] # placeholder to ensure all associated shapes equal
        for i in range(len(self.input_files)):
            if verbose:
                end = '\n' if i == len(self.input_files)-1 else '\r'
                print("Reading file %d of %d" % (i+1, len(self.input_files)), end=end)
            
            # Add mem map of mixture file
            f_in = np.load(self.input_files[i], mmap_mode='r')
            self.inputs.append(f_in)
            
            # Add mem maps of source files
            f_s = [np.load(self.source_files[j][i], mmap_mode='r') 
                   for j in range(len(sources))]
            self.sources.append(f_s)
            
            # Get shapes
            self.shapes.append(f_in.shape)
            shapes_.append([f.shape for f in f_s])
        
        # Set class variables
        self.shapes = np.asarray(self.shapes)
        self.t_total = np.sum(self.shapes[:, 0])
        self.feat_size = self.shapes[0, 1]
        
        # Check that shapes [index, ch, shape] are equal for each time point
        shapes_ = np.concatenate([self.shapes[:, np.newaxis, :], np.asarray(shapes_)], axis=1)
        if not (shapes_[:, :, 0].T == shapes_[:, 0, 0]).all():
            raise ValueError("All spectrograms must be of same length.")
        if not (shapes_[:, :, 1].T == shapes_[:, 0, 1]).all():
            raise ValueError("All spectograms must have same number of features.")
        shapes_ = None # release from memory
            
        # Set variables to track time position
        self.t_c = np.cumsum(self.shapes[:, 0])
        self.mem_len = mem_len
        self.t = 0 # current global time index
        self.t_ = 0 # current memory time index
        self.reset = False # if True, reset database
        self.load_next = False # if True, load next batch of files
        self.inputs_, self.sources_ = [], []
        self.load_memory()
    
    def load_memory(self):
        """Loads part of dataset into memory."""
        # Get indices of files corresponding to next mem_len time points or next file
        t_start = self.t
        t_end = self.t + self.mem_len
        idx_start = np.searchsorted(self.t_c, t_start, side='right')
        if self.load_by_file:
            idx_end = idx_start + 1
        else:
            idx_end = np.searchsorted(self.t_c, t_end, side='right')
        if idx_end - idx_start < 1:
            raise SyntaxError("Unable to load next file. Check mem_len (must be greater than file size)")
        if self.verbose:
            print("Loading files %d to %d of %d into memory..." % (idx_start+1, idx_end, len(self.inputs)))
        
        # Clear memory
        del self.sources_
        del self.inputs_
        
        # Load from mem maps of files into shape [time, features, [num_sources]]:
        # from inputs
        f_in = [self.inputs[t] for t in range(idx_start, idx_end)] # get list of next files
        #self.inputs_ = np.concatenate(f_in, axis=0) # loads into memory but SLOW
        t_shape = np.sum(self.shapes[idx_start:idx_end, 0]) # grab total shape ahead of time (much faster)
        self.inputs_ = np.zeros([t_shape, self.feat_size]) # set total memory block
        t_idx = 0 # running time index
        for i, f in enumerate(f_in):
            self.inputs_[t_idx:t_idx+self.shapes[idx_start+i, 0]] = f
            t_idx += self.shapes[idx_start+i, 0]
        
        # from sources
        f_s = [np.asarray(self.sources[t]) for t in range(idx_start, idx_end)]
        #self.sources_ = np.transpose(np.concatenate(f_s, axis=1), axes=[1, 2, 0]) # loads into memory
        self.sources_ = np.zeros([t_shape, self.feat_size, self.num_sources])
        t_idx = 0
        for i, f in enumerate(f_s):
            self.sources_[t_idx:t_idx+self.shapes[idx_start+i, 0]] = np.transpose(f, [1, 2, 0])
            t_idx += self.shapes[idx_start+i, 0]
        
        # Scale if specified
        self.inputs_ *= scale_factor
        self.sources_ *= scale_factor
        
        # Shuffle chunks of size time context
        if self.shuffle:
            self.shuffle_data()
        
        # Reset counter
        self.t_ = 0
    
    def shuffle_data(self):
        # This increases memory requirements by ~50% (copying self.inputs_ or self.sources_ 
        # during reshape operations. I don't know of a way to reshape in place without 
        # copying the array, which is affirmed in the numpy docs.
        if self.verbose:
            print("Shuffling data...")
        
        # Reshape data into chunks of size time_context and save ends
        concat = False
        if self.inputs_.shape[0] % self.time_context != 0:
            end_idx = self.inputs_.shape[0] // self.time_context * self.time_context
            end_inputs = self.inputs_[end_idx:]
            self.inputs_ = self.inputs_[:end_idx]
            end_sources = self.sources_[end_idx:]
            self.sources_ = self.sources_[:end_idx]
            concat = True
        self.inputs_ = self.inputs_.reshape([-1, self.time_context, self.feat_size])
        self.sources_ = self.sources_.reshape([-1, self.time_context, self.feat_size, self.num_sources])
        
        # Shuffle inputs and sources in unison
        rng_state = np.random.get_state()
        np.random.shuffle(self.inputs_)
        np.random.set_state(rng_state)
        np.random.shuffle(self.sources_)
        
        # Reshape data into original shape [time, features, [num_sources]]
        self.inputs_ = self.inputs_.reshape([-1, self.feat_size])
        self.sources_ = self.sources_.reshape([-1, self.feat_size, self.num_sources])
        if concat:
            self.inputs_ = np.concatenate([self.inputs_, end_inputs])
            self.sources_ = np.concatenate([self.sources_, end_sources])
    
    def reset_database(self):
        if self.verbose:
            print("Resetting database...")
        self.t = 0
        self.t_ = 0
        self.reset = False
        self.load_next = False
        self.load_memory()
    
    def create_batch(self):
        """Creates batch of training data from datset"""
        # Load memory if exhausted
        if self.reset:
            self.reset_database()
        elif self.load_next:
            self.load_memory()
        
        # Determine batch length and set loading bools for next batch
        batch_len = self.batch_size * self.time_context
        rem = 0
        self.reset = False
        self.load_next = False
        if self.t + batch_len > self.t_total: # reach end of all data
            batch_len = (self.t_total - self.t) // self.time_context * self.time_context
            rem = (self.t_total - self.t) % self.time_context
            self.reset = True
        elif self.t_ + batch_len > self.inputs_.shape[0]: # reach end of loaded data
            batch_len = (self.inputs_.shape[0] - self.t_) // self.time_context * self.time_context
            rem = (self.inputs_.shape[0] - self.t_) % self.time_context
            self.load_next = True
        
        # Get batches of inputs and sources
        inputs_batch = np.reshape(self.inputs_[self.t_:self.t_+batch_len],
                                  [-1, self.time_context, self.feat_size, 1])
        sources_batch = np.reshape(self.sources_[self.t_:self.t_+batch_len],
                                    [-1, self.time_context, self.feat_size, self.num_sources])
        if self.data_format == "NCHW":
            inputs_batch = np.transpose(inputs_batch, axes=[0, 3, 1, 2])
            sources_batch = np.transpose(sources_batch, axes=[0, 3, 1, 2])
        
        # Increment counters
        self.t_ += batch_len + rem 
        self.t  += batch_len + rem
        
        return inputs_batch, sources_batch
    
    def is_empty(self):
        return ds.reset
    
    def remaining_batches(self):
        return math.ceil((self.inputs_.shape[0] - self.t_) 
                         / (self.batch_size * self.time_context))
    
    def remaining_files(self):
        return len(self.inputs) - np.searchsorted(self.t_c, self.t, side='right')

## Graph building
For comparison, check out [this example](https://github.com/pkmital/tensorflow_tutorials/blob/master/python/09_convolutional_autoencoder.py) on github.

In [6]:
class Network:
    
    def __init__(self,
                 results_dir,
                 params_file=None,
                 data_format="NCHW",
                 time_context=30,
                 feat_size=513,
                 num_sources=4,
                 alpha=0.001,
                 mean_loss=True,
                 verbose=True,
                 train_mode=True,
                 scope=""):
        # Get args
        self.results_dir = results_dir
        self.data_format = data_format
        self.time_context = time_context
        self.feat_size = feat_size
        self.num_sources = num_sources
        self.alpha = alpha
        self.scope = scope
        self.verbose = verbose
        self.train_mode = train_mode
        
        # Build graph
        tf.reset_default_graph()
        with tf.name_scope(scope):
            self.build_graph(mean_loss)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        if params_file is not None:
            self.load_model(params_file)
        
        self.global_step = 0
        

    def build_graph(self, mean_loss):
        eps = 1e-18 # numerical stability
        self.graph = tf.get_default_graph()
        
        # Data formatting
        if self.verbose:
            print("Building input...")
        if self.data_format == "NHWC":
            input_shape = [None, self.time_context, self.feat_size, 1]
            target_shape = [None, self.time_context, self.feat_size, self.num_sources]
            channel_dim = 3
        elif self.data_format == "NCHW":
            input_shape = [None, 1, self.time_context, self.feat_size]
            target_shape = [None, self.num_sources, self.time_context, self.feat_size]
            channel_dim = 1
        else:
            raise ValueError("Unknown data format \"" + self.data_format + "\"")
        self.spectrogram = tf.placeholder(tf.float32, 
                                     shape=input_shape, 
                                     name="magnitude_spectrogram")
        
        # Encoder
        if self.verbose:
            print("Building encoder...")
        # Convolutional layer 1
        self.conv1 = conv2d(self.spectrogram,
                       num_outputs=30,
                       kernel_size=[1, 30],
                       stride=[1, 4],
                       padding="VALID",
                       data_format=self.data_format,
                       weights_initializer="xavier",
                       biases_initializer=["constant", 0.0],
                       scope="CONV_1")

        # Convolutional layer 2
        self.conv2 = conv2d(self.conv1,
                       num_outputs=30,
                       kernel_size=[int(2*self.time_context/3), 1],
                       stride=[1, 1],
                       padding="VALID",
                       data_format=self.data_format,
                       weights_initializer="xavier",
                       biases_initializer=["constant", 0.0],
                       scope="CONV_2")
        self.conv2_flat = flatten(self.conv2,
                             data_format=self.data_format,
                             scope="CONV_2_FLAT")

        # Fully-connected layer 1 (encoding)
        self.fc1 = fully_connected(self.conv2_flat,
                              num_outputs=256,
                              activation_fn="relu",
                              weights_initializer="xavier",
                              biases_initializer=["constant", 0.0],
                              scope="FC_1")
        # Decoder
        if self.verbose:
            print("Building decoder...")
            
        # Get shapes and variables for building decoding layers
        # Kinda hacky but don't feel like redoing layer definitions right now
        batch_size = tf.shape(self.spectrogram)[0]
        if len(self.scope) > 0: scope = self.scope + '/'
        else:                   scope = self.scope
        conv1_shape = self.conv1.get_shape().as_list()
        conv1_weights = self.graph.get_tensor_by_name(scope + "CONV_1/weights:0")
        conv1_biases = self.graph.get_tensor_by_name(scope + "CONV_1/biases:0")
        conv2_shape = self.conv2.get_shape().as_list()
        conv2_size = conv2_shape[1] * conv2_shape[2] * conv2_shape[3]
        conv2_weights = self.graph.get_tensor_by_name(scope + "CONV_2/weights:0")
        conv2_biases = self.graph.get_tensor_by_name(scope + "CONV_2/biases:0")

        # Build decoder for each source
        self.fc2, self.convt1, self.convt2 = [], [], []
        for i in range(self.num_sources):
            # Fully-connected layer 2 (decoding)
            fc2_i = fully_connected(self.fc1,
                                    num_outputs=conv2_size,
                                    activation_fn="relu",
                                    weights_initializer="xavier",
                                    biases_initializer=["constant", 0.0],
                                    scope="FC_2_%d" % (i+1))
            self.fc2.append(fc2_i)

            # Convolutional transpose layer 1
            # Side note: tf.reshape() can infer size of one dimension given rest, so -1 okay
            #            tf.nn.conv2d_transpose() must know exact dimensions, but batch size can
            #            be inferred at runtime using tf.shape()
            fc2_i = tf.reshape(fc2_i, [-1] + conv2_shape[1:])
            convt1_i = inverse_conv2d(fc2_i,
                                      output_shape=[batch_size] + conv1_shape[1:],
                                      conv_weights=conv2_weights,
                                      conv_stride=[1, 1],
                                      padding="VALID",
                                      data_format=self.data_format,
                                      scope="CONVT_1_%d" % (i+1))
            self.convt1.append(convt1_i)

            # Convolutional transpose layer 2
            convt2_i = inverse_conv2d(convt1_i,
                                        output_shape=[batch_size] + input_shape[1:],
                                        conv_weights=conv1_weights,
                                        conv_stride=[1, 4],
                                        padding="VALID",
                                        data_format=self.data_format,
                                        scope="CONVT_2_%d" % (i+1))
            self.convt2.append(convt2_i)
        
        # Output
        if self.verbose:
            print("Building output...")
            
        # Output layer
        with tf.name_scope("y_hat"):
            convt2_all = tf.concat(self.convt2, axis=channel_dim)
            b_shape = [1, 1, 1, 1]
            b_shape[channel_dim] = self.num_sources
            b = tf.Variable(tf.constant(0.0, shape=b_shape),
                            dtype=tf.float32,
                            name="bias")
            self.y_hat = tf.maximum(tf.add(convt2_all, b), 0, name="y_hat")

        # Masks: m_n(f) = |y_hat_n(f)| / Σ(|y_hat_n'(f)|)
        with tf.name_scope("masks"):
            rand = tf.random_uniform([batch_size] + input_shape[1:])
            den = tf.reduce_sum(self.y_hat, axis=channel_dim, keep_dims=True) + (eps * rand)
            self.masks = tf.div(self.y_hat, den, name="masks") # broadcast along channel dimension

        # Source signals: y_tilde_n(f) = m_n(f) * x(f), 
        # where x(f) is the spectrogram of the input mixture signal
        with tf.name_scope("y_tilde"):
            self.y_tilde = tf.multiply(self.masks, self.spectrogram, name="y_tilde") # broadcast along channel dimension
        
        if self.verbose:
            print("Building losses and summaries...")
            
        # Loss function: L = 1/N * Σ(||y_tilde_n - target_n||^2)
        # Changed from total to mean loss to account for different feature sizes
        with tf.name_scope("loss"):
            self.targets = tf.placeholder(tf.float32, 
                                     shape=target_shape, 
                                     name="target_sources")
            reduc_indices = [i for i in range(4) if i != channel_dim]
            if mean_loss:
                loss_fn = tf.reduce_mean
            else:
                loss_fn = tf.reduce_sum
            self.loss_n = loss_fn(tf.square(self.y_tilde - self.targets), 
                                        axis=reduc_indices, 
                                        name="loss_n")
            self.loss_total = loss_fn(self.loss_n, name="loss_total")
        
        # Optimizer
        with tf.name_scope("train_step"):
            self.optimizer = tf.train.AdamOptimizer(self.alpha)
            self.train_step = self.optimizer.minimize(self.loss_total)

        # Summaries
        self.saver = tf.train.Saver(max_to_keep=1)        
        self.writer = tf.summary.FileWriter(self.results_dir, self.graph)
        with tf.name_scope("summaries"):
            # Loss summaries
            loss_sum = []
            with tf.name_scope("losses"):
                for i in range(self.num_sources):
                    loss_sum.append(tf.summary.scalar("loss_%d" % (i+1), self.loss_n[i]))
                loss_sum.append(tf.summary.scalar("loss_total", self.loss_total))
                self.loss_sum = tf.summary.merge(loss_sum)
            
            # Variable summaries
            var_sum = []
            with tf.name_scope("trainable_variables"):
                for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                    with tf.name_scope(var.name[:-2]):
                        mean = tf.reduce_mean(var)
                        var_sum.append(tf.summary.scalar("mean", mean))
                        with tf.name_scope("stddev"):
                            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
                        var_sum.append(tf.summary.scalar("stddev", stddev))
                        var_sum.append(tf.summary.scalar("max", tf.reduce_max(var)))
                        var_sum.append(tf.summary.scalar("min", tf.reduce_min(var)))
                        var_sum.append(tf.summary.histogram("histogram", var))
                self.var_sum = tf.summary.merge(var_sum)
    
    def perform_training_step(self, x, y):
        feed_dict = {self.spectrogram: x, self.targets: y}
        loss_, _ = self.sess.run([self.loss_total, self.train_step], 
                                 feed_dict=feed_dict)
        self.global_step += 1
        return loss_
    
    def predict(self, x, y=None):
        if y is None:
            feed_dict = {self.spectrogram: x}
            y_tilde_, = self.sess.run(self.y_tilde,
                                      feed_dict=feed_dict)
            return y_tilde_
        else:
            feed_dict = {self.spectrogram: x, self.targets: y}
            y_tilde_, loss_, = self.sess.run([self.y_tilde, self.loss_total], 
                                             feed_dict=feed_dict)
            return y_tilde_, loss_ 
    
    def save_summaries(self, x, y):
        feed_dict = {self.spectrogram: x, self.targets: y}
        loss_sum_, var_sum_ = self.sess.run([self.loss_sum, self.var_sum], 
                                            feed_dict=feed_dict)
        self.writer.add_summary(loss_sum_, global_step=self.global_step)
        self.writer.add_summary(var_sum_, global_step=self.global_step)
        self.writer.flush()
    
    def save_model(self, epoch, save_meta=True):
        self.saver.save(self.sess, self.results_dir + "model", 
                        global_step=epoch, write_meta_graph=save_meta)
    
    def load_model(self, params_file):
        self.saver.restore(self.sess, params_file)
        

## Training
The initial training in trial 1 on a song from the Bach10 dataset worked very well. That network only took in mag specs with 513 features, as opposed to 1025 with our current DSD100 specs. Further, trying to build graphs for the full 2049 features in the original Bach10 specs causes memory errors. Maybe it is simply too many parameters to learn for stable training. Try downsizing specs to 513 features.

In [7]:
def save_train_details(folder, desc):
    f = open(folder + "settings.txt", "w+")
    f.write("Description: " + desc + "\n")
    f.write("Input: " + input_dir + "\n")
    f.write("Output: " + target_dir + "\n")
    f.write("Epochs: " + str(num_epochs) + "\n")
    f.write("Learning rate: " + str(learning_rate) + "\n")
    f.write("Loss type: " + ("mean" if mean_loss else "sum") + "\n")
    f.write("Batch size: " + str(batch_size) + "\n")
    f.write("Time context: " + str(time_context) + "\n")
    f.write("Dataset memory length: " + str(mem_len) + "\n")
    f.write("Data format: " + data_format + "\n")
    f.write("Shuffle: " + str(shuffle) + "\n")
    f.write("Scale factor: " + str(scale_factor) + "\n")
    f.write("Number of features: " + str(ds.feat_size))

In [10]:
# Training settings
desc = "y_hat bias shape changed to per source (rather than [1, 1, 1, 1])"
input_dir="../../data/Bach10/features/1025/train/Mixtures/"
target_dir="../../data/Bach10/features/1025/train/Sources/"
results_dir = "../../results/trial_27/train_data/"
make_directory(results_dir)
num_epochs = 30
learning_rate = 0.001
batch_size = 32
time_context = 30
mem_len = 5e5
data_format = "NCHW"
mean_loss = False
#sources = ['bass', 'drums', 'other', 'vocal'] # DSD100
sources = ['violin', 'clarinet', 'saxphone', 'bassoon'] # Bach10
shuffle = True
scale_factor = 0.2

# Create dataset
print("Creating dataset...")
ds = Dataset(input_dir=input_dir, 
              target_dir=target_dir, 
              batch_size=batch_size,
              time_context=time_context,
              mem_len=mem_len,
              load_by_file=False,
              sources=sources,
              data_format=data_format,
              shuffle=shuffle,
              scale_factor=scale_factor,
              verbose=True)

# Create network
print("Creating network...")
net = Network(results_dir,
              data_format=data_format,
              time_context=time_context,
              feat_size=ds.feat_size,
              num_sources=ds.num_sources,
              alpha=learning_rate,
              mean_loss=mean_loss,
              train_mode=True)

# Save settings
save_train_details(results_dir, desc)

print("------------\nTraining\n------------")
start_time = time()
for epoch in range(num_epochs):
    print("Epoch %3d of %3d" % (epoch+1, num_epochs))
    epoch_start_time = time()
    if epoch > 0:
        ds.reset_database()
    t = 0
    while not ds.is_empty():
        #sys.stdout.write("\rEstimated batches remaining: " + str(ds.remaining_batches()))
        #sys.stdout.flush()
        print("Estimated batches remaining: %3d" % ds.remaining_batches(), end="\r")
        
        # Get batch from input data and target data
        input_batch, target_batch = ds.create_batch()
        
        # Perform training step and save summaries every so often
        loss = net.perform_training_step(input_batch, target_batch)
        if t % 100 == 0:
            net.save_summaries(input_batch, target_batch)
        
        t += 1
    
    # Save model after each epoch
    net.save_model(epoch+1, save_meta=(epoch==0))
    
    # TODO: this isn't accurate, doesn't take loading data into account
    print("Epoch complete                        ")
    end_time = time()
    elap_time = end_time - start_time
    rem_time = (end_time - epoch_start_time) * (num_epochs - (epoch + 1))
    print("Elapsed time: %02d:%02d:%02d" % (elap_time // 3600, 
                                            elap_time % 3600 // 60, 
                                            elap_time % 3600 % 60))
    print("Estimated time remaining: %02d:%02d:%02d" % (rem_time // 3600, 
                                                        rem_time % 3600 // 60, 
                                                        rem_time % 3600 % 60))
    print("--------------------------------------")

Creating dataset...
Reading file 5 of 5
Loading files 1 to 5 of 5 into memory...
Shuffling data...
Creating network...
Building input...
Building encoder...
Building decoder...
Building output...
Building losses and summaries...
------------
Training
------------
Epoch   1 of  30
Epoch complete                        
Elapsed time: 00:00:09
Estimated time remaining: 00:04:29
--------------------------------------
Epoch   2 of  30
Resetting database...
Loading files 1 to 5 of 5 into memory...
Shuffling data...
Epoch complete                        
Elapsed time: 00:00:18
Estimated time remaining: 00:04:13
--------------------------------------
Epoch   3 of  30
Resetting database...
Loading files 1 to 5 of 5 into memory...
Shuffling data...
Epoch complete                        
Elapsed time: 00:00:27
Estimated time remaining: 00:03:59
--------------------------------------
Epoch   4 of  30
Resetting database...
Loading files 1 to 5 of 5 into memory...
Shuffling data...
Epoch complete   

## Testing

In [8]:
def save_test_details(folder, desc):
    f = open(folder + "settings.txt", "w+")
    f.write("Description: " + desc + "\n")
    f.write("Input: " + input_dir + "\n")
    f.write("Output: " + target_dir + "\n")
    f.write("Params: " + params_file + "\n")
    f.write("Loss type: " + ("mean" if mean_loss else "sum") + "\n")
    f.write("Batch size: " + str(batch_size) + "\n")
    f.write("Time context: " + str(time_context) + "\n")
    f.write("Dataset memory length: " + str(mem_len) + "\n")
    f.write("Data format: " + data_format + "\n")
    f.write("Shuffle: " + str(shuffle) + "\n")
    f.write("Scale factor: " + str(scale_factor) + "\n")
    f.write("Number of features: " + str(ds.feat_size))

In [9]:
# Testing settings
desc = "base network"
input_dir="../../data/Bach10/features/1025/train/Mixtures/"
target_dir = "../../data/Bach10/features/1025/train/Sources/"
results_dir = "../../results/trial_27/test_data/train/"
make_directory(results_dir)
params_file = "../../results/trial_27/train_data/model-30"
batch_size = 32
time_context = 30
mem_len = 5e5
data_format = "NCHW"
mean_loss = False
#sources = ['bass', 'drums', 'other', 'vocal'] # DSD100
sources = ['violin', 'clarinet', 'saxphone', 'bassoon'] # Bach10
shuffle = False
scale_factor = 0.2

# Create dataset
print("Creating dataset...")
ds = Dataset(input_dir=input_dir, 
              target_dir=target_dir, 
              batch_size=batch_size,
              time_context=time_context,
              mem_len=mem_len,
              load_by_file=True,
              sources=sources,
              data_format=data_format,
              shuffle=shuffle,
              scale_factor=scale_factor,
              verbose=True)

# Create network
print("Creating network...")
net = Network(results_dir + "Network/",
              params_file=params_file,
              data_format=data_format,
              time_context=time_context,
              feat_size=ds.feat_size,
              num_sources=ds.num_sources,
              mean_loss=mean_loss,
              train_mode=False)

# Save settings
save_test_details(results_dir, desc)

print("------------\nTesting\n------------")
t = 0 # batch iteration
pred = []
loss = 0
while not ds.is_empty():
    start_time = time()
    #sys.stdout.write("\rEstimated batches remaining: " + str(ds.remaining_batches()))
    #sys.stdout.flush()
    print("Estimated files remaining: %3d" % ds.remaining_files(), end="\r")

    # Get batch from input data and target data
    input_batch, target_batch = ds.create_batch()
    
    # Make prediction
    pred_, loss_ = net.predict(input_batch, target_batch)
    
    # Store prediction and loss
    pred.append(np.reshape(pred_, [-1, ds.feat_size, ds.num_sources]))
    loss += loss_

    # Save prediction if reached end of current file
    if ds.load_next or ds.reset:
        print("File %d loss: %.4f; saving results...           " % (t+1, loss / len(pred)))
       
        # Get last window
        if data_format == "NCHW":
            input_shape = [1, 1, time_context, -1]
            target_shape = [1, ds.num_sources, time_context, -1]
        elif data_format == "NHWC":
            input_shape = [1, time_context, 1, -1]
            target_shape = [1, time_context, ds.num_sources, -1]
        last_input = np.reshape(ds.inputs_[-time_context:], input_shape)
        last_target = np.reshape(ds.sources_[-time_context:], target_shape)
        pred_, _ = net.predict(last_input, last_target)
        pred_ = np.reshape(pred_, [-1, ds.feat_size, ds.num_sources])
    
        # Add last window to predictions, averaging overlapping time points
        pred = np.concatenate(pred, axis=0)
        overlap = ds.shapes[t, 0] - pred.shape[0]
        if overlap == 0: overlap = None
        pred[-time_context:-overlap] = \
            (pred[-time_context:-overlap] + pred_[:-overlap]) / 2.0
        pred = np.concatenate([pred, pred_[-overlap:]], axis=0)
       
        # Save source predictions
        for i, s in enumerate(sources):
            filename = os.path.split(ds.input_files[t])[-1]
            base = filename.split('.')[0]
            np.save(results_dir + "Spectrograms/" base + "-%s-pred" % s, pred[:, :, i])
        
        # Increment counters
        loss = 0
        pred = []
        t += 1

print("Done.")

Creating dataset...
Reading file 5 of 5
Loading files 1 to 1 of 5 into memory...
Creating network...
Building input...
Building encoder...
Building decoder...
Building output...
Building losses and summaries...
INFO:tensorflow:Restoring parameters from ../../results/trial_27/train_data/model-30
------------
Testing
------------
File 1 loss: 57.5643; saving results...           
Loading files 2 to 2 of 5 into memory...
File 2 loss: 99.2381; saving results...           
Loading files 3 to 3 of 5 into memory...
File 3 loss: 62.1705; saving results...           
Loading files 4 to 4 of 5 into memory...
File 4 loss: 88.5627; saving results...           
Loading files 5 to 5 of 5 into memory...
File 5 loss: 84.9008; saving results...           
Done.


## Old code
Just stashing away old code in case I need to reference or pull from it.

### Old Datasets

In [117]:
class Dataset(object):
    
    def __init__(self,
                 input_dir, 
                 target_dir, 
                 batch_size=32,
                 time_context=30,
                 mem_len=1e5,
                 sources=['bass', 'drums', 'other', 'vocal']):
        # Get data files
        self.input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir)])
        target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir)])
        self.source_files = []
        for i, s in enumerate(sources):
            self.source_files.append([t for t in target_files if s in t])

        # Get shapes of data files
        input_shapes = []
        for f in self.input_files:
            input_shapes.append(np.load(f, mmap_mode='r').shape)
        input_shapes = np.asarray(input_shapes)
        self.t_total = np.sum(input_shapes[:, 0])
        self.feat_size = input_shapes[0, 1]
        if not (input_shapes[:, 1] == self.feat_size).all():
            raise ValueError("All spectograms must have same number of features.")
        
        # Create memmap array of all files
        # From https://stackoverflow.com/questions/13780907/is-it-possible-to-np-concatenate-memory-mapped-files
        # Initialize memmap objects
        self.inputs = np.memmap(self.input_files[0], 
                                dtype='float64',
                                mode='r+',
                                shape=(self.t_total, self.feat_size),
                                order='C')
        self.sources = []
        for i in range(len(sources)):
            self.sources.append(np.memmap(self.source_files[i][0], 
                                          dtype='float64',
                                          mode='r+',
                                          shape=(self.t_total, self.feat_size),
                                          order='C'))
        
        # Read values from subsequent files into initialized memmap
        idx = input_shapes[0, 0]
        for i in range(1, len(self.input_files)):
            print("Reading file %d of %d" % (i+1, len(self.input_files)))
            # Load mem map of mixture file
            #f_in = np.load(self.input_files[i], mmap_mode='r')
            f_in = np.memmap(self.input_files[i],
                            dtype='float64',
                                mode='r',
                                shape=(input_shapes[i,0], self.feat_size),
                                order='C')
            self.inputs[idx:idx+f_in.shape[0]] = f_in
            
            # Load mem maps of source files
            for j in range(len(sources)):
                #f_s = np.load(self.source_files[j][i], mmap_mode='r')
                f_s = np.memmap(self.source_files[j][i],
                                dtype='float64',
                                mode='r',
                                shape=(input_shapes[i,0], self.feat_size),
                                order='C')
                self.sources[j][idx:idx+f_s.shape[0]] = f_s
            
            # Increment index
            idx += f_in.shape[0]
        
        self.mem_len = mem_len
        self.t = 0 # current global time index
        self.t_ = 0 # current memory time index
        self.reset = False # if True, reset database
        self.load_memory()
    
    def load_memory(self):
        """Loads part of dataset into memory."""
        idx_start = self.t
        self.inputs_ = self.inputs[self.t:min(self.t+self.mem_len, self.t_total)]
        self.sources_ = [s[self.t:min(self.t+self.mem_len, self.t_total)] 
                         for s in self.sources]
        self.t_ = 0
    
    def reset_database(self):
        self.t = 0
        self.t_ = 0
        self.reset = False
        self.load_memory()
    
    def create_batch(self):
        """Creates batch of training data from datset"""
        # Reset database if exhausted
        if self.reset:
            self.reset_database()
        
        # Load more data into memory if needed
        batch_len = self.batch_size * self.time_context
        self.reset = False
        if self.t + batch_len > self.t_total:
            batch_len = self.t_total - self.t
            self.reset = True
        elif self.t_ + batch_len > self.inputs_.shape[0]:
            self.load_memory()
        
        # Get batches of inputs and sources
        inputs_batch = np.reshape(self.inputs_[self.t_:self.t_+batch_len],
                                  [self.batch_size, self.time_context, self.feat_size])
        sources_batch = [np.reshape(s[self.t_:self.t_+batch_len],
                                    [self.batch_size, self.time_context, self.feat_size])
                         for s in self.sources_]
        self.t_ += batch_len
        self.t  += batch_len
        
        return inputs_batch, sources_batch
    

In [5]:
class Dataset(object):
    
    def __init__(self,
                 input_dir, 
                 target_dir, 
                 batch_size=32,
                 time_context=30,
                 mem_len=1e5,
                 sources=['bass', 'drums', 'other', 'vocal'],
                 data_format="NHWC",
                 shuffle_len=30,
                 verbose=True):
        # Grab arguments
        self.batch_size = batch_size
        self.time_context = time_context
        self.num_sources = len(sources)
        self.data_format = data_format
        self.shuffle_len = shuffle_len
        self.verbose = verbose
        
        # Get data files
        self.input_files = sorted([os.path.join(input_dir, f) 
                                   for f in os.listdir(input_dir) if "mag" in f])
        target_files = sorted([os.path.join(target_dir, f) 
                               for f in os.listdir(target_dir) if "mag" in f])
        self.source_files = [] # [source_id][t]
        for i, s in enumerate(sources):
            self.source_files.append([t for t in target_files if s in t])
        
        # Read mem maps of data files
        self.inputs = []
        self.sources = [] # [t][source_id]
        self.shapes = []
        shapes_ = [] # placeholder to ensure all associated shapes equal
        for i in range(len(self.input_files)):
            if verbose:
                print("Reading file %d of %d" % (i+1, len(self.input_files)))
            
            # Add mem map of mixture file
            f_in = np.load(self.input_files[i], mmap_mode='r')
            self.inputs.append(f_in)
            
            # Add mem maps of source files
            f_s = [np.load(self.source_files[j][i], mmap_mode='r') 
                   for j in range(len(sources))]
            self.sources.append(f_s)
            
            # Get shapes
            self.shapes.append(f_in.shape)
            shapes_.append([f.shape for f in f_s])
        
        # Set class variables
        self.shapes = np.asarray(self.shapes)
        self.t_total = np.sum(self.shapes[:, 0])
        self.feat_size = self.shapes[0, 1]
        
        # Check that shapes [index, ch, shape] are equal for each time point
        shapes_ = np.concatenate([self.shapes[:, np.newaxis, :], np.asarray(shapes_)], axis=1)
        if not (shapes_[:, :, 0].T == shapes_[:, 0, 0]).all():
            raise ValueError("All spectrograms must be of same length.")
        if not (shapes_[:, :, 1].T == shapes_[:, 0, 1]).all():
            raise ValueError("All spectograms must have same number of features.")
        shapes_ = None # release from memory
            
        # Set variables to track time position
        self.t_c = np.cumsum(self.shapes[:, 0])
        self.mem_len = mem_len
        self.t = 0 # current global time index
        self.t_ = 0 # current memory time index
        self.reset = False # if True, reset database
        self.load_next = False # if True, load next batch of files
        self.load_memory()
    
    def load_memory(self):
        """Loads part of dataset into memory."""
        # Get indices of files corresponding to next mem_len time points
        t_start = self.t
        t_end = self.t + self.mem_len
        idx_start = np.searchsorted(self.t_c, t_start, side='right')
        idx_end = np.searchsorted(self.t_c, t_end, side='right')
        if self.verbose:
            #print("Indices t %d to %d" % (t_start, t_end))
            print("Loading files %d to %d of %d into memory..." % (idx_start+1, idx_end, len(self.inputs)))
        
        # Load from mem maps of files into shape [time, features, [num_sources]]
        f_in = [self.inputs[t] for t in range(idx_start, idx_end)] # get list of next files
        self.inputs_ = None # clear previous memory (avoids temporary double storage)
        self.inputs_ = np.concatenate(f_in, axis=0) # loads into memory
        f_s = [np.asarray(self.sources[t]) for t in range(idx_start, idx_end)]
        self.sources_ = None
        self.sources_ = np.transpose(np.concatenate(f_s, axis=1), axes=[1, 2, 0]) # loads into memory
                                          
        # Reset counter
        self.t_ = 0
    
    def reset_database(self):
        if self.verbose:
            print("Resetting database...")
        self.t = 0
        self.t_ = 0
        self.reset = False
        self.load_memory()
    
    def create_batch(self):
        """Creates batch of training data from datset"""
        # Load memory if exhausted
        if self.reset:
            self.reset_database()
        elif self.load_next:
            self.load_memory()
        
        # Determine batch length and set loading bools for next batch
        batch_len = self.batch_size * self.time_context
        rem = 0
        self.reset = False
        self.load_next = False
        if self.t + batch_len > self.t_total: # reach end of all data
            batch_len = (self.t_total - self.t) // self.time_context * self.time_context
            rem = (self.t_total - self.t) % self.time_context
            self.reset = True
        elif self.t_ + batch_len > self.inputs_.shape[0]: # reach end of loaded data
            batch_len = (self.inputs_.shape[0] - self.t_) // self.time_context * self.time_context
            rem = (self.inputs_.shape[0] - self.t_) % self.time_context
            self.load_next = True
        
        # Get batches of inputs and sources
        inputs_batch = np.reshape(self.inputs_[self.t_:self.t_+batch_len],
                                  [-1, self.time_context, self.feat_size, 1])
        sources_batch = np.reshape(self.sources_[self.t_:self.t_+batch_len],
                                    [-1, self.time_context, self.feat_size, self.num_sources])
        if self.data_format == "NCHW":
            inputs_batch = np.transpose(inputs_batch, axes=[0, 3, 1, 2])
            sources_batch = np.transpose(sources_batch, axes=[0, 3, 1, 2])
        
        # Increment counters
        self.t_ += batch_len + rem 
        self.t  += batch_len + rem
        
        return inputs_batch, sources_batch
    
    def is_empty(self):
        return ds.reset
    
    def remaining_batches(self):
        return math.ceil((self.inputs_.shape[0] - self.t_) 
                         / (self.batch_size * self.time_context))

In [5]:
class Dataset(object):
    
    def __init__(self,
                 input_dir, 
                 target_dir, 
                 batch_size=32,
                 time_context=30,
                 mem_len=1e5,
                 sources=['bass', 'drums', 'other', 'vocal'],
                 data_format="NHWC",
                 shuffle_len=30,
                 verbose=True):
        # Grab arguments
        self.batch_size = batch_size
        self.time_context = time_context
        self.num_sources = len(sources)
        self.data_format = data_format
        self.shuffle_len = shuffle_len
        self.verbose = verbose
        
        # Get data files
        self.input_files = sorted([os.path.join(input_dir, f) 
                                   for f in os.listdir(input_dir) if "mag" in f])
        target_files = sorted([os.path.join(target_dir, f) 
                               for f in os.listdir(target_dir) if "mag" in f])
        self.source_files = [] # [source_id][t]
        for i, s in enumerate(sources):
            self.source_files.append([t for t in target_files if s in t])
        
        # Read mem maps of data files
        self.inputs = []
        self.sources = [] # [t][source_id]
        self.shapes = []
        shapes_ = [] # placeholder to ensure all associated shapes equal
        for i in range(len(self.input_files)):
            if verbose:
                print("Reading file %d of %d" % (i+1, len(self.input_files)))
            
            # Add mem map of mixture file
            f_in = np.load(self.input_files[i], mmap_mode='r')
            self.inputs.append(f_in)
            
            # Add mem maps of source files
            f_s = [np.load(self.source_files[j][i], mmap_mode='r') 
                   for j in range(len(sources))]
            self.sources.append(f_s)
            
            # Get shapes
            self.shapes.append(f_in.shape)
            shapes_.append([f.shape for f in f_s])
        
        # Set class variables
        self.shapes = np.asarray(self.shapes)
        self.t_total = np.sum(self.shapes[:, 0])
        self.feat_size = self.shapes[0, 1]
        
        # Check that shapes [index, ch, shape] are equal for each time point
        shapes_ = np.concatenate([self.shapes[:, np.newaxis, :], np.asarray(shapes_)], axis=1)
        if not (shapes_[:, :, 0].T == shapes_[:, 0, 0]).all():
            raise ValueError("All spectrograms must be of same length.")
        if not (shapes_[:, :, 1].T == shapes_[:, 0, 1]).all():
            raise ValueError("All spectograms must have same number of features.")
        shapes_ = None # release from memory
            
        # Set variables to track time position
        self.t_c = np.cumsum(self.shapes[:, 0])
        self.mem_len = mem_len
        self.t = 0 # current global time index
        self.t_ = 0 # current memory time index
        self.reset = False # if True, reset database
        self.load_next = False # if True, load next batch of files
        self.load_memory()
    
    def load_memory(self):
        """Loads part of dataset into memory."""
        # Get indices of files corresponding to next mem_len time points
        t_start = self.t
        t_end = self.t + self.mem_len
        idx_start = np.searchsorted(self.t_c, t_start, side='right')
        idx_end = np.searchsorted(self.t_c, t_end, side='right')
        if self.verbose:
            print("Loading files %d to %d of %d into memory..." % (idx_start+1, idx_end, len(self.inputs)))
        
        # Load from mem maps of files into shape [time, features, [num_sources]]
        f_in = [self.inputs[t] for t in range(idx_start, idx_end)] # get list of next files
        self.inputs_ = None # clear previous memory (avoids temporary double storage)
        self.inputs_ = np.concatenate(f_in, axis=0) # loads into memory
        f_s = [np.asarray(self.sources[t]) for t in range(idx_start, idx_end)]
        self.sources_ = None
        self.sources_ = np.transpose(np.concatenate(f_s, axis=1), axes=[1, 2, 0]) # loads into memory
                                          
        # Set vector of random start points from loaded memory
        batch_len = self.batch_size * self.time_context
        self.t_rand = np.asarray([batch_len * n 
                                  for n in range(self.inputs_.shape[0] // batch_len + 1)])
        np.random.shuffle(self.t_rand)
        self.t_idx = 0
        self.t_ = 0
    
    def reset_database(self):
        if self.verbose:
            print("Resetting database...")
        self.t = 0
        self.t_ = 0
        self.reset = False
        self.load_next = False
        self.load_memory()
    
    def create_batch(self):
        """Creates batch of training data from datset"""
        # Load memory if exhausted
        if self.reset:
            self.reset_database()
        elif self.load_next:
            self.load_memory()

        # Determine batch length and set loading bools for next batch
        batch_len = self.batch_size * self.time_context
        rem = 0
        self.reset = False
        self.load_next = False
        t_i = self.t_rand[self.t_idx]
        if self.t + batch_len >= self.t_total: # reach end of all data
            #batch_len = (self.t_total - self.t) // self.time_context * self.time_context
            #rem = (self.t_total - self.t) % self.time_context
            self.reset = True
        if t_i + batch_len >= self.inputs_.shape[0]: # reach end of loaded data
            batch_len = (self.inputs_.shape[0] - t_i) // self.time_context * self.time_context
            rem = (self.inputs_.shape[0] - t_i) % self.time_context
        if self.t_idx >= len(self.t_rand) - 1: # exhausted all loaded data
            self.load_next = True
        
        # Get batches of inputs and sources
        inputs_batch = np.reshape(self.inputs_[t_i:t_i+batch_len],
                                  [-1, self.time_context, self.feat_size, 1])
        sources_batch = np.reshape(self.sources_[t_i:t_i+batch_len],
                                    [-1, self.time_context, self.feat_size, self.num_sources])
        if self.data_format == "NCHW":
            inputs_batch = np.transpose(inputs_batch, axes=[0, 3, 1, 2])
            sources_batch = np.transpose(sources_batch, axes=[0, 3, 1, 2])
        
        # Increment counters
        self.t  += batch_len + rem
        self.t_ += batch_len + rem
        self.t_idx += 1
        
        return inputs_batch, sources_batch
    
    def is_empty(self):
        return ds.reset
    
    def remaining_batches(self):
        return math.ceil((self.inputs_.shape[0] - self.t_) 
                         / (self.batch_size * self.time_context))

### Old batch handling

In [None]:
def create_batches(data, batch_size):
    """Reshapes data into batches of input size for network"""
    batches = []
    time_batches = data.shape[1] // time_context
    freq_batches = data.shape[2] // feat_size
    for t in range(time_batches):
        for f in range(freq_batches):
            batches.append(data[:, t*time_context:(t+1)*time_context, f*feat_size:(f+1)*feat_size])
    return np.asarray(batches)

In [55]:
# Training settings
params_dir = results_dir + "params/"
make_directory(params_dir)
input_file = "./features/02-AchLiebenChristen__m_.data"
shape_file = "./features/02-AchLiebenChristen__m_.shape"
num_epochs = 100
batch_size = 32


f_in = np.fromfile(input_file)
if shape_file is not None:
    f_shape = get_shape(shape_file)
    f_in = np.reshape(f_in, f_shape)
input_data = create_batches(f_in[0:1], batch_size) # mixed input
target_data = create_batches(f_in[1:], batch_size) # separate sources
iter_size = len(input_data) // batch_size

# Initialize graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())

global_step = 0
for epoch in range(num_epochs):
    for i in range(iter_size):
        # Get batch from input data and target data
        input_batch = input_data[i*batch_size:(i+1)*batch_size] # magnitude spectrogram of whole
        target_batch = target_data[i*batch_size:(i+1)*batch_size] # magnitude spectrogram of sources
        
        # Perform training step
        feed_dict = {spectrogram: input_batch, targets: target_batch}
        loss_sum_, _ = sess.run([loss_sum, train_step], 
                                feed_dict=feed_dict)
        writer.add_summary(loss_sum_, global_step=global_step)
        writer.flush()
        global_step += 1
    
    # Save model after each epoch
    saver.save(sess, params_dir + "model", 
               global_step=epoch)

### Dataset testing (ignore)

In [248]:
ds = Dataset(input_dir='../../data/DSD100/features_mini/Mixtures/Dev/', 
             target_dir='../../data/DSD100/features_mini/Sources/Dev/', 
             batch_size=32,
             time_context=30,
             sources=['bass', 'drums', 'other', 'vocal'])

Reading file 1 of 5
Reading file 2 of 5
Reading file 3 of 5
Reading file 4 of 5
Reading file 5 of 5
Indices t 0 to 100000
Loading files 1 to 2 of 5 into memory...


In [249]:
for i in range(200):
    print("Batch %d, t %d, t_ %d" % (i+1, ds.t, ds.t_))
    _, _ = ds.create_batch()

Batch 1, t 0, t_ 0
Batch 2, t 960, t_ 960
Batch 3, t 1920, t_ 1920
Batch 4, t 2880, t_ 2880
Batch 5, t 3840, t_ 3840
Batch 6, t 4800, t_ 4800
Batch 7, t 5760, t_ 5760
Batch 8, t 6720, t_ 6720
Batch 9, t 7680, t_ 7680
Batch 10, t 8640, t_ 8640
Batch 11, t 9600, t_ 9600
Batch 12, t 10560, t_ 10560
Batch 13, t 11520, t_ 11520
Batch 14, t 12480, t_ 12480
Batch 15, t 13440, t_ 13440
Batch 16, t 14400, t_ 14400
Batch 17, t 15360, t_ 15360
Batch 18, t 16320, t_ 16320
Batch 19, t 17280, t_ 17280
Batch 20, t 18240, t_ 18240
Batch 21, t 19200, t_ 19200
Batch 22, t 20160, t_ 20160
Batch 23, t 21120, t_ 21120
Batch 24, t 22080, t_ 22080
Batch 25, t 23040, t_ 23040
Batch 26, t 24000, t_ 24000
Batch 27, t 24960, t_ 24960
Batch 28, t 25920, t_ 25920
Batch 29, t 26880, t_ 26880
Batch 30, t 27840, t_ 27840
Batch 31, t 28800, t_ 28800
Batch 32, t 29760, t_ 29760
Batch 33, t 30720, t_ 30720
Batch 34, t 31680, t_ 31680
Batch 35, t 32640, t_ 32640
Batch 36, t 33600, t_ 33600
Batch 37, t 34560, t_ 34560
Bat

### Graph building without class

In [254]:
tf.reset_default_graph()

# Setttings
#data_format = "NHWC" # if using cpu
data_format = "NCHW" # if using gpu
results_dir = "./results/trial_1/"
make_directory(results_dir)
time_context = 30
feat_size = 513
num_sources = 4
eps = 1e-18 # numerical stability
alpha = 0.001 # learning rate

# Data formatting
if data_format == "NHWC":
    input_shape = [None, time_context, feat_size, 1]
    target_shape = [None, time_context, feat_size, num_sources]
    channel_dim = 3
elif data_format == "NCHW":
    input_shape = [None, 1, time_context, feat_size]
    target_shape = [None, num_sources, time_context, feat_size]
    channel_dim = 1
else:
    raise ValueError("Unknown data format \"" + data_format + "\"")
spectrogram = tf.placeholder(tf.float32, 
                             shape=input_shape, 
                             name="magnitude_spectrogram")

# Convolutional layer 1
conv1 = conv2d(spectrogram,
               num_outputs=30,
               kernel_size=[1, 30],
               stride=[1, 4],
               padding="VALID",
               data_format=data_format,
               weights_initializer="xavier",
               biases_initializer=["constant", 0.0],
               scope="CONV_1")

# Convolutional layer 2
conv2 = conv2d(conv1,
               num_outputs=30,
               kernel_size=[int(2*time_context/3), 1],
               stride=[1, 1],
               padding="VALID",
               data_format=data_format,
               weights_initializer="xavier",
               biases_initializer=["constant", 0.0],
               scope="CONV_2")
conv2_flat = flatten(conv2,
                     data_format=data_format,
                     scope="CONV_2_FLAT")

# Fully-connected layer 1 (encoding)
fc1 = fully_connected(conv2_flat,
                      num_outputs=256,
                      activation_fn="relu",
                      weights_initializer="xavier",
                      biases_initializer=["constant", 0.0],
                      scope="FC_1")

# Get shapes for building decoding layers
batch_size = tf.shape(spectrogram)[0]
conv1_shape = conv1.get_shape().as_list()
conv2_shape = conv2.get_shape().as_list()
conv2_size = conv2_shape[1] * conv2_shape[2] * conv2_shape[3]

# Build decoder for each source
fc2, convt1, convt2 = [], [], []
for i in range(num_sources):
    # Fully-connected layer 2 (decoding)
    fc2_i = fully_connected(fc1,
                            num_outputs=conv2_size,
                            activation_fn="relu",
                            weights_initializer="xavier",
                            biases_initializer=["constant", 0.0],
                            scope="FC_2_%d" % (i+1))
    fc2.append(fc2_i)
    
    # Convolutional transpose layer 1
    # Side note: tf.reshape() can infer size of one dimension given rest, so -1 okay
    #            tf.nn.conv2d_transpose() must know exact dimensions, but batch size can
    #                be inferred at runtime using tf.shape()
    fc2_i = tf.reshape(fc2_i, [-1] + conv2_shape[1:])
    convt1_i = conv2d_transpose(fc2_i,
                                output_shape=[batch_size] + conv1_shape[1:],
                                kernel_size=[int(2*time_context/3), 1],
                                stride=[1, 1],
                                padding="VALID",
                                data_format=data_format,
                                weights_initializer="xavier",
                                biases_initializer=["constant", 0.0],
                                scope="CONVT_1_%d" % (i+1))
    convt1.append(convt1_i)
    
    # Convolutional transpose layer 2
    convt2_i = conv2d_transpose(convt1_i,
                                output_shape=[batch_size] + input_shape[1:],
                                kernel_size=[1, 30],
                                stride=[1, 4],
                                padding="VALID",
                                data_format=data_format,
                                weights_initializer="xavier",
                                biases_initializer=["constant", 0.0],
                                scope="CONVT_2_%d" % (i+1))
    convt2.append(convt2_i)

# Output layer
with tf.name_scope("y_hat"):
    convt2_all = tf.concat(convt2, axis=channel_dim)
    b = tf.Variable(tf.constant(0.0, shape=[1, 1, 1, 1]),
                    dtype=tf.float32,
                    name="bias")
    y_hat = tf.maximum(tf.add(convt2_all, b), 0, name="y_hat")

# Masks: m_n(f) = |y_hat_n(f)| / Σ(|y_hat_n'(f)|)
with tf.name_scope("masks"):
    rand = tf.random_uniform([batch_size] + input_shape[1:])
    den = tf.reduce_sum(y_hat, axis=channel_dim, keep_dims=True) + (eps * rand)
    masks = tf.div(y_hat, den, name="masks") # broadcast along channel dimension
    
# Source signals: y_tilde_n(f) = m_n(f) * x(f), 
# where x(f) is the spectrogram of the input mixture signal
with tf.name_scope("y_tilde"):
    y_tilde = tf.multiply(masks, spectrogram, name="y_tilde") # broadcast along channel dimension

# Loss function: L = Σ(||y_tilde_n - target_n||^2)
with tf.name_scope("loss"):
    targets = tf.placeholder(tf.float32, 
                             shape=target_shape, 
                             name="target_sources")
    reduc_indices = [i for i in range(4) if i != channel_dim]
    loss_n = tf.reduce_sum(tf.square(y_tilde - targets), axis=reduc_indices, name="loss_n")
    loss_total = tf.reduce_sum(loss_n, name="loss_total")

# Optimizer
with tf.name_scope("train_step"):
    optimizer = tf.train.AdamOptimizer(alpha)
    train_step = optimizer.minimize(loss_total)

# Summaries
saver = tf.train.Saver(max_to_keep=5)        
graph = tf.get_default_graph()
writer = tf.summary.FileWriter(results_dir, graph)
loss_sum = []
with tf.name_scope("summaries"):
    for i in range(num_sources):
        loss_sum.append(tf.summary.scalar("loss_%d" % (i+1), loss_n[i]))
    loss_sum.append(tf.summary.scalar("loss_total", loss_total))
    loss_sum = tf.summary.merge(loss_sum)

In [None]:
def conv2d_transpose(x,
                     output_shape,
                     kernel_size,
                     stride=1,
                     padding="VALID",
                     data_format="NCHW",
                     normalizer_fn=None,
                     activation_fn=None,
                     weights_initializer="random_normal",
                     biases_initializer=None,
                     trainable=True,
                     scope="CONV_T"):
    with tf.name_scope(scope):
        x_shape = x.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            if data_format == "NHWC":
                input_channels = x_shape[3]
                num_outputs = output_shape[3]
            elif data_format == "NCHW":
                input_channels = x_shape[1]
                num_outputs = output_shape[1]
            W_shape = kernel_size + [num_outputs, input_channels]
            if W_init_type == "xavier": # based on output size
                layer_shape = output_shape[1:]
                n_out = tf.reduce_prod(layer_shape)
                W_init_params = [n_out]
            W_init = _get_variable_initializer(W_init_type,
                                               W_shape,
                                               *W_init_params)
        W = tf.Variable(W_init, 
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        

        # Convolute input
        stride_h, stride_w = _check_list(stride)
        if isinstance(stride_w, list):
            if len(stride_w) == 0:
                stride_w = stride_h
            else:
                stride_w = stride_w[0]
        if data_format == "NHWC":
            strides = [1, stride_h, stride_w, 1]
        elif data_format == "NCHW":
            strides = [1, 1, stride_h, stride_w]
        out = tf.nn.conv2d_transpose(x, 
                                     filter=W,
                                     output_shape=output_shape,
                                     strides=strides,
                                     padding=padding,
                                     data_format=data_format,
                                     name="convolution_transpose")
        
        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=data_format)
        
        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            if data_format == "NHWC":
                b_shape = [1, 1, 1, num_outputs]
            elif data_format == "NCHW":
                b_shape = [1, num_outputs, 1, 1]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")

        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out