In [1]:
import mxnet as mx
import os, time, shutil
import logging
from mxnet import gluon, image, init, nd
from mxnet import autograd as ag
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms
from gluoncv.model_zoo import get_model
import numpy as np
import gluoncv
from gluoncv.utils import export_block,viz,makedirs




model_name = 'xception'
root_dir = '/home/lin/sheldon/my_mxnet'
weights_name = 'gluoncv-xception'
num_classes = 25
num_gpus = 2
batch_per_gpu = 8
train_epoch = 10



batch_size = batch_per_gpu * num_gpus
rec_path = os.path.join(root_dir, 'img_file')
prefix = os.path.join(root_dir, 'weights/xception')



def get_iterators(batch_size = batch_size, data_shape=(3, 299, 299),shuffle=True):
    train_data = mx.io.ImageRecordIter(
        path_imgrec=os.path.join(rec_path, 'train.rec'),
        path_imgidx=os.path.join(rec_path, 'train.idx'),
        data_shape=data_shape,
        batch_size=batch_size,
        resize = 299,
        saturation = 0.2,
        contrast = 0.2,
        shuffle=shuffle,
        rand_mirror = True,
        brightness = 0.2,
        data_name='data',
        label_name='softmax_label',
        rotate = 180
    )
    val_data = mx.io.ImageRecordIter(
        path_imgrec=os.path.join(rec_path, 'val.rec'),
        path_imgidx=os.path.join(rec_path, 'val.idx'),
        data_shape=data_shape,
        data_name='data',
        label_name='softmax_label',
        batch_size=batch_size,
        resize = 299,
    )
    return (train_data, val_data)

In [100]:
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, 10)

mod = mx.mod.Module(symbol=sym,context=[mx.gpu(0),mx.gpu(1)],
                    data_names=["data"], label_names=None)
mod.bind(for_training=False,data_shapes=[("data",(16,3,299,299))],label_shapes=[("softmax_label",(16,))])
# 设定模型参数
mod.set_params(arg_params,aux_params,allow_missing=True)



In [114]:
def micro_f1_score(label,pred):
    from sklearn.metrics import f1_score
    from sklearn.metrics import hamming_loss
    import numpy as np
    pred = np.squeeze(pred)
    pred = np.argsort(pred,axis = 1)[:,-1]
#     result = f1_score(label, pred,average='micro',labels = np.unique(pred))
    result = f1_score(label, pred,average='micro',labels = np.unique(pred))
    return result
    
def macro_f1_score(label,pred):
    from sklearn.metrics import f1_score
    from sklearn.metrics import hamming_loss
    import numpy as np
    pred = np.squeeze(pred)
    pred = np.argsort(pred,axis = 1)[:,-1]
#     result = f1_score(label, pred,average='macro',labels = np.unique(pred))
    result = f1_score(label, pred,average='macro',labels = np.unique(pred))
    return result
    

In [115]:
eval_metrics_1 = mx.metric.Accuracy()
eval_metrics_2 = mx.metric.TopKAccuracy(5)
eval_metrics_3 = mx.metric.create(micro_f1_score)
eval_metrics_4 = mx.metric.create(macro_f1_score)
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metrics_1, eval_metrics_2,eval_metrics_3,eval_metrics_4]:
    eval_metrics.add(child_metric)
(train, val) = get_iterators(batch_size)
a = mod.score(eval_data = val,eval_metric = eval_metrics,num_batch = 1)
a = dict(a)
a

{'accuracy': 0.9375,
 'macro_f1_score': 0.9722222222222223,
 'micro_f1_score': 0.9666666666666667,
 'top_k_accuracy_5': 1.0}

In [94]:
a = dict(a)
a

{'accuracy': 0.9198998178506376,
 'macro_f1_score': 0.8641668563195138,
 'micro_f1_score': 0.9198998178506376,
 'top_k_accuracy_5': 0.9951502732240437}

In [None]:
mod.output_shapes
(train, val) = get_iterators(batch_size)
val.reset()
batch_data = val.next()

In [None]:
batch_data.label

In [None]:
val.reset()
mod.forward(val.next())

In [None]:
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
a = np.argsort(prob,axis = 1)[:,-1]
a
# a[:,0]

In [116]:
mod.predict(val,num_batch= 1)


[[3.98765520e-13 9.99844313e-01 6.19093266e-09 5.54473658e-08
  6.38566838e-11 2.15013662e-13 2.82305124e-10 9.32476469e-05
  6.21436084e-06 1.63510322e-07 3.25453982e-07 6.74017439e-11
  1.33658684e-09 4.49267446e-07 2.33182675e-11 4.80407519e-11
  4.96623294e-08 1.01876330e-08 5.50551595e-05 7.88477172e-10
  6.65730804e-10 6.48560997e-11 6.89227520e-08 3.38429035e-10
  1.82499582e-09]
 [9.31293029e-11 7.07452585e-09 1.97008835e-14 6.91337043e-10
  1.14139125e-08 2.88335336e-16 5.15910266e-11 1.13320738e-07
  1.40374496e-12 1.47252355e-09 9.44443036e-06 3.55245069e-11
  5.99424885e-11 7.47082743e-11 2.25096244e-11 1.73540187e-08
  1.35595923e-10 8.76335325e-05 5.52964993e-05 2.01976640e-08
  9.99845743e-01 5.39090203e-13 2.65926992e-09 3.09577359e-07
  1.38796941e-06]
 [3.53099605e-10 6.99830055e-11 9.67295022e-10 1.62404930e-07
  6.28601893e-10 7.45947037e-09 1.46401710e-10 2.56007793e-09
  2.23793484e-11 2.95561722e-06 9.99959826e-01 5.30654560e-08
  1.96779060e-09 4.13154788e-10 1