# Single Channel Source Separation
Some ideas:
- Deeper networks
- Wider networks
- Different conv filter sizes
- Noisy inputs
- Dropout

Some notes about data preprocessing:
In the original code, data is preprocessed from raw .wav files into numpy arrays containing the magnitude and phase spectrograms from the short time Fourier transfrom (STFT). The class `LargeDataset` in `dataset.py` handles the transformation and subsequent batch handling during network training. How does it work?

1. The features (i.e. spectrograms) from the audio files are saved in a directory. If using the `compute_transform` function found in class `Transform` in `transform.py`, then the features for each audio file should be in this directory as:
    - {filename}__{m,p}_.data : numpy array containing the magnitude (m) or phase (p) spectrogram
    - {filename}__{m,p}_.shape : binary file containing shape of array
2. The LargeDataset class is pointed to the feature directory via the `path_transform_in` argument in the constructor.
3. It calls updatePath to update its list of .data files (self.file_list), which are all the .data files in the feature directory.
4. updatePath also updates the cumulative number of points in the file list (`self.num_points`) and total points (`self.total_points`), where a point is a time window of size `time_context`. This is done via the getNum function, which essentially return the `np.floor(time_axis / time_context`, plus a term if using overlap:

```python
def getNum(self,id):
        """
        For a single .data file computes the number of examples of size \"time_context\" that can be created
        """
        shape = self.get_shape(os.path.join(self.path_transform_in[self.dirid[id]],self.file_list[id].replace('.data','.shape')))
        time_axis = shape[1]
        return np.maximum(1,int(np.floor((time_axis + (np.floor(float(time_axis)/self.time_context) * self.overlap))  / self.time_context)))
```
5. updatePath also updates the input and output feature sizes via the `getFeatureSize` function, which returns the number of features (self.input_size) and number of features * number of sources (self.output_size) for each .data file.
6. Finally, updatePath calls `initBatches()`, which allocates memory needed for output. Several class variables are set:

```python
self.batch_size = np.minimum(self.batch_size,self.num_points[-1]) # size of each batch
self.iteration_size = int(self.total_points / self.batch_size)    # number of batches in dataset
self.batch_memory = np.minimum(self.batch_memory,self.iteration_size) # minimum number of batches to load into memory
#...
self.batch_inputs = np.zeros((self.batch_memory*self.batch_size,self.time_context,self.input_size), dtype=self.tensortype)
        self.batch_outputs = np.zeros((self.batch_memory*self.batch_size,self.time_context,self.output_size), dtype=self.tensortype)
```
7. At this point, all files and directories are accounted for, but nothing has actually been  loaded into memory. `initBatches()` calls `loadBatches()`, which loads batches into `self.batch_inputs` and `self.batch_outputs` once the current store is exhausted, by itself calling `genBatches()`. First, `genBatches()` calls `getNextIndex()` to update class variables that set the time window for the next batch:

```python
def getNextIndex(self):
    """
    Returns how many batches/sequences to load from each .data file
    """
    # next time point = (# of loads into memory) * (# of time points per load)
    target_value = (self.scratch_index+1)*(self.batch_memory*self.batch_size)
    # next file index = right-sided search of files with cumulative sum = next time point
    idx_target = np.searchsorted(self.num_points,target_value, side='right')
    # End case: set idxend to the number of points in the last file, and nindex
    # to the last file
    if target_value>self.num_points[-1] or idx_target>=len(self.num_points):
        idx_target = idx_target - 2
        target_value = self.num_points[idx_target]
        self.idxend = self.num_points[idx_target] - self.num_points[idx_target-1]
        self.nindex = idx_target
    # Otherwise, set idxend to number of points after file ending just prior to target
    # time point, and nindex to that file
    else:
        while target_value<=self.num_points[idx_target]:
            idx_target = idx_target - 1
        self.idxend = target_value - self.num_points[idx_target]
        self.nindex = idx_target
```
8. Next, `genBatches()` decides how much to load from how many files in order to produce batches of the desired length. It utilizes the `loadFile()` function to load .data files into memory between indices `idxbegin` and `idxend`.  As seen in the `loadInputOutput()` helper function, the .data file `file` contains the input, or unseparated STFT, in `allmixinput = file[0]` and the output, or source separated STFTs, in `allmixoutput = file[1:]`. The STFTs are scaled by a (log) scale factor:

```python
#apply a scaled log10(1+value) function to make sure larger values are eliminated
#bach10 training: mult_factor_in = mult_factor_out = 0.3 (0.2 for testing)
#                 log_in = log_out = False
if self.log_in==True:
    allmixinput = self.mult_factor_in*np.log10(1.0+allmixinput)
else:
    allmixinput = self.mult_factor_in*allmixinput
if self.log_out==True:
    allmixoutput = self.mult_factor_out*np.log10(1.0+allmixoutput)
else:
    allmixoutput = self.mult_factor_out*allmixoutput
```
9. The inputs and outputs in `loadFile()` are originally set via `loadOutput()` to:

```python
size = idxend - idxbegin
inp = np.zeros((size, self.time_context, self.input_size), dtype=self.tensortype)
out = np.zeros((size, self.time_context, self.output_size), dtype=self.tensortype)
```
10. If the file size is smaller than `time_context`, then the first part of `inputs` and `outputs` are taken from the file:

```python
if self.time_context > allmixinput.shape[1]:
    inputs[0,:allmixinput.shape[1],:] = allmixinput[0]
    outputs[0, :allmixoutput.shape[1], :allmixoutput.shape[-1]] = allmixoutput[0]
    # ...
    # concatenate features from rest of sources to third (feature) dimension
    for j in range(1,self.nsources):
        outputs[0, :allmixoutput.shape[1], j*allmixoutput.shape[-1]:(j+1)*allmixoutput.shape[-1]] = allmixoutput[j]
```
10. Otherwise, samples of size `time_context` are taken from the file along the time dimension until the target number of loaded samples is satisfied:

```python
else:
    while (start + self.time_context) < allmixinput.shape[1]:
        if i>=idxbegin and i<idxend:
            # separate variables names for memory clearing
            allminput = allmixinput[:,start:start+self.time_context,:] #truncate on time axis so it would match the actual context
            allmoutput = allmixoutput[:,start:start+self.time_context,:]
            inputs[i-idxbegin] = allminput[0]
            outputs[i-idxbegin, :, :allmoutput.shape[-1]] = allmoutput[0]
            # ...
            # concatenate features from rest of sources to third (feature) dimension
            for j in range(1,self.nsources):
                outputs[i-idxbegin,:, j*allmoutput.shape[-1]:(j+1)*allmoutput.shape[-1]] = allmoutput[j,:,:]
            # ...

        i = i + 1
        start = start - self.overlap + self.time_context
        #clear memory
        allminput=None
        allmoutput=None
```
11. `loadFile()` returns a dictionary of input and output (and other) values to `genBatches()`. After smartly loading from the correct number of files in sequence to fill `self.batch_inputs` and `self.batch_outputs`, the batches are shuffled via `shuffleBatches()`. Finally, class variables are incremented accordingly in anticipation of the next call.

## Initial setup

In [1]:
import tensorflow as tf
import numpy as np
import os, errno
import re

In [2]:
def make_directory(f):
    """Makes directory if does not already exist"""
    try:
        os.makedirs(f)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise

## Layer definitions

In [3]:
def _check_list(arg):
    if isinstance(arg, list):
        try:
            return arg[0], arg[1:]
        except IndexError:
            return arg[0], []
    else:
        return arg, []

def _get_variable_initializer(init_type, var_shape, *args):
    if init_type == "random_normal":
        mean = float(args[0])
        stddev = float(args[1])
        return tf.random_normal(var_shape, mean=mean, stddev=stddev)
    elif init_type == "truncated_normal":
        mean = float(args[0])
        stddev = float(args[1])
        return tf.truncated_normal(var_shape, mean=mean, stddev=stddev)
    elif init_type == "constant":
        c = args[0]
        return tf.constant(c, dtype=tf.float32, shape=var_shape)
    elif init_type == "xavier":
        n_in = tf.cast(args[0], tf.float32)
        return tf.div(tf.random_normal(var_shape), tf.sqrt(n_in))
    else:
        raise ValueError("Variable initializer \"" + init_type + "\" not supported.")

def _apply_normalization(norm_type, x, *args, **kwargs):
    if norm_type == "batch_norm":
        return batch_norm(x, *args, **kwargs)
    else:
        raise ValueError("Normalization type \"" + norm_type + "\" not supported.")

def _apply_activation(activation_type, x, *args):
    if activation_type.lower() == "relu":
        return tf.nn.relu(x, name="Relu")
    elif activation_type.lower() == "leaky_relu":
        return tf.maximum(x, 0.1 * x, name="Leaky_Relu")
    elif activation_type.lower() == "softmax":
        return tf.nn.softmax(x)
    elif activation_type.lower() == "none":
        return x
    else:
        raise ValueError("Activation type \"" + activation_type + "\" not supported.")
        
def conv2d(input_layer,
           num_outputs,
           kernel_size,
           stride=1,
           padding="VALID",
           data_format="NCHW",
           normalizer_fn=None,
           activation_fn=None,
           weights_initializer="random_normal",
           biases_initializer=None,
           trainable=True,
           scope="CONV"):
    with tf.name_scope(scope):
        input_shape = input_layer.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            if data_format == "NHWC":
                input_channels = input_shape[3]
            elif data_format == "NCHW":
                input_channels = input_shape[1]
            W_shape = kernel_size + [input_channels, num_outputs]
            if W_init_type == "xavier":
                layer_shape = input_shape[1:]
                n_in = tf.reduce_prod(layer_shape)
                W_init_params = [n_in] 
            W_init = _get_variable_initializer(W_init_type,
                                                W_shape,
                                                *W_init_params)
        W = tf.Variable(W_init, 
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        

        # Convolute input
        stride_h, stride_w = _check_list(stride)
        if isinstance(stride_w, list):
            if len(stride_w) == 0:
                stride_w = stride_h
            else:
                stride_w = stride_w[0]
        if data_format == "NHWC":
            strides = [1, stride_h, stride_w, 1]
        elif data_format == "NCHW":
            strides = [1, 1, stride_h, stride_w]
        out = tf.nn.conv2d(input_layer, 
                            filter=W,
                            strides=strides,
                            padding=padding,
                            data_format=data_format,
                            name="convolution")
        
        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=data_format)
        
        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            if data_format == "NHWC":
                b_shape = [1, 1, 1, num_outputs]
            elif data_format == "NCHW":
                b_shape = [1, num_outputs, 1, 1]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")

        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out

def conv2d_transpose(x,
                     output_shape,
                     kernel_size,
                     stride=1,
                     padding="VALID",
                     data_format="NCHW",
                     normalizer_fn=None,
                     activation_fn=None,
                     weights_initializer="random_normal",
                     biases_initializer=None,
                     trainable=True,
                     scope="CONV_T"):
    with tf.name_scope(scope):
        x_shape = x.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            if data_format == "NHWC":
                input_channels = x_shape[3]
                num_outputs = output_shape[3]
            elif data_format == "NCHW":
                input_channels = x_shape[1]
                num_outputs = output_shape[1]
            W_shape = kernel_size + [num_outputs, input_channels]
            if W_init_type == "xavier": # based on output size
                layer_shape = output_shape[1:]
                n_out = tf.reduce_prod(layer_shape)
                W_init_params = [n_out]
            W_init = _get_variable_initializer(W_init_type,
                                               W_shape,
                                               *W_init_params)
        W = tf.Variable(W_init, 
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        

        # Convolute input
        stride_h, stride_w = _check_list(stride)
        if isinstance(stride_w, list):
            if len(stride_w) == 0:
                stride_w = stride_h
            else:
                stride_w = stride_w[0]
        if data_format == "NHWC":
            strides = [1, stride_h, stride_w, 1]
        elif data_format == "NCHW":
            strides = [1, 1, stride_h, stride_w]
        out = tf.nn.conv2d_transpose(x, 
                                     filter=W,
                                     output_shape=output_shape,
                                     strides=strides,
                                     padding=padding,
                                     data_format=data_format,
                                     name="convolution_transpose")
        
        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=data_format)
        
        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            if data_format == "NHWC":
                b_shape = [1, 1, 1, num_outputs]
            elif data_format == "NCHW":
                b_shape = [1, num_outputs, 1, 1]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")

        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out
    
def flatten(input_layer, 
            data_format="NCHW",
            scope="FLAT"):
    with tf.name_scope(scope):
        # Grab runtime values to determine number of elements
        input_shape = tf.shape(input_layer)
        input_ndims = input_layer.get_shape().ndims
        batch_size = tf.slice(input_shape, [0], [1])
        layer_shape = tf.slice(input_shape, [1], [input_ndims-1])
        num_neurons = tf.expand_dims(tf.reduce_prod(layer_shape), 0)
        flattened_shape = tf.concat([batch_size, num_neurons], 0)
        if data_format == "NHWC":
            input_layer = tf.transpose(input_layer, perm=[0, 3, 1, 2])
        flat = tf.reshape(input_layer, flattened_shape)
        
        # Attempt to set values during graph building
        input_shape = input_layer.get_shape().as_list()
        batch_size, layer_shape = input_shape[0], input_shape[1:]
        if all(layer_shape): # None not present
            num_neurons = 1
            for dim in layer_shape:
                num_neurons *= dim
            flat.set_shape([batch_size, num_neurons])
        else: # None present
            flat.set_shape([batch_size, None])
        return flat

def fully_connected(input_layer,
                    num_outputs,
                    normalizer_fn=None,
                    activation_fn=None,
                    weights_initializer="random_normal",
                    biases_initializer=None,
                    trainable=True,
                    scope="FC"):
    with tf.name_scope(scope):
        input_shape = input_layer.get_shape().as_list()
        
        # Create weights
        W_init_type, W_init_params = _check_list(weights_initializer)
        with tf.name_scope(W_init_type + "_initializer"):
            W_shape = [input_shape[1], num_outputs]
            if W_init_type == "xavier":
                layer_shape = input_shape[1]
                n_in = tf.reduce_prod(layer_shape)
                W_init_params = [n_in]
            W_init = _get_variable_initializer(W_init_type,
                                            W_shape,
                                            *W_init_params)
        W = tf.Variable(W_init,
                        dtype=tf.float32, 
                        trainable=trainable, 
                        name="weights")
        
        # Multiply inputs by weights
        out = tf.matmul(input_layer, W)

        # Apply normalization
        if normalizer_fn is not None:
            norm_type, norm_params = _check_list(normalizer_fn)
            out = _apply_normalization(norm_type, 
                                       out, 
                                       *norm_params,
                                       data_format=None)

        # Add biases
        elif biases_initializer is not None:
            b_init_type, b_init_params = _check_list(biases_initializer)
            b_shape = [num_outputs]
            b_init = _get_variable_initializer(b_init_type,
                                               b_shape,
                                               *b_init_params)
            b = tf.Variable(b_init,
                            dtype=tf.float32,
                            trainable=trainable,
                            name="biases")
            out = tf.add(out, b, name="BiasAdd")
       
        # Apply activation
        if activation_fn is not None:
            act_type, act_params = _check_list(activation_fn)
            out = _apply_activation(act_type, out, *act_params)

        return out

## Graph building

In [52]:
tf.reset_default_graph()

# Setttings
#data_format = "NHWC" # if using cpu
data_format = "NCHW" # if using gpu
results_dir = "./results/trial_1/"
make_directory(results_dir)
time_context = 30
feat_size = 512
num_sources = 4
eps = 1e-18 # numerical stability
alpha = 0.001 # learning rate

# Data formatting
if data_format == "NHWC":
    input_shape = [None, time_context, feat_size, 1]
    target_shape = [None, time_context, feat_size, num_sources]
    channel_dim = 3
elif data_format == "NCHW":
    input_shape = [None, 1, time_context, feat_size]
    target_shape = [None, num_sources, time_context, feat_size]
    channel_dim = 1
else:
    raise ValueError("Unknown data format \"" + data_format + "\"")
spectrogram = tf.placeholder(tf.float32, 
                             shape=input_shape, 
                             name="magnitude_spectrogram")

# Convolutional layer 1
conv1 = conv2d(spectrogram,
               num_outputs=30,
               kernel_size=[1, 30],
               stride=[1, 4],
               padding="VALID",
               data_format=data_format,
               weights_initializer="xavier",
               biases_initializer=["constant", 0.0],
               scope="CONV_1")

# Convolutional layer 2
conv2 = conv2d(conv1,
               num_outputs=30,
               kernel_size=[int(2*time_context/3), 1],
               stride=[1, 1],
               padding="VALID",
               data_format=data_format,
               weights_initializer="xavier",
               biases_initializer=["constant", 0.0],
               scope="CONV_2")
conv2_flat = flatten(conv2,
                     data_format=data_format,
                     scope="CONV_2_FLAT")

# Fully-connected layer 1 (encoding)
fc1 = fully_connected(conv2_flat,
                      num_outputs=256,
                      activation_fn="relu",
                      weights_initializer="xavier",
                      biases_initializer=["constant", 0.0],
                      scope="FC_1")

# Get shapes for building decoding layers
batch_size = tf.shape(spectrogram)[0]
conv1_shape = conv1.get_shape().as_list()
conv2_shape = conv2.get_shape().as_list()
conv2_size = conv2_shape[1] * conv2_shape[2] * conv2_shape[3]

# Build decoder for each source
fc2, convt1, convt2 = [], [], []
for i in range(num_sources):
    # Fully-connected layer 2 (decoding)
    fc2_i = fully_connected(fc1,
                            num_outputs=conv2_size,
                            activation_fn="relu",
                            weights_initializer="xavier",
                            biases_initializer=["constant", 0.0],
                            scope="FC_2_%d" % (i+1))
    fc2.append(fc2_i)
    
    # Convolutional transpose layer 1
    # Side note: tf.reshape() can infer size of one dimension given rest, so -1 okay
    #            tf.nn.conv2d_transpose() must know exact dimensions, but batch size can
    #                be inferred at runtime using tf.shape()
    fc2_i = tf.reshape(fc2_i, [-1] + conv2_shape[1:])
    convt1_i = conv2d_transpose(fc2_i,
                                output_shape=[batch_size] + conv1_shape[1:],
                                kernel_size=[int(2*time_context/3), 1],
                                stride=[1, 1],
                                padding="VALID",
                                data_format=data_format,
                                weights_initializer="xavier",
                                biases_initializer=["constant", 0.0],
                                scope="CONVT_1_%d" % (i+1))
    convt1.append(convt1_i)
    
    # Convolutional transpose layer 2
    convt2_i = conv2d_transpose(convt1_i,
                                output_shape=[batch_size] + input_shape[1:],
                                kernel_size=[1, 30],
                                stride=[1, 4],
                                padding="VALID",
                                data_format=data_format,
                                weights_initializer="xavier",
                                biases_initializer=["constant", 0.0],
                                scope="CONVT_2_%d" % (i+1))
    convt2.append(convt2_i)

# Output layer
with tf.name_scope("y_hat"):
    convt2_all = tf.concat(convt2, axis=channel_dim)
    b = tf.Variable(tf.constant(0.0, shape=[1, 1, 1, 1]),
                    dtype=tf.float32,
                    name="bias")
    y_hat = tf.maximum(tf.add(convt2_all, b), 0, name="y_hat")

# Masks: m_n(f) = |y_hat_n(f)| / Σ(|y_hat_n'(f)|)
with tf.name_scope("masks"):
    rand = tf.random_uniform([batch_size] + input_shape[1:])
    den = tf.reduce_sum(y_hat, axis=channel_dim, keep_dims=True) + (eps * rand)
    masks = tf.div(y_hat, den, name="masks") # broadcast along channel dimension
    
# Source signals: y_tilde_n(f) = m_n(f) * x(f), 
# where x(f) is the spectrogram of the input mixture signal
with tf.name_scope("y_tilde"):
    y_tilde = tf.multiply(masks, spectrogram, name="y_tilde") # broadcast along channel dimension

# Loss function: L = Σ(||y_tilde_n - target_n||^2)
with tf.name_scope("loss"):
    targets = tf.placeholder(tf.float32, 
                             shape=target_shape, 
                             name="target_sources")
    reduc_indices = [i for i in range(4) if i != channel_dim]
    loss_n = tf.reduce_sum(tf.square(y_tilde - targets), axis=reduc_indices, name="loss_n")
    loss_total = tf.reduce_sum(loss_n, name="loss_total")

# Optimizer
with tf.name_scope("train_step"):
    optimizer = tf.train.AdamOptimizer(alpha)
    train_step = optimizer.minimize(loss_total)

# Summaries
saver = tf.train.Saver(max_to_keep=5)        
graph = tf.get_default_graph()
writer = tf.summary.FileWriter(results_dir, graph)
loss_sum = []
with tf.name_scope("summaries"):
    for i in range(num_sources):
        loss_sum.append(tf.summary.scalar("loss_%d" % (i+1), loss_n[i]))
    loss_sum.append(tf.summary.scalar("loss_total", loss_total))
    loss_sum = tf.summary.merge(loss_sum)

## Training

In [53]:
def get_shape(shape_file):
    """Reads a .shape file"""
    with open(shape_file, 'rb') as f:
        line=f.readline().decode('ascii')
        if line.startswith('#'):
            shape=tuple(map(int, re.findall(r'(\d+)', line)))
            return shape
        else:
            raise IOError('Failed to find shape in file')

def create_batches(data, batch_size):
    """Reshapes data into batches of input size for network"""
    batches = []
    time_batches = data.shape[1] // time_context
    freq_batches = data.shape[2] // feat_size
    for t in range(time_batches):
        for f in range(freq_batches):
            batches.append(data[:, t*time_context:(t+1)*time_context, f*feat_size:(f+1)*feat_size])
    return np.asarray(batches)

(5, 3494, 2049)


In [55]:
# Training settings
params_dir = results_dir + "params/"
make_directory(params_dir)
input_file = "./features/02-AchLiebenChristen__m_.data"
shape_file = "./features/02-AchLiebenChristen__m_.shape"d
num_epochs = 100
batch_size = 32


f_in = np.fromfile(input_file)
if shape_file is not None:
    f_shape = get_shape(shape_file)
    f_in = np.reshape(f_in, f_shape)
input_data = create_batches(f_in[0:1], batch_size) # mixed input
target_data = create_batches(f_in[1:], batch_size) # separate sources
iter_size = len(input_data) // batch_size

# Initialize graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())

global_step = 0
for epoch in range(num_epochs):
    for i in range(iter_size):
        # Get batch from input data and target data
        input_batch = input_data[i*batch_size:(i+1)*batch_size] # magnitude spectrogram of whole
        target_batch = target_data[i*batch_size:(i+1)*batch_size] # magnitude spectrogram of sources
        
        # Perform training step
        feed_dict = {spectrogram: input_batch, targets: target_batch}
        loss_sum_, _ = sess.run([loss_sum, train_step], 
                                feed_dict=feed_dict)
        writer.add_summary(loss_sum_, global_step=global_step)
        writer.flush()
        global_step += 1
    
    # Save model after each epoch
    saver.save(sess, params_dir + "model", 
               global_step=epoch)

## Testing

In [None]:
def load_model(params_file, meta_graph):
    

In [49]:
params_file = None
meta_graph = None
sess = tf.Session()
saver = tf.train.Saver()


(348, 1, 30, 513)