In [45]:
import os
import argparse
import logging
logging.basicConfig(level=logging.ERROR)
from common import find_mxnet, data, fit
from common.util import download_file
import mxnet as mx
import numpy as np
import gzip, struct
from mxnet.symbol import *

In [77]:
parser = argparse.ArgumentParser(description="train mnist",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-classes', type=int, default=10,
                        help='the number of classes')
fit.add_fit_args(parser)
parser.set_defaults(
    # train
    gpus           = '2',
    batch_size     = 100,
    disp_batches   = 100,
    num_epochs     = 20000,
    num_examples   = 100, 
    wd             = 1e-3,
    lr             = .001,
    lr_factor      = .33,
    optimizer      = 'adam',
    lr_step_epochs = '5000,10000,15000',
)
args = parser.parse_args("")
unsup_multiplier = 1
labeled_per_class = 10
sample_seed = 47
val_interval = 100
num_unsup_examples = 60000
num_sup_examples = labeled_per_class * args.num_classes
sup_batch_size = args.batch_size

In [78]:
def read_data(label, image):
    """
    download and read data into numpy
    """
    base_url = 'http://yann.lecun.com/exdb/mnist/'
    with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return (label, image)


def to4d(img):
    """
    reshape to 4D arrays
    """
    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255

def sample_by_label(images, labels, n_per_label, num_labels, seed=None):
    """Extract equal number of sampels per class."""
    res_img = []
    res_lbl = []
    rng = np.random.RandomState(seed=seed)
    for i in range(num_labels):
        a = images[labels == i]
        
        if n_per_label == -1:  # use all available labeled data
            res_img.append(a)
        else:  # use randomly chosen subset
            choice = rng.choice(len(a), n_per_label, False)
            print(choice)
            r = a[choice]
            
            res_img.append(r)
            
            lbls = np.ones(n_per_label) * i
            res_lbl.append(lbls)
    return (res_img, res_lbl)


"""
create data iterator with NDArrayIter
"""
(train_lbl, train_img) = read_data(
        'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
(val_lbl, val_img) = read_data(
        't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')

(imgs, lbls) = sample_by_label(to4d(train_img), train_lbl, labeled_per_class, 10, seed = sample_seed)
imgs = np.vstack(imgs)
lbls = np.hstack(lbls)

train_sup = mx.io.NDArrayIter(
    imgs, lbls, args.batch_size, shuffle=True, data_name='dataSup', label_name='labelSup')
train_unsup = mx.io.NDArrayIter(
    to4d(train_img), label=None, batch_size=args.batch_size, 
    shuffle=True, data_name='dataUnsup')
val = mx.io.NDArrayIter(
    to4d(val_img), val_lbl, args.batch_size*1) # use larger test batch size

[5840 5656  576 3662 3626 4448 4352 5659 4106 4176]
[1103 5635 3920 6070 6075 3763 6509 1527 1191 2650]
[5128 5073 2352 2550  164 5421 5858 2597 1559 3745]
[1750 1338 2470 1011 2191  774 4962 5773 6060 1810]
[2295 5622 5409  555 5314 2464  909 3597 1125 3279]
[5183 2728 2781 3977  345 4871 4113 1403 2388 1468]
[5142 1629  953 4917  950  672 2646 5433 3551 5325]
[5726 2479 1952 6120 4552 2228 2253 4124 4963 3627]
[4854  222  260 4818 1262 3203 5607  990 5167 1794]
[3534   41 4560 5218 4152  972 5041 2605  640 3111]


In [79]:
class Multi_mnist_iterator(mx.io.DataIter):
    '''multi label mnist iterator'''

    def __init__(self, supIterator, unsupIterator):
        super(Multi_mnist_iterator, self).__init__()
        self.supIterator = supIterator
        self.unsupIterator = unsupIterator
        self.batch_size = self.supIterator.batch_size
        
        self.reset_counter = 0
        self.reset_multiplier = int(np.floor(num_unsup_examples / num_sup_examples / unsup_multiplier))
        print(self.reset_multiplier)

    @property
    def provide_data(self):
        iters = [self.supIterator.provide_data[0]]
        
        for i in range(unsup_multiplier):
            d = self.unsupIterator.provide_data[0]
            desc = mx.io.DataDesc('dataUnsup'+str(i), d.shape, d.dtype, d.layout)
            iters.append(desc)
        
        return iters

    @property
    def provide_label(self):
        return self.supIterator.provide_label

    def hard_reset(self):
        self.supIterator.hard_reset()
        self.unsupIterator.hard_reset()

    def reset(self):
        self.supIterator.reset()
        self.reset_counter = self.reset_counter + 1
        
        # only reset unsup iterator if all images have been traversed
        # samples in batches are shuffled, but always in the same (shuffled) order after a reset
        # in most cases the unsup iterator has a lot more images, so it should be reset less often
        if self.reset_counter % self.reset_multiplier == 0:
            self.unsupIterator.reset()

    def next(self):
        batch0 = self.supIterator.next()
        
        data = [batch0.data[0]]
        
        for i in range(unsup_multiplier):
            batch = self.unsupIterator.next()
            data.append(batch.data[0])
            
        label = batch0.label

        return mx.io.DataBatch(data=data, label=label, \
                pad=batch0.pad, index=batch0.index)
    
train = Multi_mnist_iterator(train_sup, train_unsup)
2
#[d.asnumpy().mean() for d in train.next().data]

600


2

In [80]:
labels = Variable('labels')

equality_matrix = broadcast_equal(reshape(labels, shape=(-1,1)), labels, name="eqmat")
    
equality_matrix = cast(equality_matrix, dtype='float32')

batch=train.next()

labels = batch.label[0]
bDataSup = batch.data[0]
bDataUnsup = batch.data[1]
#print(equality_matrix.eval(labels=labels)[0].asnumpy()[0:10,0:10])
#print(labels.asnumpy())

c = mod.forward()

# first fullc
flatten = mx.symbol.Flatten(data=pool)
embeddings = mx.symbol.FullyConnected(data=flatten, num_hidden=128)

splitted = split(embeddings, num_outputs=2, axis=0)

a = splitted[0]
b = splitted[1]

match_ab = dot(a,transpose(b),name='match_ab')
p_ab = softmax(match_ab, name='p_ab')
p_ba = softmax(transpose(match_ab), name='p_ba')
p_aba = dot(p_ab, p_ba, name='p_aba')

print(bDataUnsup)

NameError: name 'mod' is not defined

In [None]:
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(1, 1)):
    conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
    act = mx.symbol.Activation(data=conv, act_type='relu')
    return act

def build_embeddings(data, nembeddings=128, add_stn=False, **kwargs):
    if(add_stn):
        data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
                                         transform_type="affine", sampler_type="bilinear")
    # first conv
    conv = ConvFactory(data=data, kernel=(3,3), num_filter=32)
    conv = ConvFactory(data=conv, kernel=(3,3), num_filter=32)
    pool = mx.symbol.Pooling(data=conv, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    
    conv = ConvFactory(data=pool, kernel=(3,3), num_filter=64)
    conv = ConvFactory(data=conv, kernel=(3,3), num_filter=64)
    pool = mx.symbol.Pooling(data=conv, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    
    conv = ConvFactory(data=pool, kernel=(3,3), num_filter=128)
    conv = ConvFactory(data=conv, kernel=(3,3), num_filter=128)
    pool = mx.symbol.Pooling(data=conv, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    
    # first fullc
    flatten = mx.symbol.Flatten(data=pool)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=nembeddings)
    
    return fc1

In [None]:
def getshape(tensor):
    arg_shape, output_shape, aux_shape = tensor.infer_shape(labelSup=(128,))
    print(output_shape)
def getshapeData(tensor):
    arg_shape, output_shape, aux_shape = tensor.infer_shape(
        dataUnsup=(128,3,28,28),dataSup=(128,3,28,28))
    print(output_shape)
    
def compute_visit_loss(p, visit_weight=1):
    
    visit_probability = mean(p, axis=(0), keepdims=True, name='visit_prob')
    
    t_nb = sup_batch_size * unsup_multiplier
    
    init = mx.initializer.Constant(t_nb)
    t_nb = var('t_nb', init=init, dtype='float32', shape=(1))
    
    target = broadcast_div(ones_like(visit_probability), t_nb)
    
#     arg_shape, output_shape, aux_shape = visit_probability.infer_shape(
#         dataUnsup0=(128,3,28,28),dataSup=(128,3,28,28))
#     print(output_shape)
#     arg_shape, output_shape, aux_shape = target.infer_shape(
#         dataUnsup0=(128,3,28,28),dataSup=(128,3,28,28))
#     print(output_shape)
#     arg_shape, output_shape, aux_shape = t_nb.infer_shape()
#     print(output_shape)
    visit_probability = log(1e-8 + visit_probability)
    visit_loss = SoftmaxOutput(visit_probability, target, grad_scale=visit_weight, name='visit_loss')
    
    return visit_loss
    
def compute_semisup_loss(a,b,labels,walker_weight=1., visit_weight=1.):
    equality_matrix = broadcast_equal(reshape(labels, shape=(-1,1)), labels, name="eqmat")
    
    equality_matrix = cast(equality_matrix, dtype='float32')
    p_target = broadcast_div(equality_matrix,
                             sum(equality_matrix, axis=(1), keepdims=True))
    
    match_ab = dot(a,b, transpose_b=True,name='match_ab')
    p_ab = softmax(match_ab, name='p_ab')
    p_ba = softmax(transpose(match_ab), name='p_ba')
    p_aba = dot(p_ab, p_ba, name='p_aba')
    
    #todo: create walk statistics
    
    # softmaxOutput should be cross entropy loss: https://github.com/dmlc/mxnet/issues/1969
    # apparently this calculates the gradient of cross entropy loss for backprop, so should
    # be equivalent
    
    p_aba = log(1e-8 + p_aba, name='log_aba')
    walker_loss = SoftmaxOutput(p_aba, p_target, name='loss_aba', grad_scale=walker_weight)
    
    visit_loss = compute_visit_loss(p_ab, visit_weight)

    return (walker_loss, visit_loss)

In [81]:
def logit_loss(embeddings, labels, nclasses, grad_scale=1):
    fc1 = mx.symbol.FullyConnected(data=embeddings, num_hidden=nclasses, name='fc1')
    softmax = mx.symbol.SoftmaxOutput(fc1, labels, name='softmax'+str(grad_scale), grad_scale=grad_scale)
    return softmax

In [82]:
def build():
    dataSup = mx.symbol.Variable(name="dataSup")
    labelSup = mx.symbol.Variable(name='labelSup')
    overall_loss = []
    
    if unsup_multiplier >= 1:
        dataUnsup = []
        for i in range(unsup_multiplier):
            dataUnsup.append(Variable(name="dataUnsup"+str(i)))

        # concat data, feed both through the network
        # then split it up again
        data = concat(dataSup, *dataUnsup, dim=0)

        embeddings = build_embeddings(data)
        splitted = split(embeddings, num_outputs=(unsup_multiplier+1), axis=0, name='split')

        supEmbeddings = splitted[0]
        
        unsupEmbeddings = []
        for i in range(unsup_multiplier):
            unsupEmbeddings.append(splitted[1+i])
    
        unsupEmbeddings = concat(*unsupEmbeddings, dim=0)
            
        (walker_loss, visit_loss) = compute_semisup_loss(supEmbeddings, unsupEmbeddings, labelSup, 
                                                     walker_weight=1.0, visit_weight=1.0)
        overall_loss = [walker_loss, visit_loss]
        
    else:
        supEmbeddings = build_embeddings(dataSup)
        
    overall_loss = [logit_loss(supEmbeddings, labelSup, 10)] + overall_loss
    
    return Group(overall_loss)


In [83]:
class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""

    def __init__(self, num=None):
        super(Multi_Accuracy, self).__init__('multi-accuracy', num)

    def update(self, labels, preds):

        #for i in range(len(preds)):
        for i in range(1):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[0].asnumpy().astype('int32')

            #mx.metric.check_label_shapes(label, pred_label)
            
            #print((pred_label.flat == label.flat).sum())
            #print(len(pred_label.flat))

            
            self.sum_metric[i] += (pred_label.flat == label.flat).sum()
            self.num_inst[i] += len(pred_label.flat)

In [84]:
# evaluation metrices
eval_metrics = Multi_Accuracy(num= 3 if unsup_multiplier >= 1 else 1)
                    
def fit_model(args, network, data, **kwargs):
    """
    train a model
    args : argparse returns
    network : the symbol definition of the nerual network
    data_loader : function that returns the train and val data iterators
    """
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)
    batch_size = sup_batch_size * (unsup_multiplier + 1)
    

    # data iterators
    (train, val) = data
    if args.test_io:
        tic = time.time()
        for i, batch in enumerate(train):
            for j in batch.data:
                j.wait_to_read()
            if (i+1) % args.disp_batches == 0:
                logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
                    i, args.disp_batches*batch_size/(time.time()-tic)))
                tic = time.time()
        return

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = fit._load_model(args, kv.rank)
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = fit._save_model(args, kv.rank)

    # devices for training
    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    # learning rate
    lr, lr_scheduler = fit._get_lr_scheduler(args, kv)

    data_names = ['dataSup'] + ['dataUnsup'+str(i) for i in range(unsup_multiplier)]
        
    # create model
    model = mx.mod.Module(
        context       = devs,
        symbol        = network,
        label_names   = ['labelSup'],
        data_names   = data_names
    )
    
    def validate_model(epoch, *args):
        if epoch % val_interval != 0: 
            return
        res = model.score(val, eval_metrics)
        #TODO: pull this into default
        for name, value in res:
            print('Epoch[%d] Validation-%s=%f' % (epoch, name, value))

    
    #print(model.label_names)
    lr_scheduler  = lr_scheduler
    optimizer_params = {
            'learning_rate': lr,
            #'momentum' : args.mom,
            'wd' : args.wd,
            'rescale_grad': 1,
            'lr_scheduler': lr_scheduler}

    #monitor = mx.mon.Monitor(interval=1000, pattern='.*aba_backward.*') 
    monitor = mx.mon.Monitor(interval=1000, pattern='.*') 
    
    #initializer = mx.init.Xavier(
    #    rnd_type='gaussian', factor_type="in", magnitude=2)
    initializer   = mx.init.Xavier(rnd_type='gaussian', factor_type="avg", magnitude=2.34)

    # callbacks that run after each batch
    batch_end_callbacks = [mx.callback.Speedometer(batch_size, args.disp_batches)]
    
    epoch_end_callbacks = validate_model

    # run
    model.fit(train,
        begin_epoch        = args.load_epoch if args.load_epoch else 0,
        num_epoch          = args.num_epochs,
        #eval_data          = val,
        eval_metric        = eval_metrics,
        kvstore            = kv,
        optimizer          = args.optimizer,
        optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        batch_end_callback = batch_end_callbacks,
        epoch_end_callback = epoch_end_callbacks
        #allow_missing      = True
        #monitor            = monitor
             )
    return model

In [85]:
train_sup.reset()
train_unsup.reset()

net = build()

model = fit_model(args, net, (train, val))

[5000.0, 10000.0, 15000.0]
Epoch[0] Validation-multi-accuracy_0=0.186700
Epoch[0] Validation-multi-accuracy_1=nan
Epoch[0] Validation-multi-accuracy_2=nan
Epoch[100] Validation-multi-accuracy_0=0.920200
Epoch[100] Validation-multi-accuracy_1=nan
Epoch[100] Validation-multi-accuracy_2=nan
Epoch[200] Validation-multi-accuracy_0=0.922000
Epoch[200] Validation-multi-accuracy_1=nan
Epoch[200] Validation-multi-accuracy_2=nan
Epoch[300] Validation-multi-accuracy_0=0.932500
Epoch[300] Validation-multi-accuracy_1=nan
Epoch[300] Validation-multi-accuracy_2=nan
Epoch[400] Validation-multi-accuracy_0=0.942300
Epoch[400] Validation-multi-accuracy_1=nan
Epoch[400] Validation-multi-accuracy_2=nan
Epoch[500] Validation-multi-accuracy_0=0.947700
Epoch[500] Validation-multi-accuracy_1=nan
Epoch[500] Validation-multi-accuracy_2=nan
Epoch[600] Validation-multi-accuracy_0=0.950700
Epoch[600] Validation-multi-accuracy_1=nan
Epoch[600] Validation-multi-accuracy_2=nan
Epoch[700] Validation-multi-accuracy_0=0.

Epoch[6000] Validation-multi-accuracy_0=0.967700
Epoch[6000] Validation-multi-accuracy_1=nan
Epoch[6000] Validation-multi-accuracy_2=nan
Epoch[6100] Validation-multi-accuracy_0=0.967800
Epoch[6100] Validation-multi-accuracy_1=nan
Epoch[6100] Validation-multi-accuracy_2=nan
Epoch[6200] Validation-multi-accuracy_0=0.965900
Epoch[6200] Validation-multi-accuracy_1=nan
Epoch[6200] Validation-multi-accuracy_2=nan
Epoch[6300] Validation-multi-accuracy_0=0.967300
Epoch[6300] Validation-multi-accuracy_1=nan
Epoch[6300] Validation-multi-accuracy_2=nan
Epoch[6400] Validation-multi-accuracy_0=0.967100
Epoch[6400] Validation-multi-accuracy_1=nan
Epoch[6400] Validation-multi-accuracy_2=nan
Epoch[6500] Validation-multi-accuracy_0=0.966500
Epoch[6500] Validation-multi-accuracy_1=nan
Epoch[6500] Validation-multi-accuracy_2=nan
Epoch[6600] Validation-multi-accuracy_0=0.967800
Epoch[6600] Validation-multi-accuracy_1=nan
Epoch[6600] Validation-multi-accuracy_2=nan
Epoch[6700] Validation-multi-accuracy_0=0

Epoch[12000] Validation-multi-accuracy_0=0.968200
Epoch[12000] Validation-multi-accuracy_1=nan
Epoch[12000] Validation-multi-accuracy_2=nan
Epoch[12100] Validation-multi-accuracy_0=0.967700
Epoch[12100] Validation-multi-accuracy_1=nan
Epoch[12100] Validation-multi-accuracy_2=nan
Epoch[12200] Validation-multi-accuracy_0=0.966400
Epoch[12200] Validation-multi-accuracy_1=nan
Epoch[12200] Validation-multi-accuracy_2=nan
Epoch[12300] Validation-multi-accuracy_0=0.968400
Epoch[12300] Validation-multi-accuracy_1=nan
Epoch[12300] Validation-multi-accuracy_2=nan
Epoch[12400] Validation-multi-accuracy_0=0.968500
Epoch[12400] Validation-multi-accuracy_1=nan
Epoch[12400] Validation-multi-accuracy_2=nan
Epoch[12500] Validation-multi-accuracy_0=0.967600
Epoch[12500] Validation-multi-accuracy_1=nan
Epoch[12500] Validation-multi-accuracy_2=nan
Epoch[12600] Validation-multi-accuracy_0=0.968100
Epoch[12600] Validation-multi-accuracy_1=nan
Epoch[12600] Validation-multi-accuracy_2=nan
Epoch[12700] Validat

Epoch[17900] Validation-multi-accuracy_0=0.967800
Epoch[17900] Validation-multi-accuracy_1=nan
Epoch[17900] Validation-multi-accuracy_2=nan
Epoch[18000] Validation-multi-accuracy_0=0.967800
Epoch[18000] Validation-multi-accuracy_1=nan
Epoch[18000] Validation-multi-accuracy_2=nan
Epoch[18100] Validation-multi-accuracy_0=0.967700
Epoch[18100] Validation-multi-accuracy_1=nan
Epoch[18100] Validation-multi-accuracy_2=nan
Epoch[18200] Validation-multi-accuracy_0=0.967800
Epoch[18200] Validation-multi-accuracy_1=nan
Epoch[18200] Validation-multi-accuracy_2=nan
Epoch[18300] Validation-multi-accuracy_0=0.967700
Epoch[18300] Validation-multi-accuracy_1=nan
Epoch[18300] Validation-multi-accuracy_2=nan
Epoch[18400] Validation-multi-accuracy_0=0.967900
Epoch[18400] Validation-multi-accuracy_1=nan
Epoch[18400] Validation-multi-accuracy_2=nan
Epoch[18500] Validation-multi-accuracy_0=0.967900
Epoch[18500] Validation-multi-accuracy_1=nan
Epoch[18500] Validation-multi-accuracy_2=nan
Epoch[18600] Validat

In [None]:
shape={"dataSup":(10, 1, 28,28),
                    "dataUnsup0":(10, 1, 28,28),
                    "dataUnsup1":(10, 1, 28,28), 
                     "labelSup": (10)
                    }

plot_network(net, shape=shape)

In [None]:
batch=train.next()

labels = batch.label[0]
bDataSup = batch.data[0]
bDataUnsup = batch.data[1]


a = model.forward(batch)

In [76]:
def confusion_matrix(labels, predictions, num_labels):
  """Compute the confusion matrix."""
  rows = []
  for i in range(num_labels):
    row = np.bincount(predictions[labels == i], minlength=num_labels)
    rows.append(row)
  return np.vstack(rows)

val.reset()
res = model.predict(val)[0].asnumpy().argmax(axis=1)
print(res.shape, val_lbl.shape)

confusion_matrix(val_lbl, res,10)

(10000,) (10000,)


array([[ 976,    0,    2,    0,    0,    0,    0,    1,    1,    0],
       [   0, 1123,    5,    3,    0,    1,    0,    3,    0,    0],
       [   1,    0, 1029,    0,    0,    0,    0,    1,    0,    1],
       [   0,    0,    0, 1001,    0,    4,    0,    2,    2,    1],
       [   1,    0,    0,    0,  955,    0,    2,    0,    0,   24],
       [   0,    0,    0,    2,    0,  887,    2,    1,    0,    0],
       [   5,    3,    1,    0,    1,    3,  944,    0,    1,    0],
       [   0,    1,    8,    1,    0,    1,    0, 1016,    0,    1],
       [   0,    0,    1,    2,    1,    1,    0,    2,  964,    3],
       [   2,    0,    0,    2,    9,    2,    0,   10,    0,  984]])

In [None]:
print(labels.asnumpy())
model.get_outputs()[1].asnumpy()[0:10,0:10]

In [None]:
import re
import copy
import json
def _str2tuple(string):
    """Convert shape string to list, internal use only.
    Parameters
    ----------
    string: str
        Shape string.
    Returns
    -------
    list of str
        Represents shape.
    """
    return re.findall(r"\d+", string)

def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
                 hide_weights=True):
    """Creates a visualization (Graphviz digraph object) of the given computation graph.
    Graphviz must be installed for this function to work.
    Parameters
    ----------
    title: str, optional
        Title of the generated visualization.
    symbol: Symbol
        A symbol from the computation graph. The generated digraph will visualize the part
        of the computation graph required to compute `symbol`.
    shape: dict, optional
        Specifies the shape of the input tensors. If specified, the visualization will include
        the shape of the tensors between the nodes. `shape` is a dictionary mapping
        input symbol names (str) to the corresponding tensor shape (tuple).
    node_attrs: dict, optional
        Specifies the attributes for nodes in the generated visualization. `node_attrs` is
        a dictionary of Graphviz attribute names and values. For example,
            ``node_attrs={"shape":"oval","fixedsize":"false"}``
            will use oval shape for nodes and allow variable sized nodes in the visualization.
    hide_weights: bool, optional
        If True (default), then inputs with names of form *_weight (corresponding to weight
        tensors) or *_bias (corresponding to bias vectors) will be hidden for a cleaner
        visualization.
    Returns
    -------
    dot: Digraph
        A Graphviz digraph object visualizing the computation graph to compute `symbol`.
    Example
    -------
    >>> net = mx.sym.Variable('data')
    >>> net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)
    >>> net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
    >>> net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)
    >>> net = mx.sym.SoftmaxOutput(data=net, name='out')
    >>> digraph = mx.viz.plot_network(net, shape={'data':(100,200)},
    ... node_attrs={"fixedsize":"false"})
    >>> digraph.view()
    """
    # todo add shape support
    try:
        from graphviz import Digraph
    except:
        raise ImportError("Draw network requires graphviz library")
    if not isinstance(symbol, Symbol):
        raise TypeError("symbol must be a Symbol")
    draw_shape = False
    if shape is not None:
        draw_shape = True
        interals = symbol.get_internals()
        _, out_shapes, _ = interals.infer_shape(**shape)
        if out_shapes is None:
            raise ValueError("Input shape is incomplete")
        shape_dict = dict(zip(interals.list_outputs(), out_shapes))
    conf = json.loads(symbol.tojson())
    nodes = conf["nodes"]
    # default attributes of node
    node_attr = {"shape": "box", "fixedsize": "true",
                 "width": "1.3", "height": "0.8034", "style": "filled"}
    # merge the dict provided by user and the default one
    node_attr.update(node_attrs)
    dot = Digraph(name=title, format=save_format)
    # color map
    cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
          "#fdb462", "#b3de69", "#fccde5")

    def looks_like_weight(name):
        """Internal helper to figure out if node should be hidden with `hide_weights`.
        """
        if name.endswith("_weight"):
            return True
        if name.endswith("_bias"):
            return True
        return False

    # make nodes
    hidden_nodes = set()
    for node in nodes:
        op = node["op"]
        name = node["name"]
        # input data
        attr = copy.deepcopy(node_attr)
        label = name

        if op == "null":
            if looks_like_weight(node["name"]):
                if hide_weights:
                    hidden_nodes.add(node["name"])
                # else we don't render a node, but
                # don't add it to the hidden_nodes set
                # so it gets rendered as an empty oval
                continue
            attr["shape"] = "oval" # inputs get their own shape
            label = node["name"]
            attr["fillcolor"] = cm[0]
        elif op == "Convolution":
            label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attr"]["kernel"])),
                                                 "x".join(_str2tuple(node["attr"]["stride"]))
                                                 if "stride" in node["attr"] else "1",
                                                 node["attr"]["num_filter"])
            attr["fillcolor"] = cm[1]
        elif op == "FullyConnected":
            label = r"FullyConnected\n%s" % node["attr"]["num_hidden"]
            attr["fillcolor"] = cm[1]
        elif op == "BatchNorm":
            attr["fillcolor"] = cm[3]
        elif op == "Activation" or op == "LeakyReLU":
            label = r"%s\n%s" % (op, node["attr"]["act_type"])
            attr["fillcolor"] = cm[2]
        elif op == "Pooling":
            label = r"Pooling\n%s, %s/%s" % (node["attr"]["pool_type"],
                                             "x".join(_str2tuple(node["attr"]["kernel"])),
                                             "x".join(_str2tuple(node["attr"]["stride"]))
                                             if "stride" in node["attr"] else "1")
            attr["fillcolor"] = cm[4]
        elif op == "Concat" or op == "Flatten" or op == "Reshape":
            attr["fillcolor"] = cm[5]
        elif op == "Softmax":
            attr["fillcolor"] = cm[6]
        else:
            attr["fillcolor"] = cm[7]
            if op == "Custom":
                label = node["attr"]["op_type"]

        dot.node(name=name, label=label, **attr)

    # add edges
    for node in nodes:          # pylint: disable=too-many-nested-blocks
        op = node["op"]
        name = node["name"]
        if op == "null":
            continue
        else:
            inputs = node["inputs"]
            for item in inputs:
                input_node = nodes[item[0]]
                input_name = input_node["name"]
                if input_name not in hidden_nodes:
                    attr = {"dir": "back", 'arrowtail':'open'}
                    # add shapes
                    if draw_shape:
                        if input_node["op"] != "null":
                            key = input_name + "_output"
                            if "attr" in input_node:
                                params = input_node["attr"]
                                if "num_outputs" in params:
                                    key += str(np.maximum(int(params["num_outputs"]) - 1, 0))
                                    params["num_outputs"] = int(params["num_outputs"]) - 1
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                        else:
                            key = input_name
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                    dot.edge(tail_name=name, head_name=input_name, **attr)

    return dot