# DeepLoc Mxnet Port

In [None]:
import mxnet as mx
import numpy as np
import time
import mxnet.ndarray as nd
import mxnet.initializer as init
from mxnet import npx, autograd, optimizer, gluon
from mxnet.gluon import nn, rnn
from mxboard import SummaryWriter


## Parameters

In [None]:
epoch = 200                #-- integer, epoch
batch_size = 32            #-- integer, minibatches size
max_seq_size = 1000        #-- integer, maximum sequence size
n_hid = 256                #-- integer, number of hidden neurons
n_feat = 20                #-- integer, number of features encoded  X_test.shape[2]
n_class = 10               #-- integer, number of classes to output
lr = 0.0005                #-- float, learning rate
drop_per  = 0.2            #-- float, input dropout
drop_hid = 0.5             #-- float, hidden neurons dropout
n_filt_1 = 20              #-- integer, number of filter in the first convolutional layer
n_filt_2 = 128             #-- integer, number of filter in the second convolutional layer
seed     = 123456          #-- seed
loss_fn  = 'cross_entropy' #-- 'cross_entropy' or 'cosine' loss function
test_data_set = 'large'    #-- 'large or small'
train_data_set = 'large'   #-- 'large or small'


## Initialization

In [None]:
ctx = mx.gpu() if mx.context.num_gpus() else mx.cpu()
mx.random.seed(seed)
npx.random.seed(seed)
np.random.seed(seed)


def generate_run_id():
    t = time.gmtime()
    return "{0}{1:0>2d}{2:0>2d}-{3:0>2d}{4:0>2d}".format(t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min)


## Training Set and Test Set

In [None]:
if train_data_set == 'large':
    train_file = 'data/deeploc_full.npz'
else:
    train_file = 'subcellular_localization/data/train.npz'
    
if test_data_set == 'large':
    test_file = 'data/deeploc_full.npz'
else:
    test_file = 'subcellular_localization/data/test.npz'

train_npz = np.load(train_file)
test_npz = np.load(test_file)


In [None]:
#mask_train = nd.from_numpy(train_npz['mask_train']).as_in_context(ctx)
#partition = nd.from_numpy(train_npz['partition']).as_in_context(ctx)
#X_train = nd.from_numpy(train_npz['X_train']).as_in_context(ctx)
#y_train = nd.from_numpy(train_npz['y_train']).as_in_context(ctx)
#X_test = nd.from_numpy(test_npz['X_test']).as_in_context(ctx)
#mask_test = nd.from_numpy(test_npz['mask_test']).as_in_context(ctx)
#y_test = nd.from_numpy(test_npz['y_test']).as_in_context(ctx)

mask_train = train_npz['mask_train']
partition = train_npz['partition']
X_train = train_npz['X_train']
y_train = train_npz['y_train']
X_test = test_npz['X_test']
mask_test = test_npz['mask_test']
y_test = test_npz['y_test']

In [None]:
train_npz.close()
test_npz.close()

## Network

### Convoluted Layers

In [None]:

class ConvLayer(nn.Block):
    
    def __init__(self, **kwargs):
        super(ConvLayer, self).__init__(**kwargs)
        
        with self.name_scope():
            self.l_conv_01 = nn.Conv1D(prefix='01_', channels=n_filt_1, kernel_size=1,  padding=0,  layout='NCW', activation='relu')
            self.l_conv_03 = nn.Conv1D(prefix='03_', channels=n_filt_1, kernel_size=3,  padding=1,  layout='NCW', activation='relu')
            self.l_conv_05 = nn.Conv1D(prefix='05_', channels=n_filt_1, kernel_size=5,  padding=2,  layout='NCW', activation='relu')
            self.l_conv_09 = nn.Conv1D(prefix='09_', channels=n_filt_1, kernel_size=9,  padding=4,  layout='NCW', activation='relu')
            self.l_conv_15 = nn.Conv1D(prefix='15_', channels=n_filt_1, kernel_size=15, padding=7,  layout='NCW', activation='relu')
            self.l_conv_21 = nn.Conv1D(prefix='21_', channels=n_filt_1, kernel_size=21, padding=10, layout='NCW', activation='relu')
            
            self.l_conv_final = nn.Conv1D(prefix='conc_', channels=n_filt_2, kernel_size=3, padding=1, layout='NCW', activation='relu')
        
    def forward(self, x):
        a = self.l_conv_01(x)
        b = self.l_conv_03(x)
        c = self.l_conv_05(x)
        d = self.l_conv_09(x)
        e = self.l_conv_15(x)
        f = self.l_conv_21(x)
        
        conc = nd.concat(a, b, c, d, e, f, dim=1)
        
        return self.l_conv_final(conc)   


### Bidirectional LSTM Layer

In [None]:

class BidirectionalLSTM(nn.Block):

    def __init__(self, **kwargs):
        super(BidirectionalLSTM, self).__init__(**kwargs)
        
        with self.name_scope():
            self.l_fwd = rnn.LSTM(hidden_size=n_hid, layout='TNC', prefix='Fwd_')
            self.l_bck = rnn.LSTM(hidden_size=n_hid, layout='TNC', prefix='Bck_')

    def _reverse(self, x, mask):
        x_1 = nd.empty(x.shape, ctx=ctx)
        nd.reset_arrays(x_1, num_arrays=1)
        for i in range(x.shape[0]):
            size = nd.sum(mask[i]).astype('int32').asscalar()
            seq = x[i]
            seq_1 = nd.reverse(nd.slice(seq, begin=(0,0), end=(size, n_filt_2)), axis=1)
            seq_1.copyto(x_1[i, :size]) 
        return x_1            
            
    def forward(self, x, mask):
        dimension = (1, 0, 2)
        x_fwd = nd.transpose(x, dimension)
        x_bck = nd.transpose(self._reverse(x, mask), dimension)
        fwd = self.l_fwd(x_fwd)
        bck = self.l_bck(x_bck)
        conc = nd.concat(fwd, bck, dim=2)
        return nd.transpose(conc, dimension)


### Decoder with Attention Mechanism

In [None]:

class LSTMAttentionDecodeFeedback(nn.Block):
    def __init__(self,
                 num_units,
                 aln_num_units,
                 n_decodesteps=10,
                 **kwargs):
        
        super(LSTMAttentionDecodeFeedback, self).__init__(**kwargs)
        
        self.num_units = num_units
        self.aln_num_units = aln_num_units
        self.n_decodesteps = n_decodesteps
        self.attention_softmax_function = nd.softmax
        self.peepholes = True

        self.num_inputs = 512
        
        self.nonlinearity_align=nd.tanh
        
        self.nonlinearity_ingate = nd.sigmoid
        self.nonlinearity_forgetgate = nd.sigmoid
        self.nonlinearity_cell = nd.tanh
        self.nonlinearity_outgate = nd.sigmoid
        
        self.nonlinearity_out = nd.tanh
        
        self.W_hid_to_ingate = self.params.get('W_hid_to_ingate', shape=(num_units, num_units),
                                               init=init.Normal(0.1),
                                               allow_deferred_init=True)
        
        self.W_hid_to_forgetgate = self.params.get('W_hid_to_forgetgate', shape=(num_units, num_units),
                                                   init=init.Normal(0.1),
                                                   allow_deferred_init=True)
        
        self.W_hid_to_cell = self.params.get('W_hid_to_cell', shape=(num_units, num_units),
                                             init=init.Normal(0.1),
                                             allow_deferred_init=True)
        
        self.W_hid_to_outgate = self.params.get('W_hid_to_outgate', shape=(num_units, num_units),
                                                init=init.Normal(0.1),
                                                allow_deferred_init=True)
        
        self.b_ingate = self.params.get('b_ingate', shape=(num_units),
                                        init=init.Constant(0),
                                        allow_deferred_init=True)

        self.b_forgetgate = self.params.get('b_forgetgate', shape=(num_units),
                                            init=init.Constant(0),
                                            allow_deferred_init=True)

        self.b_cell = self.params.get('b_cell', shape=(num_units),
                                      init=init.Constant(0),
                                      allow_deferred_init=True)
        
        self.b_outgate = self.params.get('b_outgate', shape=(num_units),
                                         init=init.Constant(0),
                                         allow_deferred_init=True)
        
        self.W_weightedhid_to_ingate = self.params.get('W_weightedhid_to_ingate',
                                                      shape=(self.num_inputs, num_units),
                                                      init=init.Normal(0.1),
                                                      allow_deferred_init=True)
        
        self.W_weightedhid_to_forgetgate = self.params.get('W_weightedhid_to_forgetgate',
                                                           shape=(self.num_inputs, num_units),
                                                           init=init.Normal(0.1),
                                                           allow_deferred_init=True)
        
        self.W_weightedhid_to_cell = self.params.get('W_weightedhid_to_cell',
                                                     shape=(self.num_inputs, num_units),
                                                     init=init.Normal(0.1),
                                                     allow_deferred_init=True)
        
        self.W_weightedhid_to_outgate = self.params.get('W_weightedhid_to_outgate',
                                                        shape=(self.num_inputs, num_units),
                                                        init=init.Normal(0.1),
                                                        allow_deferred_init=True)
        
        self.W_cell_to_ingate = self.params.get('W_cell_to_ingate',
                                                shape=(num_units),
                                                init=init.Normal(0.1),
                                                allow_deferred_init=True)
        
        self.W_cell_to_forgetgate = self.params.get('W_cell_to_forgetgate',
                                                    shape=(num_units),
                                                    init=init.Normal(0.1),
                                                    allow_deferred_init=True)
        
        self.W_cell_to_outgate = self.params.get('W_cell_to_outgate',
                                                 shape=(num_units),
                                                 init=init.Normal(0.1),
                                                 allow_deferred_init=True)
        
        self.W_align = self.params.get('W_align',
                                       shape=(num_units, self.aln_num_units),
                                       init=init.Normal(0.1))
        
        self.U_align = self.params.get('U_align', shape=(self.num_inputs,self.aln_num_units),
                                       init=init.Normal(0.1),
                                       allow_deferred_init=True)
        
        self.v_align = self.params.get('v_align', shape=(self.aln_num_units, 1),
                                       init=init.Normal(0.1))
        
        with self.name_scope():
            pass

    def slice_w(self, x, n):
        return x[:, n*self.num_units:(n+1)*self.num_units]
    
    def step(self, cell_previous, hid_previous, alpha_prev, weighted_hidden_prev,
            input, mask, hUa, W_align, v_align,
            W_hid_stacked, W_weightedhid_stacked, W_cell_to_ingate,
            W_cell_to_forgetgate, W_cell_to_outgate,
            b_stacked, *args):
        
        sWa = nd.dot(hid_previous, W_align)  # (BS, aln_num_units)
        sWa = nd.expand_dims(sWa, axis=1)    # (BS, 1 aln_num_units) 
        align_act = sWa + hUa
        tanh_sWahUa = nd.tanh(align_act)     # (BS, seqlen, num_units_aln)
        
        # CALCULATE WEIGHT FOR EACH HIDDEN STATE VECTOR
        a = nd.dot(tanh_sWahUa, v_align)  # (BS, Seqlen, 1)
        a = nd.reshape(a, (a.shape[0], a.shape[1]))
        #                                # (BS, Seqlen)
        # # ->(BS, seq_len)
        
        a = a*mask - (1-mask)*10000
        
        alpha = self.attention_softmax_function(a)
        
        # input: (BS, Seqlen, num_units)
        weighted_hidden = input * nd.expand_dims(alpha, axis=2)
        weighted_hidden = nd.sum(weighted_hidden, axis=1)  #sum seqlen out

        # (BS, dec_hid) x (dec_hid, dec_hid)
        gates = nd.dot(hid_previous, W_hid_stacked) + b_stacked
        # (BS, enc_hid) x (enc_hid, dec_hid)
        gates = gates + nd.dot(weighted_hidden, W_weightedhid_stacked)

        
        # Clip gradients
        # if self.grad_clipping is not False:
        #    gates = theano.gradient.grad_clip(
        #        gates, -self.grad_clipping, self.grad_clipping)

        # Extract the pre-activation gate values
        ingate = self.slice_w(gates, 0)
        forgetgate = self.slice_w(gates, 1)
        cell_input = self.slice_w(gates, 2)
        outgate = self.slice_w(gates, 3)

        if self.peepholes:
            # Compute peephole connections
            ingate = ingate + cell_previous*W_cell_to_ingate
            forgetgate = forgetgate + (cell_previous*W_cell_to_forgetgate)
            
        # Apply nonlinearities
        ingate = self.nonlinearity_ingate(ingate)
        forgetgate = self.nonlinearity_forgetgate(forgetgate)
        cell_input = self.nonlinearity_cell(cell_input)
        outgate = self.nonlinearity_outgate(outgate)
        
        # Compute new cell value
        cell = forgetgate*cell_previous + ingate*cell_input
        
        if self.peepholes:
            outgate = outgate + cell*W_cell_to_outgate

        # W_align:  (num_units, aln_num_units)
        # U_align:  (num_feats, aln_num_units)
        # v_align:  (aln_num_units, 1)
        # hUa:      (BS, Seqlen, aln_num_units)
        # hid:      (BS, num_units_dec)
        # input:    (BS, Seqlen, num_inputs)

        # Compute new hidden unit activation
        hid = outgate*self.nonlinearity_out(cell)

        return [cell, hid, alpha, weighted_hidden]            
            
        
    def forward(self, input, mask):
        
        num_batch = input.shape[0]
        encode_seqlen = input.shape[1]
        
        W_hid_stacked = nd.concat(
            self.W_hid_to_ingate.data(),
            self.W_hid_to_forgetgate.data(),
            self.W_hid_to_cell.data(),
            self.W_hid_to_outgate.data(),
            dim=1)
        
        W_weightedhid_stacked = nd.concat(
            self.W_weightedhid_to_ingate.data(),
            self.W_weightedhid_to_forgetgate.data(),
            self.W_weightedhid_to_cell.data(),
            self.W_weightedhid_to_outgate.data(),
            dim=1)
        
        b_stacked = nd.concat(
            self.b_ingate.data(),
            self.b_forgetgate.data(),
            self.b_cell.data(),
            self.b_outgate.data(),
            dim=0)
        
        cell = nd.zeros((num_batch, self.num_units), ctx=ctx)
        hid = nd.zeros((num_batch, self.num_units), ctx=ctx)
        alpha = nd.zeros((num_batch, encode_seqlen), ctx=ctx)
        weighted_hidden = nd.zeros((num_batch, self.num_units), ctx=ctx)
        
        hUa = nd.dot(input, self.U_align.data())
        W_align = self.W_align.data()
        v_align = self.v_align.data()
        
        W_cell_to_ingate = self.W_cell_to_ingate.data()
        W_cell_to_forgetgate = self.W_cell_to_forgetgate.data()
        W_cell_to_outgate = self.W_cell_to_outgate.data()
        
        for i in range(self.n_decodesteps):        
            cell, hid, alpha, weighted_hidden = self.step(cell, hid, alpha, weighted_hidden,
                input, mask, hUa, W_align, v_align,
                W_hid_stacked, W_weightedhid_stacked, W_cell_to_ingate,
                W_cell_to_forgetgate, W_cell_to_outgate,
                b_stacked)
        
        return weighted_hidden
    

## DeepLoc Model

In [None]:

class Model(nn.Block):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        
        with self.name_scope():
            self.l_dropout_1 = nn.Dropout(rate=drop_per)
            self.l_dropout_2 = nn.Dropout(rate=drop_hid)
            self.l_dropout_3 = nn.Dropout(rate=drop_hid)
            self.l_dropout_4 = nn.Dropout(rate=drop_hid)
            self.l_conv = ConvLayer(prefix='Conv_')
            self.l_lstm = BidirectionalLSTM(prefix='BLSTM_')
            self.l_dense = nn.Dense(units=n_class, activation='relu')
            self.l_decoder = LSTMAttentionDecodeFeedback(
                              prefix='Decoder_',
                              num_units=2*n_hid, aln_num_units=n_hid, n_decodesteps=10)
    
    def forward(self, input, mask):
        x = self.l_dropout_1.forward(input)
        x = nd.transpose(x, (0, 2, 1))
        x = self.l_conv.forward(x)
        x = nd.transpose(x, (0, 2, 1))
        x = self.l_dropout_2.forward(x)
        x = self.l_lstm.forward(x, mask)
        x = self.l_decoder(x, mask)
        x = self.l_dropout_3.forward(x)
        x = self.l_dense.forward(x)
        x = self.l_dropout_4.forward(x)
        
        return x
        

## Confusion Matrix

In [None]:
import numpy as np


class ConfusionMatrix:
    """
       Simple confusion matrix class
       row is the true class, column is the predicted class
    """
    def __init__(self, num_classes, class_names=None):
        self.n_classes = num_classes
        if class_names is None:
            self.class_names = map(str, range(num_classes))
        else:
            self.class_names = class_names

        # find max class_name and pad
        max_len = max(map(len, self.class_names))
        self.max_len = max_len
        for idx, name in enumerate(self.class_names):
            if len(self.class_names) < max_len:
                self.class_names[idx] = name + " "*(max_len-len(name))

        self.mat = np.zeros((num_classes,num_classes),dtype='int')

    def __str__(self):
        # calucate row and column sums
        col_sum = np.sum(self.mat, axis=1)
        row_sum = np.sum(self.mat, axis=0)

        s = []

        mat_str = self.mat.__str__()
        mat_str = mat_str.replace('[','').replace(']','').split('\n')

        for idx, row in enumerate(mat_str):
            if idx == 0:
                pad = " "
            else:
                pad = ""
            class_name = self.class_names[idx]
            class_name = " " + class_name + " |"
            row_str = class_name + pad + row
            row_str += " |" + str(col_sum[idx])
            s.append(row_str)

        row_sum = [(self.max_len+4)*" "+" ".join(map(str, row_sum))]
        hline = [(1+self.max_len)*" "+"-"*len(row_sum[0])]

        s = hline + s + hline + row_sum

        # add linebreaks
        s_out = [line+'\n' for line in s]
        return "".join(s_out)

    def batch_add(self, targets, preds):
        assert targets.shape == preds.shape
        assert len(targets) == len(preds)
        assert max(targets) < self.n_classes
        assert max(preds) < self.n_classes
        targets = targets.flatten()
        preds = preds.flatten()
        for i in range(len(targets)):
            self.mat[targets[i], preds[i]] += 1
    def ret_mat(self):
        return self.mat

    def get_errors(self):
        tp = np.asarray(np.diag(self.mat).flatten(),dtype='float')
        fn = np.asarray(np.sum(self.mat, axis=1).flatten(),dtype='float') - tp
        fp = np.asarray(np.sum(self.mat, axis=0).flatten(),dtype='float') - tp
        tn = np.asarray(np.sum(self.mat)*np.ones(self.n_classes).flatten(),
                        dtype='float') - tp - fn - fp
        return tp, fn, fp, tn

    def accuracy(self):
        """
        Calculates global accuracy
        :return: accuracy
        :example: >>> conf = ConfusionMatrix(3)
                  >>> conf.batchAdd([0,0,1],[0,0,2])
                  >>> print conf.accuracy()
        """
        tp, _, _, _ = self.get_errors()
        n_samples = np.sum(self.mat)
        return np.sum(tp) / n_samples

    def sensitivity(self):
        tp, tn, fp, fn = self.get_errors()
        res = tp / (tp + fn)
        res = res[~np.isnan(res)]
        return res

    def specificity(self):
        tp, tn, fp, fn = self.get_errors()
        res = tn / (tn + fp)
        res = res[~np.isnan(res)]
        return res

    def positive_predictive_value(self):
        tp, tn, fp, fn = self.get_errors()
        res = tp / (tp + fp)
        res = res[~np.isnan(res)]
        return res

    def negative_predictive_value(self):
        tp, tn, fp, fn = self.get_errors()
        res = tn / (tn + fn)
        res = res[~np.isnan(res)]
        return res

    def false_positive_rate(self):
        tp, tn, fp, fn = self.get_errors()
        res = fp / (fp + tn)
        res = res[~np.isnan(res)]
        return res

    def false_discovery_rate(self):
        tp, tn, fp, fn = self.get_errors()
        res = fp / (tp + fp)
        res = res[~np.isnan(res)]
        return res

    def F1(self):
        tp, tn, fp, fn = self.get_errors()
        res = (2*tp) / (2*tp + fp + fn)
        res = res[~np.isnan(res)]
        return res

    def matthews_correlation(self):
        tp, tn, fp, fn = self.get_errors()
        numerator = tp*tn - fp*fn
        denominator = np.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn))
        res = numerator / denominator
        res = res[~np.isnan(res)]
        return res
    def OMCC(self):
        tp, tn, fp, fn = self.get_errors()
        tp = np.sum(tp)
        tn = np.sum(tn)
        fp = np.sum(fp)
        fn = np.sum(fn)
        numerator = tp*tn - fp*fn
        denominator = np.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn))
        res = numerator / denominator
        res = res[~np.isnan(res)]
        return res

### Training loop

In [None]:

run_id = generate_run_id()
print(run_id)

hparams_path='./models/{}_hparams.json'.format(run_id)
net_params_path='./models/{}.params'.format(run_id)
logdir='./logs/{}/{}'.format(train_data_set, run_id)
import json

dict = {'run_id' : run_id,
        'epoch' : epoch,
        'batch_size' : batch_size,
        'n_hid' : n_hid,
        'n_feat' : n_feat,
        'n_class' : n_class,
        'lr' : lr,
        'drop_per' : drop_per,
        'drop_hid' : drop_hid,
        'n_filt_1' : n_filt_1,
        'n_filt_2' : n_filt_2,
        'seed' : seed,
        'loss_fn' : loss_fn,
        'train_data_set' : train_data_set}


json = json.dumps(dict)
f = open(hparams_path,"w")
f.write(json)
f.close()

sw = SummaryWriter(logdir=logdir)
    

if (loss_fn == 'cosine'):
  loss_function = gluon.loss.CosineEmbeddingLoss()  
elif (loss_fn == 'cross_entropy'):
  loss_function =  gluon.loss.SoftmaxCrossEntropyLoss()



net = Model(prefix='net_')
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

params = net.collect_params()
params.reset_ctx([ctx])

trainer = gluon.Trainer(params=params,
                        optimizer='adam', optimizer_params={'learning_rate':lr})

for p in range(1, 5):
    
    # Train and validation sets
    train_index = np.where(partition != p)
    val_index = np.where(partition == p)
    X_tr = nd.from_numpy(X_train[train_index].astype(np.float32)).as_in_context(ctx)
    X_val = nd.from_numpy(X_train[val_index].astype(np.float32)).as_in_context(ctx)
    y_tr = nd.from_numpy(y_train[train_index].astype(np.int32)).as_in_context(ctx)
    y_val = nd.from_numpy(y_train[val_index].astype(np.int32)).as_in_context(ctx)
    mask_tr = nd.from_numpy(mask_train[train_index].astype(np.float32)).as_in_context(ctx)
    mask_val = nd.from_numpy(mask_train[val_index].astype(np.float32)).as_in_context(ctx)
    
    train_iter = mx.io.NDArrayIter([X_tr, mask_tr], y_tr, batch_size, shuffle=True)
    val_iter = mx.io.NDArrayIter([X_val, mask_val], y_val, batch_size, shuffle=False)
    
    eps = []
    best_val_acc = 0
    
    for e in range(1, epoch + 1):
        step = ((p - 1) * epoch) + e
        begin_time = time.perf_counter()
        train_loss = 0.
        train_acc = mx.metric.Accuracy()
        train_iter.reset()
        val_iter.reset()
        
        # Full pass training set
        train_err = 0
        train_batches = 0
        confusion_train = ConfusionMatrix(n_class)
        
        for batch in train_iter:
            input = batch.data[0]
            mask = batch.data[1]
            label = batch.label[0]

            with mx.autograd.record():
                output = net(input, mask)

                if (loss_fn == 'cosine'):
                    # Cosine loss
                    l = loss_function(output, nd.one_hot(label, n_class), nd.array([1.0], ctx=ctx))
                else:
                    # Softmax cross entropy
                    l = loss_function(output, label)
                    
            l.backward()
            trainer.step(batch_size)

            train_err += l.mean().asscalar()
            preds = output.argmax(axis=1)
            train_acc.update(label, preds)
            train_batches += 1
            np_label = label.astype('int32').asnumpy()
            np_preds = preds.astype('int32').asnumpy()
            confusion_train.batch_add(np_label, np_preds)

        stop_time = time.perf_counter()
        train_time = stop_time - begin_time
        # train_acc = train_acc.get()[1] 
        # avg_train_loss = train_loss/train_batches

        train_loss = train_err / train_batches
        train_accuracy = confusion_train.accuracy()
        cf_train = confusion_train.ret_mat()
        
        sw.add_scalar(tag='train_time', value=train_time, global_step=step)
        sw.add_scalar(tag='train_loss', value=train_loss, global_step=step)
        sw.add_scalar(tag='train_accuracy', value=train_accuracy, global_step=step)

        param_names = net.collect_params().keys()
        grads = [i.grad() for i in net.collect_params().values()]
        assert len(grads) == len(param_names)
        # logging the gradients of parameters for checking convergence
        for i, name in enumerate(param_names):
            sw.add_histogram(tag=name, values=grads[i], global_step=step, bins=1000)

        print("%d,%.5f,%.5f,%.5f" % (e, train_time, train_accuracy, train_loss))

        # Full pass validation set
        val_err = 0
        val_batches = 0
        val_acc = mx.metric.Accuracy()
        confusion_valid = ConfusionMatrix(n_class)

        for batch in val_iter:
            input = batch.data[0]
            mask = batch.data[1]
            label = batch.label[0]

            with mx.autograd.predict_mode():
                output = net(input, mask)

                if (loss_fn == 'cosine'):
                    # Cosine loss
                    l = loss_function(output, nd.one_hot(label, n_class), nd.array([1.0], ctx=ctx))
                else:
                    # Softmax cross entropy
                    l = loss_function(output, label)
                
            preds = output.argmax(axis=1)
            np_label = label.asnumpy()
            val_acc.update(label, preds)
            np_preds = preds.astype('int32').asnumpy()
            confusion_valid.batch_add(np_label, np_preds)
            val_batches += 1
            val_err += l.mean().asscalar()
            
        val_loss = val_err / val_batches
        val_accuracy = confusion_valid.accuracy()
        cf_val = confusion_valid.ret_mat()

        sw.add_scalar(tag='val_loss', value=val_loss, global_step=step)
        sw.add_scalar(tag='val_accuracy', value=val_accuracy, global_step=step)
        
        f_val_acc = val_accuracy
            
        # Full pass test set if validation accuracy is higher
        if f_val_acc >= best_val_acc:
            test_batches = 0
            
            confusion_test = ConfusionMatrix(n_class)
            
            mask_nd = nd.from_numpy(mask_test.astype(np.float32)).as_in_context(ctx)
            X_nd = nd.from_numpy(X_test.astype(np.float32)).as_in_context(ctx)
            y_nd = nd.from_numpy(y_test.astype(np.float32)).as_in_context(ctx)
    
            test_iter = mx.io.NDArrayIter([X_nd, mask_nd], y_nd, batch_size, shuffle=False)
            
            for batch in test_iter:
                input = batch.data[0]
                mask = batch.data[1]
                label = batch.label[0]

                with mx.autograd.predict_mode():
                    output = net(input, mask)

                preds = output.argmax(axis=1)
                np_label = label.astype('int32').asnumpy()
                np_preds = preds.astype('int32').asnumpy()
                confusion_test.batch_add(np_label, np_preds)
            
            print(confusion_test.accuracy())
            print(confusion_test.ret_mat())

            best_val_acc = f_val_acc
        
            net.save_parameters(net_params_path)

sw.close()
    

## Prediction

In [None]:
# small - cross entropy
# 20200705-02-57.params

# full - cosine
# 20200802-0216.params

net_params_path = 'models/20200802-0216.params'

mask = nd.from_numpy(mask_test.astype(np.float32)).as_in_context(ctx)
X    = nd.from_numpy(X_test.astype(np.float32)).as_in_context(ctx)
y    = nd.from_numpy(y_test.astype(np.int32)).as_in_context(ctx)

test_iter = mx.io.NDArrayIter([X, mask], y, 5, shuffle=False)

net = Model(prefix='net_')
net.load_parameters(net_params_path, ctx=ctx)

confusion_test = ConfusionMatrix(10)

for batch in test_iter:
    input = batch.data[0]
    mask = batch.data[1]
    label = batch.label[0]

    with mx.autograd.predict_mode():
        output = net(input, mask)
        
    preds = output.argmax(axis=1)
    np_label = label.asnumpy()
    np_preds = preds.astype('int32').asnumpy()
    confusion_test.batch_add(np_label, np_preds)

print("Accuracy: %.5f" % (confusion_test.accuracy()))
print(confusion_test.ret_mat())

