* 导入packages

In [93]:
from __future__ import print_function
import os, sys, random
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt
%matplotlib inline

import mxnet as mx
from mxnet import gluon
from mxnet import ndarray as nd
from mxnet.gluon import nn, utils
from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \
    BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout
from mxnet import autograd
import numpy as np

from mxboard import SummaryWriter

import cv2
from skimage.transform import resize
from sklearn.cross_validation import train_test_split

import logging
logging.getLogger().setLevel(logging.DEBUG)


DEBUG:matplotlib.backends:backend module://ipykernel.pylab.backend_inline version unknown


* 定义数据集类

In [95]:
class TripletDataset(gluon.data.dataset.Dataset):
    def __init__(self, rd, rl, transform=None):
        self.__rd = rd  # 原始数据
        self.__rl = rl  # 原始标签
        self._data = None
        self._label = None
        self._transform = transform
        self._get_data()

    def __getitem__(self, idx):
        if self._transform is not None:
            return self._transform(self._data[idx], self._label[idx])
        return self._data[idx], self._label[idx]

    def __len__(self):
        return len(self._label)

    def _get_data(self):
        label_list = np.unique(self.__rl)
        digit_indices = [np.where(self.__rl == i)[0] for i in label_list]
        tl_pairs = self.create_pairs(self.__rd, digit_indices, len(label_list))
        self._data = tl_pairs
        self._label = np.ones(tl_pairs.shape[0])

    @staticmethod
    def create_pairs(x, digit_indices, num_classes):
        x = x.asnumpy()  # 转换数据格式
        pairs = []
        n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1  # 最小类别数
        for d in range(num_classes):
            for i in range(n):
                np.random.shuffle(digit_indices[d])
                z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
                inc = random.randrange(1, num_classes)
                dn = (d + inc) % num_classes
                z3 = digit_indices[dn][i]
                pairs += [[x[z1], x[z2], x[z3]]]
        return np.asarray(pairs)

In [96]:
def evaluate_net(model, test_data, ctx=mx.cpu() ):
    triplet_loss = gluon.loss.TripletLoss(margin=0)
    sum_correct = 0.0
    sum_all = 0
    rate = 0.0
    for i, (data, _) in enumerate(test_data):
        data = data.as_in_context(ctx)

        anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
        
        inter1 = model(anc_ins)  # 训练的时候组合
        inter2 = model(pos_ins)
        inter3 = model(neg_ins)
#         print( inter1.shape )
        loss = triplet_loss(inter1, inter2, inter3)

        loss = loss.asnumpy()
        n_all = loss.shape[0]
        n_correct = np.sum(np.where(loss == 0, 1, 0))

        sum_correct += n_correct
        sum_all += n_all
    rate = sum_correct / sum_all

#     print('accuracy : %.4f (%s / %s)' % (rate, sum_correct, sum_all))
    return rate


* np.where是找到满足条件的下标，可以加上x和y两个参数，会返回一个矩阵，满足条件的值得位置赋值为x，不满足的为y。

In [97]:
a = np.array([2,1,3])
b = np.where( a == 1, -1, -2 )
print( b )

[-2 -1 -2]


* 导入数据，设置训练的一些参数

In [98]:
ctx = mx.gpu()
batch_size = 1024
random.seed(47)
mnist_data_dir = '../dataset/mnist'

mnist_train = gluon.data.vision.MNIST(train=True, root=mnist_data_dir)  # load train data
tr_data = mnist_train._data.reshape((-1, 28 * 28)) 
tr_label = mnist_train._label  # 标签

mnist_test = gluon.data.vision.MNIST(train=False, root=mnist_data_dir)  # load test data
te_data = mnist_test._data.reshape((-1, 28 * 28))
te_label = mnist_test._label

def transform(data_, label_):
    return data_.astype(np.float32) / 255., label_.astype(np.float32)

train_data = gluon.data.DataLoader(
    TripletDataset(rd=tr_data, rl=tr_label, transform=transform),
    batch_size, shuffle=True)

test_data = gluon.data.DataLoader(
    TripletDataset(rd=te_data, rl=te_label, transform=transform),
    batch_size, shuffle=True)


base_net = gluon.nn.Sequential()
with base_net.name_scope():
    base_net.add(gluon.nn.Dense(256, activation='relu'))
    base_net.add(gluon.nn.Dense(128, activation='relu'))


* 定义triplet loss，开始训练

In [123]:
base_net.collect_params().initialize(mx.init.Uniform(scale=0.1), ctx=ctx, force_reinit=True)

triplet_loss = gluon.loss.TripletLoss()  # TripletLoss损失函数
trainer_triplet = gluon.Trainer(base_net.collect_params(), 'sgd', {'learning_rate': 0.03})

for epoch in range(100):
    curr_loss = 0.0
    for i, (data, _) in enumerate(train_data):
        data = data.as_in_context(ctx)
        anc_ins, pos_ins, neg_ins = data[:, 0], data[:, 1], data[:, 2]
        with autograd.record():
            inter1 = base_net(anc_ins)
            inter2 = base_net(pos_ins)
            inter3 = base_net(neg_ins)
            loss = triplet_loss(inter1, inter2, inter3)  # Triplet Loss
        loss.backward()
        trainer_triplet.step(batch_size)
        curr_loss = mx.nd.mean(loss).asscalar()
        # print('Epoch: %s, Batch: %s, Triplet Loss: %s' % (epoch, i, curr_loss))
    if epoch % 10 == 0:
        val_acc = evaluate_net(base_net, test_data, ctx=ctx)
        print('Epoch: %s, Triplet Loss: %s, validation accuracy : %f' % (epoch, curr_loss, val_acc))


Epoch: 0, Triplet Loss: 0.28602317, validation accuracy : 0.877329
Epoch: 10, Triplet Loss: 0.0814259, validation accuracy : 0.954658
Epoch: 20, Triplet Loss: 0.0684311, validation accuracy : 0.963524
Epoch: 30, Triplet Loss: 0.0485348, validation accuracy : 0.968126
Epoch: 40, Triplet Loss: 0.03245996, validation accuracy : 0.970034
Epoch: 50, Triplet Loss: 0.030095275, validation accuracy : 0.971829
Epoch: 60, Triplet Loss: 0.024903722, validation accuracy : 0.973625
Epoch: 70, Triplet Loss: 0.010364741, validation accuracy : 0.973513
Epoch: 80, Triplet Loss: 0.015544275, validation accuracy : 0.973962
Epoch: 90, Triplet Loss: 0.007999075, validation accuracy : 0.974411


* 使用tensorboard进行emebedding的可视化

In [124]:
trans_te_data, trans_te_label = transform(te_data, te_label)
trans_te_data = trans_te_data[0:1000]
trans_te_label = trans_te_label[0:1000]
trans_te_label = mx.nd.array( trans_te_label )
# tb_projector(trans_te_data.asnumpy(), trans_te_label, os.path.join(ROOT_DIR, 'logs', 'origin'))
# 如果需要看初始时刻的embedding情况，可以强制初始化
# base_net.collect_params().initialize(mx.init.Uniform(scale=0.1), ctx=ctx, force_reinit=True)
trans_te_res = base_net(trans_te_data.as_in_context( context=ctx ))

# 转换成4D数据 NCHW
trans_te_data = trans_te_data.reshape( (-1,28,28))
trans_te_data = mx.nd.expand_dims( trans_te_data, axis=(1) )

label_str = [str(int(idx)) for idx in trans_te_label.asnumpy()]

with SummaryWriter(logdir='./logs') as sw:
    sw.add_image(tag='mnists', image=trans_te_data)
    sw.add_embedding(tag='mnist_codes', embedding=trans_te_res, images=trans_te_data, labels=label_str)

INFO:mxboard.event_file_writer:successfully opened events file: ./logs/events.out.tfevents.1544098248.ININ-Z640
INFO:mxboard.writer:saved embedding labels to ./logs/mnist_codes
INFO:mxboard.event_file_writer:wrote 1 event to disk
INFO:mxboard.writer:saved embedding images to ./logs/mnist_codes
INFO:mxboard.event_file_writer:wrote 1 event to disk
INFO:mxboard.writer:saved embedding data to ./logs/mnist_codes


In [125]:
print( evaluate_net(base_net, test_data, ctx=ctx) )

0.9750841750841751
