# Build a classifier via a sample softmax
## load features

In [None]:
import os
import numpy as np
import h5py
from tqdm import tqdm

import mxnet as mx
from mxnet.gluon import nn
from mxnet import autograd
from mxnet import gluon
from mxnet import image
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import nn

import glob
import re
import warnings
warnings.filterwarnings("ignore")

In [None]:
features_train_base_dir = './features/train'
features_valid_base_dir = './features/valid'
features_tests_base_dir = './features/tests'

In [None]:
def load_batch_data(batch_num, batch_size=32, scope="train"):
    labls_file_path = "./features/%s/labls_%03d_%03d.h5" % (scope, int(batch_size), int(batch_num))
    if scope != 'test':
        with h5py.File(labls_file_path, 'r') as f:
            labls = np.array(f['labels'])
    feats_file_regx = "./features/%s/feats_%03d_%03d_resnet*.h5" % (scope, batch_size, batch_num)
    
    feats_files = sorted(glob.glob(feats_file_regx))
    feats_all = None
    for feats_path in feats_files:
        feats_file_name = feats_path.split("/")[-1]
        with h5py.File(feats_path, 'r') as f:
            feats = np.array(f['features'])
            feats = feats.reshape((batch_size, -1))
            if feats_all is None:
                feats_all = feats
            else:
                feats_all = np.concatenate([feats_all, feats], axis=-1)
    return nd.array(feats_all), nd.array(labls)
#     return gluon.data.ArrayDataset(nd.array(feats_all), nd.array(labls))

In [None]:
gluon.data.DataLoader??

In [None]:
load_batch_data(0, scope="train")[0].shape

In [None]:
def load_data(batch_cnt, scope):
    for num in range(batch_cnt):
        X, y = load_batch_data(num, scope=scope)
        yield gluon.data.ArrayDataset(nd.array(X), nd.array(y))

In [None]:
train_batch_cnt = 697
valid_batch_cnt = 90
tests_batch_cnt = 90
train_data = load_data(train_batch_cnt,"train")
valid_data = load_data(valid_batch_cnt,"valid")
tests_data = load_data(tests_batch_cnt,"tests")

In [None]:
valid_data

In [None]:
from mxnet.gluon.model_zoo import vision as models

In [None]:
ctx = mx.gpu()

In [None]:
def get_classifier():
    num_outputs = 120
    net = gluon.nn.Sequential()
    with net.name_scope():
        net.add(nn.Dense(256, activation='relu'))
        net.add(gluon.nn.Dropout(0.5))
        net.add(gluon.nn.Dense(num_outputs))
        net.initialize(ctx=ctx)
    return net

In [None]:
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

In [None]:
def accuracy(outputs, labels):
    return nd.mean(outputs.argmax(axis=1)==labels).asscalar()

In [None]:
def evaluate_accuracy(net):
    acc = 0.
    loss = 0.
    valid_steps = 90
    for batch_index in range(valid_steps):
        feats, lables = load_batch_data(batch_index,scope='valid')
        feats, lables = feats.as_in_context(ctx), lables.as_in_context(ctx)
        outputs = net(feats)
        acc += accuracy(outputs, lables)
        loss += nd.mean(softmax_cross_entropy(outputs, lables)).asscalar()
    return loss/valid_steps, acc / valid_steps,

In [None]:
def fit(epochs=5):
    net = get_classifier()
    net.initialize(ctx=mx.gpu(), init=init.Xavier())
    trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 1e-4, 'wd': 1e-5})
    for epoch in range(epochs):
        train_loss = 0.
        train_acc = 0.
        for batch_index in range(679):
            feats, labels = load_batch_data(batch_index)
            feats, labels = feats.as_in_context(ctx), labels.as_in_context(ctx)
            with autograd.record():
                outputs = net(feats)
                loss = softmax_cross_entropy(outputs, labels)
            loss.backward()
            trainer.step(32)
            train_loss += nd.mean(loss).asscalar()
            train_acc += accuracy(outputs, labels)
            if batch_index >0 and batch_index % 20 == 0:
                print("Epoch %d. batch_index: %d. Loss: %f, Train acc %f" % (epoch, batch_index, train_loss/batch_index, train_acc/batch_index*100))

        tests_loss, tests_acc = evaluate_accuracy(net)
        print("Epoch %d. Loss: %f, Train acc %f. Test loss %f, Test acc %f" % (epoch+1, train_loss/679, train_acc/679*100, tests_loss,tests_acc))

In [None]:
fit()