#### referenced by - https://github.com/kuleshov/audio-super-res

# Training ASR model

In [1]:
import os
os.sys.path.append(os.path.abspath('.'))
os.sys.path.append(os.path.dirname(os.path.abspath('.')))
import numpy as np
import matplotlib
from asr_model import ASRNet, default_opt
from io_utils import upsample_wav
from io_utils import load_h5
import tensorflow as tf
#matplotlib.use('Agg')

In [2]:
args = {
    'train'      : 'train.h5',
    'val'        : 'valid.h5',
    'alg'        : 'adam',
    'epochs'     : 10,
    'logname'    : 'default_log_name',
    'layers'     : 4,
    'lr'         : 0.0005,
    'batch_size' : 100
}
print(tf.__version__)

1.5.0


In [3]:
# get data
X_train, Y_train = load_h5(args['train'])
X_val, Y_val = load_h5(args['val'])

List of arrays in input file: KeysView(<HDF5 file "train.h5" (mode r)>)
Shape of X: (852, 16384, 1)
Shape of Y: (852, 16384, 1)
List of arrays in input file: KeysView(<HDF5 file "valid.h5" (mode r)>)
Shape of X: (287, 16384, 1)
Shape of Y: (287, 16384, 1)


In [4]:
# determine super-resolution level
n_dim_y, n_chan_y = Y_train[0].shape
n_dim_x, n_chan_x = X_train[0].shape
print('number of dimension Y:',n_dim_y)
print('number of channel Y:',n_chan_y)
print('number of dimension X:',n_dim_x)
print('number of channel X:',n_chan_x)
r = int(Y_train[0].shape[0] / X_train[0].shape[0])
print('r:',r)
n_chan = n_chan_y
n_dim = n_dim_y
assert n_chan == 1 # if not number of channel is not 0 -> Error assert!

number of dimension Y: 16384
number of channel Y: 1
number of dimension X: 16384
number of channel X: 1
r: 1


In [5]:
# create model
def get_model(args, n_dim, r, from_ckpt=False, train=True):
    """Create a model based on arguments"""
    
    if train:
        opt_params = {
            'alg' : args['alg'], 
            'lr' : args['lr'], 
            'b1' : 0.9, 
            'b2' : 0.999,
            'batch_size': args['batch_size'], 
            'layers': args['layers']}
    else: 
        opt_params = default_opt

    # create model & init
    model = ASRNet(
        from_ckpt=from_ckpt, 
        n_dim=n_dim, 
        r=r,
        opt_params=opt_params, 
        log_prefix=args['logname'])
    
    return model

model = get_model(args, n_dim, r, from_ckpt=False, train=True)

>> Generator Model init...
D-Block >>  Tensor("generator/Relu:0", shape=(?, ?, 32), dtype=float32)
D-Block >>  Tensor("generator/Relu_1:0", shape=(?, ?, 48), dtype=float32)
D-Block >>  Tensor("generator/Relu_2:0", shape=(?, ?, 64), dtype=float32)
D-Block >>  Tensor("generator/Relu_3:0", shape=(?, ?, 64), dtype=float32)
B-Block >>  Tensor("generator/Relu_4:0", shape=(?, ?, 64), dtype=float32)
U-Block >>  Tensor("generator/concat:0", shape=(?, ?, 128), dtype=float32)
U-Block >>  Tensor("generator/concat_1:0", shape=(?, ?, 128), dtype=float32)
U-Block >>  Tensor("generator/concat_2:0", shape=(?, ?, 96), dtype=float32)
U-Block >>  Tensor("generator/concat_3:0", shape=(?, ?, 64), dtype=float32)
Fin-Layer >>  Tensor("generator/Add:0", shape=(?, ?, 1), dtype=float32)
>> ...finish

creating train_op with params: {'b1': 0.9, 'batch_size': 100, 'layers': 4, 'lr': 0.0005, 'b2': 0.999, 'alg': 'adam'}


In [6]:
# train model
model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args['epochs'])

start training epoch (n:10)
num-of-batch: 100
count 1 / obj: 0.011600 / snr: 19.355580
count 2 / obj: 0.008685 / snr: 20.612528
count 3 / obj: 0.009123 / snr: 20.398688
count 4 / obj: 0.008165 / snr: 20.880310
count 5 / obj: 0.007995 / snr: 20.972055
count 6 / obj: 0.008121 / snr: 20.903914
count 7 / obj: 0.008175 / snr: 20.875330
count 8 / obj: 0.008135 / snr: 20.896617
count 9 / obj: 0.007721 / snr: 21.123295

Epoch 1 of 10 took 215.091s (8 minibatches)
  training l2_loss/segsnr:		0.008504	14.878838
  validation l2_loss/segsnr:		0.004127	15.779288
-----------------------------------------------------------------------
count 1 / obj: 0.008232 / snr: 20.844777
count 2 / obj: 0.008422 / snr: 20.745691
count 3 / obj: 0.008368 / snr: 20.773929
count 4 / obj: 0.010174 / snr: 19.924932
count 5 / obj: 0.008494 / snr: 20.709130
count 6 / obj: 0.008225 / snr: 20.848514
count 7 / obj: 0.007763 / snr: 21.099658
count 8 / obj: 0.008255 / snr: 20.832661
count 9 / obj: 0.007708 / snr: 21.130599

Ep