In [19]:
import numpy as np
import mxnet as mx
from mxnet.test_utils import *

In [2]:
def np_softmax(x, axis=-1):
    # fix for old numpy on Travis not supporting keepdims
    # x = x - np.max(x, axis=-1, keepdims=True)
    x = x - np.max(x, axis=axis, keepdims=True)
    x = np.exp(x)
    # x /= np.sum(x, axis=-1, keepdims=True)
    x /= np.sum(x, axis=axis, keepdims=True)
    return x

In [181]:
def cls2onehot(x, axis=-1):
    # convert class ids to one hot vector
    max_id = np.max(x)
    original_shape = list(x.shape)
    #print(original_shape)
    original_shape.insert(axis+1, max_id+1)
    #print(original_shape)
    zero_shape = list(x.flatten().shape)
    zero_shape.append(max_id+1)
    # generate zeros with shape (flatten, max_id+1)
    one_hot = np.zeros(zero_shape)
    one_hot[np.arange(zero_shape[0]), x.flatten()] = 1
    #print(np.sum(one_hot, axis=1))
    zero_id = np.where(np.sum(one_hot, axis=1)==0)
    #print(zero_id)
    one_hot = np.reshape(one_hot, original_shape)
    return one_hot

np.random.seed(10)
dummy = np.random.randint(0,5, [6])
prob_dummy = np.random.normal(size=[6,5])
print(dummy)
one_hot = cls2onehot(dummy, 1)
print(one_hot)
pred_dummy = np.where(one_hot==1, prob_dummy, 0)
print(pred_dummy)
print(np.sum(pred_dummy, axis=1))

[1 4 0 1 3 4]
[[ 0.  1.  0.  0.  0.]
 [ 0.  0.  0.  0.  1.]
 [ 1.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  1.]]
[[ 0.         -0.00838385  0.          0.          0.        ]
 [ 0.          0.          0.          0.          1.20303737]
 [-0.96506567  0.          0.          0.          0.        ]
 [ 0.          1.484537    0.          0.          0.        ]
 [ 0.          0.          0.          1.67262221  0.        ]
 [ 0.          0.          0.          0.         -0.54930901]]
[-0.00838385  1.20303737 -0.96506567  1.484537    1.67262221 -0.54930901]


In [258]:
# params
epsilon = 1e-16
alpha = 0.25
gamma = 2
num_anchors = 5 # number of anchors per location
num_classes  = 12 # with background as index 0
H = 3 #height
W = 3 #width
N = 4 #batch size
C = num_anchors * num_classes

In [276]:
logits = np.random.normal(5, 1, [N, C, H, W])
logits_reshape = logits.reshape([N, num_anchors, num_classes, H, W])
label_data = np.random.randint(0, num_classes, [N, num_anchors, H, W])
label_one_hot = cls2onehot(label_data.reshape([N, num_anchors, H, W]), 1)
# print('logits_reshape', logits_reshape.shape)
# print('label', label_data.shape)
print('label_one_hot', label_one_hot.shape)

prob = np_softmax(logits_reshape, axis=2) # calculate softmax output along classes axis
#print(prob)
#print(np.sum(prob, axis=2))
softmax_prob = prob.reshape([N, C, H, W])

preds = np.zeros([N, num_anchors, H, W])
# take corresponding prob of class for CE loss
for i in range(preds.shape[0]):
    for j in range(preds.shape[1]):
        for u in range(preds.shape[2]):
            for v in range(preds.shape[3]):
                cls = label_data[i][j][u][v]
                preds[i][j][u][v] = prob[i][j][cls][u][v]
alpha_ = np.where(label_data>=1, alpha, 1-alpha)

expected_losses = -alpha_ * (1. - preds)**gamma * np.log(preds + epsilon)


# mxnet ndarray
x = mx.sym.Variable('x')
label = mx.sym.Variable('label')
norm = mx.sym.Variable('norm')
x_nd = mx.nd.array(logits, ctx=mx.gpu(0))
label_nd = mx.nd.array(label_data, ctx=mx.gpu(0))
norm_nd = mx.nd.array([1], ctx=mx.gpu(0))

sym = mx.sym.contrib.SoftmaxFocalLoss(data=x, label=label, normalizer=norm, 
                                        gamma=gamma, alpha=alpha, num_classes=num_classes)
arg_shapes, out_shapes, _ = sym.infer_shape(x=x_nd.shape, label=label_nd.shape, norm=norm_nd.shape)
args_grad = [mx.nd.empty(s, ctx=mx.gpu(0)) for s in arg_shapes]
ex = sym.bind(ctx=mx.gpu(0), args={'x': x_nd, 'label': label_nd, 'norm': norm_nd}, args_grad=args_grad)
ex.forward(is_train=True)
focal_loss_out = ex.outputs[0].asnumpy()
softmax_out = ex.outputs[1].asnumpy()

label_one_hot (4, 5, 12, 3, 3)


In [277]:
assert_almost_equal(softmax_out, softmax_prob)
assert_almost_equal(expected_losses, focal_loss_out)