In [None]:
from __future__ import print_function
import os
import numpy as np
import csv
import time
import pickle
import cPickle
import h5py
import argparse
import soundfile
import librosa
import matplotlib.pyplot as plt
from scipy import signal
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import config as cfg

## Вспомогательные функции

In [None]:
def create_folder(fd):
    if not os.path.exists(fd):
        os.makedirs(fd)

def get_filename(path):
    path = os.path.realpath(path)
    na_ext = path.split('/')[-1]
    na = os.path.splitext(na_ext)[0]
    return na

In [None]:
def read_audio(path, target_fs=None):
    """Read 1 dimension audio sequence from given path. 
    
    Args:
      path: string, path of audio. 
      target_fs: int, resampling rate. 
      
    Returns:
      audio: 1 dimension audio sequence. 
      fs: sampling rate of audio. 
    """
    (audio, fs) = soundfile.read(path)
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if target_fs is not None and fs != target_fs:
        audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs)
        fs = target_fs
    return audio, fs

def write_audio(path, audio, sample_rate):
    """Write audio sequence to .wav file. 
    
    Args:
      path: string, path to write out .wav file. 
      data: ndarray, audio sequence to write out. 
      sample_rate: int, sample rate to write out. 
      
    Returns: 
      None. 
    """
    soundfile.write(file=path, data=audio, samplerate=sample_rate)

In [None]:
def spectrogram(audio):
    """Calculate magnitude spectrogram of an audio sequence. 
    
    Args: 
      audio: 1darray, audio sequence. 
      
    Returns:
      x: ndarray, spectrogram (n_time, n_freq)
    """
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    
    ham_win = np.hamming(n_window)
    [f, t, x] = signal.spectral.spectrogram(
                    audio, 
                    window=ham_win,
                    nperseg=n_window, 
                    noverlap=n_overlap, 
                    detrend=False, 
                    return_onesided=True, 
                    mode='magnitude') 
    x = x.T
    x = x.astype(np.float32)
    return x

def logmel(audio):
    """Calculate log Mel spectrogram of an audio sequence. 
    
    Args: 
      audio: 1darray, audio sequence. 
      
    Returns:
      x: ndarray, log Mel spectrogram (n_time, n_freq)
    """
    n_window = cfg.n_window
    n_overlap = cfg.n_overlap
    fs = cfg.sample_rate
    
    ham_win = np.hamming(n_window)
    [f, t, x] = signal.spectral.spectrogram(
                    audio, 
                    window=ham_win,
                    nperseg=n_window, 
                    noverlap=n_overlap, 
                    detrend=False, 
                    return_onesided=True, 
                    mode='magnitude') 
    x = x.T
                    
    if globals().get('melW') is None:
        global melW
        melW = librosa.filters.mel(sr=fs, 
                                n_fft=n_window, 
                                n_mels=229, 
                                fmin=0, 
                                fmax=fs / 2.)
    x = np.dot(x, melW.T)
    x = np.log(x + 1e-8)
    x = x.astype(np.float32)
    return x

In [None]:
def calculate_features(args): 
    """Calculate and write out features & ground truth notes of all songs in MUS 
    directory of all pianos. 
    """
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    feat_type = args.feat_type
    fs = cfg.sample_rate
    tr_pianos = cfg.tr_pianos
    te_pianos = cfg.te_pianos
    pitch_bgn = cfg.pitch_bgn
    pitch_fin = cfg.pitch_fin
    
    out_dir = os.path.join(workspace, "features", feat_type)
    create_folder(out_dir)
    
    # Calculate features for all 9 pianos. 
    cnt = 0
    for piano in tr_pianos + te_pianos:
        audio_dir = os.path.join(dataset_dir, piano, "MUS")
        wav_names = [na for na in os.listdir(audio_dir) if na.endswith('.wav')]
        
        for wav_na in wav_names:
            # Read audio. 
            bare_na = os.path.splitext(wav_na)[0]
            wav_path = os.path.join(audio_dir, wav_na)
            (audio, _) = read_audio(wav_path, target_fs=fs)
            
            # Calculate feature. 
            if feat_type == "spectrogram":
                x = spectrogram(audio)
            elif feat_type == "logmel":
                x = logmel(audio)
            else:
                raise Exception("Error!")
            
            # Read piano roll from txt file. 
            (n_time, n_freq) = x.shape
            txt_path = os.path.join(audio_dir, "%s.txt" % bare_na)
            roll = txt_to_midi_roll(txt_path, max_fr_len=n_time)    # (n_time, 128)
            y = roll[:, pitch_bgn : pitch_fin]      # (n_time, 88)
            
            # Write out data. 
            data = [x, y]
            out_path = os.path.join(out_dir, "%s.p" % bare_na)
            print(cnt, out_path, x.shape, y.shape)
            cPickle.dump(data, open(out_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)
            cnt += 1

In [None]:
def is_in_pianos(na, list_of_piano):
    """E.g., na="MAPS_MUS-alb_esp2_SptkBGCl.wav", list_of_piano=['SptkBGCl', ...]
    then return True. 
    """
    for piano in list_of_piano:
        if piano in na:
            return True
    return False

In [None]:
def pack_features(args):
    """Pack already calculated features and write out to a big file, for 
    speeding up later loading. 
    """
    workspace = args.workspace
    feat_type = args.feat_type
    tr_pianos = cfg.tr_pianos
    te_pianos = cfg.te_pianos
    
    fe_dir = os.path.join(workspace, "features", feat_type)
    fe_names = os.listdir(fe_dir)
    
    # Load all single feature files and append to list. 
    tr_x_list, tr_y_list, tr_na_list = [], [], []
    te_x_list, te_y_list, te_na_list = [], [], []
    t1 = time.time()
    cnt = 0
    for fe_na in fe_names:
        print(cnt)
        bare_na = os.path.splitext(fe_na)[0]
        fe_path = os.path.join(fe_dir, fe_na)
        [x, y] = cPickle.load(open(fe_path, 'rb'))
        
        if is_in_pianos(fe_na, tr_pianos):
            tr_x_list.append(x)
            tr_y_list.append(y)
            tr_na_list.append("%s.wav" % bare_na)
        elif is_in_pianos(fe_na, te_pianos):
            te_x_list.append(x)
            te_y_list.append(y)
            te_na_list.append("%s.wav" % bare_na)
        else:
            raise Exception("File not in tr_pianos or te_pianos!")
        cnt += 1
    
    # Write out the big file. 
    out_dir = os.path.join(workspace, "packed_features", feat_type)
    create_folder(out_dir)
    tr_packed_feat_path = os.path.join(out_dir, "train.p")
    te_packed_feat_path = os.path.join(out_dir, "test.p")
    
    cPickle.dump([tr_x_list, tr_y_list, tr_na_list], open(tr_packed_feat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)
    cPickle.dump([te_x_list, te_y_list, te_na_list], open(te_packed_feat_path, 'wb'), protocol=cPickle.HIGHEST_PROTOCOL)
    print("Packing time: %s s" % (time.time() - t1,))

In [None]:
def compute_scaler(args):
    """Compute and write out scaler from already packed feature file. Using 
    scaler in training neural network can speed up training. 
    """
    workspace = args.workspace
    feat_type = args.feat_type
    
    # Load packed features. 
    t1 = time.time()
    packed_feat_path = os.path.join(workspace, "packed_features", feat_type, "train.p")
    [x_list, _, _] = cPickle.load(open(packed_feat_path, 'rb'))
    
    # Compute scaler. 
    x_all = np.concatenate(x_list)
    scaler = preprocessing.StandardScaler(with_mean=True, with_std=True).fit(x_all)
    print(scaler.mean_)
    print(scaler.scale_)
    
    # Save out scaler. 
    out_path = os.path.join(workspace, "scalers", feat_type, "scaler.p")
    create_folder(os.path.dirname(out_path))
    pickle.dump(scaler, open(out_path, 'wb'))
    print("Compute scaler finished! %s s" % (time.time() - t1,))
    
def scale_on_x_list(x_list, scaler): 
    """Scale list of ndarray. 
    """
    return [scaler.transform(e) for e in x_list]

In [None]:
def data_to_3d(x_list, y_list, n_concat, n_hop):
    """Convert data to 3d tensor. 
    
    Args: 
      x_list: list of ndarray, e.g., [(N1, n_freq), (N2, n_freq), ...]
      y_list: list of ndarray, e.g., [(N1, 88), (N2, 88), ...]
      n_concat: int, number of frames to concatenate. 
      n_hop: int, hop frames. 
      
    Returns:
      x_all: (n_samples, n_concat, n_freq)
      y_all: (n_samples, n_out)
    """
    x_all, y_all = [], []
    n_half = (n_concat - 1) / 2
    for e in x_list:
        x3d = mat_2d_to_3d(e, n_concat, n_hop)
        x_all.append(x3d)
        
    for e in y_list:
        y3d = mat_2d_to_3d(e, n_concat, n_hop)
        y_all.append(y3d)
        
    x_all = np.concatenate(x_all, axis=0)   # (n_samples, n_concat, n_freq)
    y_all = np.concatenate(y_all, axis=0)   # (n_samples, n_concat, n_out)
    y_all = y_all[:, n_half, :]     # (n_samples, n_out)
    return x_all, y_all
    
def mat_2d_to_3d(x, agg_num, hop):
    """Convert data to 3d tensor. 
    
    Args: 
      x: 2darray, e.g., (N, n_in)
      agg_num: int, number of frames to concatenate. 
      hop: int, hop frames. 
      
    Returns:
      x3d: 3darray, e.g., (n_samples, agg_num, n_in)
    """
    # pad to at least one block
    len_x, n_in = x.shape
    if (len_x < agg_num):
        x = np.concatenate((x, np.zeros((agg_num-len_x, n_in))))
        
    # agg 2d to 3d
    len_x = len(x)
    i1 = 0
    x3d = []
    while (i1+agg_num <= len_x):
        x3d.append(x[i1:i1+agg_num])
        i1 += hop
    x3d = np.array(x3d)
    return x3d

In [None]:
def txt_to_midi_roll(txt_path, max_fr_len):
    """Read txt to piano roll. 
    
    Args: 
      txt_path: string, path of note info txt. 
      max_fr_len: int, should be the same as the number of frames of calculated 
          feature. 
          
    Returns:
      midi_roll: (n_time, 108)
    """
    step_sec = cfg.step_sec
    
    with open(txt_path, 'rb') as f:
        reader = csv.reader(f, delimiter='\t')
        lis = list(reader)

    midi_roll = np.zeros((max_fr_len, 128))
    for i1 in xrange(1, len(lis)):
        # Read a note info from a line. 
        [onset_time, offset_time, midi_pitch] = lis[i1]
        onset_time = float(onset_time)
        offset_time = float(offset_time)
        midi_pitch = int(midi_pitch)
        
        # Write a note info to midi roll. 
        onset_fr = int(np.floor(onset_time / step_sec))
        offset_fr = int(np.ceil(offset_time / step_sec)) + 1
        midi_roll[onset_fr : offset_fr, midi_pitch] = 1
        
    return midi_roll

def prob_to_midi_roll(x, thres):
    """Threshold input probability to binary, then convert piano roll (n_time, 88) 
    to midi roll (n_time, 108). 
    
    Args:
      x: (n_time, n_pitch)    
    """
    pitch_bgn = cfg.pitch_bgn
    x_bin = np.zeros_like(x)
    x_bin[np.where(x >= thres)] = 1
    n_time = x.shape[0]
    out = np.zeros((n_time, 128))
    out[:, pitch_bgn : pitch_bgn + 88] = x_bin
    return out    

def write_midi_roll_to_midi(x, out_path):
    """Write out midi_roll to midi file. 
    
    Args: 
      x: (n_time, n_pitch), midi roll. 
      out_path: string, path to write out the midi. 
    """
    step_sec = cfg.step_sec
    
    def _get_bgn_fin_pairs(ary):
        pairs = []
        bgn_fr, fin_fr = -1, -1
        for i2 in xrange(1, len(ary)):
            if ary[i2-1] == 0 and ary[i2] == 0:
                pass
            elif ary[i2-1] == 0 and ary[i2] == 1:
                bgn_fr = i2
            elif ary[i2-1] == 1 and ary[i2] == 0:
                fin_fr = i2
                if fin_fr > bgn_fr:
                    pairs.append((bgn_fr, fin_fr))
            elif ary[i2-1] == 1 and ary[i2] == 1:
                pass
            else:
                raise Exception("Input must be binary matrix!")
            
        return pairs
    
    # Get (pitch, bgn_frame, fin_frame) triple. 
    triples = []
    (_, n_pitch) = x.shape
    for i1 in xrange(n_pitch):
        ary = x[:, i1]
        pairs_per_pitch = _get_bgn_fin_pairs(ary)
        if pairs_per_pitch:
            triples_per_pitch = [(i1,) + pair for pair in pairs_per_pitch]
            triples += triples_per_pitch
    
    # Sort by begin frame. 
    triples = sorted(triples, key=lambda x: x[1])
    
    # Write out midi. 
    MyMIDI = MIDIFile(1)    # Create the MIDIFile Object with 1 track
    track = 0   
    time = 0
    tempo = 120
    beat_per_sec = 60. / float(tempo)
    MyMIDI.addTrackName(track, time, "Sample Track")  # Add track name 
    MyMIDI.addTempo(track, time, tempo)   # Add track tempo
    
    for triple in triples:
        (midi_pitch, bgn_fr, fin_fr) = triple
        bgn_beat = bgn_fr * step_sec / float(beat_per_sec)
        fin_beat = fin_fr * step_sec / float(beat_per_sec)
        dur_beat = fin_beat - bgn_beat
        MyMIDI.addNote(track=0,     # The track to which the note is added.
                    channel=0,   # the MIDI channel to assign to the note. [Integer, 0-15]
                    pitch=midi_pitch,    # the MIDI pitch number [Integer, 0-127].
                    time=bgn_beat,      # the time (in beats) at which the note sounds [Float].
                    duration=dur_beat,  # the duration of the note (in beats) [Float].
                    volume=100)  # the volume (velocity) of the note. [Integer, 0-127].
    out_file = open(out_path, 'wb')
    MyMIDI.writeFile(out_file)
    out_file.close()

In [None]:
def tp_fn_fp_tn(p_y_pred, y_gt, thres, average):
    """
    Args:
      p_y_pred: shape = (n_samples,) or (n_samples, n_classes)
      y_gt: shape = (n_samples,) or (n_samples, n_classes)
      thres: float between 0 and 1. 
      average: None (element wise) | 'micro' (calculate metrics globally) 
        | 'macro' (calculate metrics for each label then average). 
      
    Returns:
      tp, fn, fp, tn or list of tp, fn, fp, tn. 
    """
    if p_y_pred.ndim == 1:
        y_pred = np.zeros_like(p_y_pred)
        y_pred[np.where(p_y_pred > thres)] = 1.
        tp = np.sum(y_pred + y_gt > 1.5)
        fn = np.sum(y_gt - y_pred > 0.5)
        fp = np.sum(y_pred - y_gt > 0.5)
        tn = np.sum(y_pred + y_gt < 0.5)
        return tp, fn, fp, tn
    elif p_y_pred.ndim == 2:
        tps, fns, fps, tns = [], [], [], []
        n_classes = p_y_pred.shape[1]
        for j1 in xrange(n_classes):
            (tp, fn, fp, tn) = tp_fn_fp_tn(p_y_pred[:, j1], y_gt[:, j1], thres, None)
            tps.append(tp)
            fns.append(fn)
            fps.append(fp)
            tns.append(tn)
        if average is None:
            return tps, fns, fps, tns
        elif average == 'micro' or average == 'macro':
            return np.sum(tps), np.sum(fns), np.sum(fps), np.sum(tns)
        else: 
            raise Exception("Incorrect average arg!")
    else:
        raise Exception("Incorrect dimension!")

In [None]:
def prec_recall_fvalue(p_y_pred, y_gt, thres, average):
    """
    Args:
      p_y_pred: shape = (n_samples,) or (n_samples, n_classes)
      y_gt: shape = (n_samples,) or (n_samples, n_classes)
      thres: float between 0 and 1. 
      average: None (element wise) | 'micro' (calculate metrics globally) 
        | 'macro' (calculate metrics for each label then average). 
      
    Returns:
      prec, recall, fvalue | list or prec, recall, fvalue. 
    """
    eps = 1e-10
    if p_y_pred.ndim == 1:
        (tp, fn, fp, tn) = tp_fn_fp_tn(p_y_pred, y_gt, thres, average=None)
        prec = tp / max(float(tp + fp), eps)
        recall = tp / max(float(tp + fn), eps)
        fvalue = 2 * (prec * recall) / max(float(prec + recall), eps)
        return prec, recall, fvalue
    elif p_y_pred.ndim == 2:
        n_classes = p_y_pred.shape[1]
        if average is None or average == 'macro':
            precs, recalls, fvalues = [], [], []
            for j1 in xrange(n_classes):
                (prec, recall, fvalue) = prec_recall_fvalue(p_y_pred[:, j1], y_gt[:, j1], thres, average=None)
                precs.append(prec)
                recalls.append(recall)
                fvalues.append(fvalue)
            if average is None:
                return precs, recalls, fvalues
            elif average == 'macro':
                return np.mean(precs), np.mean(recalls), np.mean(fvalues)
        elif average == 'micro':
            (prec, recall, fvalue) = prec_recall_fvalue(p_y_pred.flatten(), y_gt.flatten(), thres, average=None)
            return prec, recall, fvalue
        else:
            raise Exception("Incorrect average arg!")
    else:
        raise Exception("Incorrect dimension!")

In [None]:
import numpy as np

class DataGenerator(object):
    def __init__(self, batch_size, type, te_max_iter=None):
        assert type in ['train', 'test']
        self._batch_size_ = batch_size
        self._type_ = type
        self._te_max_iter_ = te_max_iter
        
    def generate(self, xs, ys):
        x = xs[0]
        y = ys[0]
        batch_size = self._batch_size_
        n_samples = len(x)
        
        index = np.arange(n_samples)
        np.random.shuffle(index)
        
        iter = 0
        epoch = 0
        pointer = 0
        while True:
            if (self._type_ == 'test') and (self._te_max_iter_ is not None):
                if iter == self._te_max_iter_:
                    break
            iter += 1
            if pointer >= n_samples:
                epoch += 1
                if (self._type_) == 'test' and (epoch == 1):
                    break
                pointer = 0
                np.random.shuffle(index)                
 
            batch_idx = index[pointer : min(pointer + batch_size, n_samples)]
            pointer += batch_size
            yield x[batch_idx], y[batch_idx]

## Функции для обучения и валидации

In [None]:
def uniform_weights(m):
    classname = m.__class__.__name__    
    if classname.find('Linear') != -1:
        scale = 0.1
        m.weight.data = torch.nn.init.uniform(m.weight.data, -scale, scale)
        m.bias.data.fill_(0.)

def glorot_uniform_weights(m):
    classname = m.__class__.__name__    
    if classname.find('Linear') != -1:
        # w = torch.nn.init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
        w = torch.nn.init.xavier_uniform(m.weight.data)
        m.weight.data = w
        m.bias.data.fill_(0.)  

In [None]:
def eval(model, gen, xs, ys, cuda):
    model.eval()
    pred_all = []
    y_all = []
    for (batch_x, batch_y) in gen.generate(xs=xs, ys=ys):
        batch_x = torch.Tensor(batch_x)
        batch_x = Variable(batch_x, volatile=True)
        if cuda:
            batch_x = batch_x.cuda()
        pred = model(batch_x)
        pred = pred.data.cpu().numpy()
        pred_all.append(pred)
        y_all.append(batch_y)
        
    pred_all = np.concatenate(pred_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)
    (prec, recall, fvalue) = prec_recall_fvalue(pred_all, y_all, thres=0.5, average='micro')
    
        
    print("prec: %f, recall: %f, fvalue: %f" % (prec, recall, fvalue))

In [None]:
class Net(nn.Module):
    def __init__(self, n_concat, n_freq, n_out):
        super(Net, self).__init__()
        n_in = n_concat * n_freq
        n_hid = 500
        
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, n_hid)
        self.fc3 = nn.Linear(n_hid, n_hid)
        self.fc4 = nn.Linear(n_hid, n_out)
        
    def forward(self, x):
        drop_p = 0.2
        x1 = x.view(len(x), -1)
        x2 = F.dropout(F.relu(self.fc1(x1)), p=drop_p, training=self.training)
        x3 = F.dropout(F.relu(self.fc2(x2)), p=drop_p, training=self.training)
        x4 = F.dropout(F.relu(self.fc3(x3)), p=drop_p, training=self.training)
        x5 = F.sigmoid(self.fc4(x4))
        return x5

In [None]:
def train(args):
    cuda = args.use_cuda and torch.cuda.is_available()
    workspace = args.workspace
    feat_type = args.feat_type
    lr = args.lr
    resume_model_path = args.resume_model_path
    script_na = args.script_na
    print("cuda:", cuda)

    # Load data. 
    t1 = time.time()
    tr_packed_feat_path = os.path.join(workspace, "packed_features", feat_type, "train.p")
    te_packed_feat_path = os.path.join(workspace, "packed_features", feat_type, "test.p")
    [tr_x_list, tr_y_list, _] = cPickle.load(open(tr_packed_feat_path, 'rb'))
    [te_x_list, te_y_list, _] = cPickle.load(open(te_packed_feat_path, 'rb'))
    print("Loading packed feature time: %s s" % (time.time() - t1,))
        
    # Scale. 
    if True:
        scale_path = os.path.join(workspace, "scalers", feat_type, "scaler.p")
        scaler = pickle.load(open(scale_path, 'rb'))
        tr_x_list = scale_on_x_list(tr_x_list, scaler)
        te_x_list = scale_on_x_list(te_x_list, scaler)
    
    # Data to 3d. 
    n_concat = 3
    n_hop = 1
    (tr_x, tr_y) = data_to_3d(tr_x_list, tr_y_list, n_concat, n_hop)
    (te_x, te_y) = data_to_3d(te_x_list, te_y_list, n_concat, n_hop)
    n_freq = tr_x.shape[-1]
    n_out = tr_y.shape[-1]
    print(tr_x.shape, tr_y.shape)
    
    # Model. 
    model = Net(n_concat, n_freq, n_out)
    
    if os.path.isfile(resume_model_path):
        # Load weights. 
        print("Loading checkpoint '%s'" % resume_model_path)
        checkpoint = torch.load(resume_model_path)
        model.load_state_dict(checkpoint['state_dict'])
        iter = checkpoint['iter']
    else:
        # Randomly init weights. 
        print("Train from random initialization. ")
        model.apply(glorot_uniform_weights)
        iter = 0
    
    # Move model to GPU. 
    if cuda:
        model.cuda()
    
    # Optimizer. 
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    
    # Data Generator
    batch_size = 500
    tr_gen = DataGenerator(batch_size=batch_size, type='train')
    eval_tr_gen = DataGenerator(batch_size=batch_size, type='test', te_max_iter=20)
    eval_te_gen = DataGenerator(batch_size=batch_size, type='test')
    
    iters_per_epoch = len(tr_x) / batch_size
    print("Iters_per_epoch: %d" % iters_per_epoch)
    
    # Train. 
    eps = 1e-8
    for (batch_x, batch_y) in tr_gen.generate(xs=[tr_x], ys=[tr_y]):
        if iter % (1000) == 0:
            print("\n--- Evaluation of training set (subset), iteration: %d ---" % iter)
            eval(model, eval_tr_gen, [tr_x], [tr_y], cuda)
            print("--- Evaluation of testing set, iteration: %d ---" % iter)
            eval(model, eval_te_gen, [te_x], [te_y], cuda)
            print("-----------------------------------------------\n")
        
        # Move data to GPU. 
        t1 = time.time()
        batch_x = torch.Tensor(batch_x)
        batch_y = torch.Tensor(batch_y)
        batch_x = Variable(batch_x)
        batch_y = Variable(batch_y)
        if cuda:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
        optimizer.zero_grad()
        model.train()
        output = model(batch_x)
        output = torch.clamp(output, eps, 1. - eps)
        
        loss = F.binary_cross_entropy(output, batch_y)
        loss.backward()
        optimizer.step()
        
        if iter % 200 == 0:
            print("Iter: %d loss: %f" % (iter, loss))
        
        iter += 1
        
        # Save model. 
        if iter % 1000 == 0:
            save_out_dict = {'iter': iter, 
                             'state_dict': model.state_dict(), 
                             'optimizer': optimizer.state_dict(), }
            save_out_path = os.path.join(workspace, "models", script_na, feat_type, "md_%diters.tar" % iter)
            create_folder(os.path.dirname(save_out_path))
            torch.save(save_out_dict, save_out_path)
            print("Save model to %s" % save_out_path)
            
        # Stop training. 
        if iter == 10001:
            break

In [None]:
def inference(args):
    cuda = args.use_cuda and torch.cuda.is_available()
    workspace = args.workspace
    model_name = args.model_name
    feat_type = args.feat_type
    script_na = args.script_na

    # Load data. 
    te_packed_feat_path = os.path.join(workspace, "packed_features", feat_type, "test.p")
    [te_x_list, te_y_list, te_na_list] = cPickle.load(open(te_packed_feat_path, 'rb'))
        
    # Scale. 
    if True:
        scale_path = os.path.join(workspace, "scalers", feat_type, "scaler.p")
        scaler = pickle.load(open(scale_path, 'rb'))
        te_x_list = scale_on_x_list(te_x_list, scaler)
        
    # Construct model topology. 
    n_concat = 3
    te_n_hop = 1
    n_freq = te_x_list[0].shape[-1]
    n_out = te_y_list[0].shape[-1]
    model = Net(n_concat, n_freq, n_out)
    
    # Init the weights of model using trained weights. 
    model_path = os.path.join(workspace, "models", script_na, feat_type, model_name)
    if os.path.isfile(model_path):
        print("Loading checkpoint '%s'" % model_path)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise Exception("Model path %s does not exist!" % model_path)
        
    # Move model to GPU. 
    if cuda:
        model.cuda()
        
    # Directory to write out transcript midi files. 
    out_midi_dir = os.path.join(workspace, "out_midis", get_filename(__file__), feat_type)
    create_folder(out_midi_dir)
        
    # Data to 3d. 
    n_half = (n_concat - 1) / 2
    for i1 in xrange(len(te_x_list)):
        x = te_x_list[i1]   # (n_time, n_freq)
        y = te_y_list[i1]   # (n_time, n_out)
        bare_na = os.path.splitext(te_na_list[i1])[0]
        (n_time, n_freq) = x.shape
        
        zero_pad = np.zeros((n_half, n_freq))
        x = np.concatenate((zero_pad, x, zero_pad), axis=0)
        x3d = mat_2d_to_3d(x, n_concat, te_n_hop)     # (n_time, n_concat, n_freq)
        
        # Move data to GPU. 
        x3d = torch.Tensor(x3d)
        x3d = Variable(x3d)
        if cuda:
            x3d = x3d.cuda()
        
        # Inference. 
        model.eval()
        pred = model(x3d)   # (n_time, n_out)
        
        # Convert data type to numpy. 
        pred = pred.data.cpu().numpy()
        
        # Threshold and write out predicted piano roll to midi file. 
        mid_roll = prob_to_midi_roll(pred, 0.5)
        out_path = os.path.join(out_midi_dir, "%s.mid" % bare_na)
        print("Write out to: %s" % out_path)
        write_midi_roll_to_midi(mid_roll, out_path)
        
        # Debug plot. 
        if True:
            fig, axs = plt.subplots(3,1, sharex=True)
            axs[0].matshow(y.T, origin='lower', aspect='auto')
            axs[1].matshow(pred.T, origin='lower', aspect='auto')
            binary_pred = (np.sign(pred - 0.5) + 1) / 2
            axs[2].matshow(binary_pred.T, origin='lower', aspect='auto')
            axs[0].set_title("Ground truth")
            axs[1].set_title("DNN output probability")
            axs[2].set_title("DNN output probability after thresholding")
            for j1 in xrange(3):
                axs[j1].set_ylabel('note index')
                axs[j1].set_xlabel('frames')
                axs[j1].xaxis.set_label_coords(1.06, -0.01)
                axs[j1].xaxis.tick_bottom()
            plt.tight_layout()
            plt.show()
