In [4]:
import tensorflow as tf

# 张量排序

 * sort/argsort 得到序列排序/得到序列排序后的位置
 * topk 得到前n个大数（小数）
 * top-5 acc

## sort/argsort

In [7]:
a = tf.random.shuffle(tf.range(5))
a

<tf.Tensor: id=14, shape=(5,), dtype=int32, numpy=array([1, 2, 4, 0, 3], dtype=int32)>

In [9]:
# 降序排列
tf.sort(a, direction="DESCENDING")

<tf.Tensor: id=31, shape=(5,), dtype=int32, numpy=array([4, 3, 2, 1, 0], dtype=int32)>

In [10]:
# 获得降序排列后的相对原数据的位置
tf.argsort(a, direction="DESCENDING")

<tf.Tensor: id=41, shape=(5,), dtype=int32, numpy=array([2, 4, 1, 0, 3], dtype=int32)>

In [14]:
tf.gather(a, tf.argsort(a, direction="DESCENDING"))

<tf.Tensor: id=63, shape=(5,), dtype=int32, numpy=array([4, 3, 2, 1, 0], dtype=int32)>

In [21]:
a = tf.random.uniform([3, 3], maxval=10, dtype=tf.int32)
a

<tf.Tensor: id=91, shape=(3, 3), dtype=int32, numpy=
array([[1, 9, 7],
       [6, 0, 7],
       [9, 1, 6]], dtype=int32)>

In [24]:
tf.sort(a), tf.sort(a, direction="DESCENDING")

(<tf.Tensor: id=133, shape=(3, 3), dtype=int32, numpy=
 array([[1, 7, 9],
        [0, 6, 7],
        [1, 6, 9]], dtype=int32)>,
 <tf.Tensor: id=141, shape=(3, 3), dtype=int32, numpy=
 array([[9, 7, 1],
        [7, 6, 0],
        [9, 6, 1]], dtype=int32)>)

In [25]:
tf.argsort(a)

<tf.Tensor: id=152, shape=(3, 3), dtype=int32, numpy=
array([[0, 2, 1],
       [1, 0, 2],
       [1, 2, 0]], dtype=int32)>

## topk

In [27]:
a = tf.random.uniform([3, 3], maxval=10, dtype=tf.int32)
a

<tf.Tensor: id=160, shape=(3, 3), dtype=int32, numpy=
array([[4, 4, 8],
       [6, 8, 8],
       [8, 4, 2]], dtype=int32)>

In [32]:
res = tf.math.top_k(a, k=2)
res

(TopKV2(values=<tf.Tensor: id=174, shape=(3, 2), dtype=int32, numpy=
 array([[8, 4],
        [8, 8],
        [8, 4]], dtype=int32)>, indices=<tf.Tensor: id=175, shape=(3, 2), dtype=int32, numpy=
 array([[2, 0],
        [1, 2],
        [0, 1]], dtype=int32)>),
 <tf.Tensor: id=175, shape=(3, 2), dtype=int32, numpy=
 array([[2, 0],
        [1, 2],
        [0, 1]], dtype=int32)>)

In [35]:
res.indices, res.values

(<tf.Tensor: id=175, shape=(3, 2), dtype=int32, numpy=
 array([[2, 0],
        [1, 2],
        [0, 1]], dtype=int32)>,
 <tf.Tensor: id=174, shape=(3, 2), dtype=int32, numpy=
 array([[8, 4],
        [8, 8],
        [8, 4]], dtype=int32)>)

#### top accuracy

In [36]:
prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.7, 0.1]])
target = tf.constant([2, 0])
k_b = tf.math.top_k(prob, k=3).indices
k_b

<tf.Tensor: id=180, shape=(2, 3), dtype=int32, numpy=
array([[2, 1, 0],
       [1, 0, 2]], dtype=int32)>

In [39]:
k_b = tf.transpose(k_b, perm=[1, 0])
k_b

<tf.Tensor: id=186, shape=(3, 2), dtype=int32, numpy=
array([[2, 1],
       [1, 0],
       [0, 2]], dtype=int32)>

In [44]:
target = tf.broadcast_to(target, [3, 2])
target

<tf.Tensor: id=195, shape=(3, 2), dtype=int32, numpy=
array([[2, 0],
       [2, 0],
       [2, 0]], dtype=int32)>

In [54]:
tf.equal(k_b, target)

<tf.Tensor: id=238, shape=(3, 2), dtype=bool, numpy=
array([[ True, False],
       [False,  True],
       [False, False]])>

## 实战

In [68]:
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.shape[0]
    
    pred = tf.math.top_k(output, maxk).indices
    pred = tf.transpose(pred, perm=[1, 0])
    target_ = tf.broadcast_to(target, pred.shape)
    # [10, b]
    correct = tf.equal(pred, target_)
    
    res = []
    for k in topk:
        correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
        correct_k = tf.reduce_sum(correct_k)
        acc = float(correct_k * (100.0 / batch_size))
        res.append(acc)
        
    return res

output = tf.random.normal([10, 6])
output = tf.math.softmax(output, axis=1) # 维度等于1的数据相加和为1
print("prob:", output.numpy())
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print("label:", target.numpy())
pred = tf.argmax(output, axis=1)
print("pred:", pred.numpy())

acc = accuracy(output, target, topk=(1, 2, 3, 4, 5, 6))
print("top 1-6 acc:", acc)

prob: [[0.18353066 0.3276817  0.06344536 0.12068    0.06104458 0.24361764]
 [0.14046003 0.23030654 0.11896093 0.28100115 0.08454478 0.14472651]
 [0.12014476 0.23867853 0.09241256 0.06428436 0.34006873 0.14441106]
 [0.02193103 0.02008838 0.06097358 0.73129296 0.12790377 0.03781031]
 [0.11057741 0.08675427 0.07851733 0.19058196 0.30722696 0.2263421 ]
 [0.19347434 0.10302109 0.19567528 0.1416116  0.06699216 0.29922545]
 [0.12074371 0.04231453 0.40411508 0.10170089 0.02152895 0.30959684]
 [0.19445543 0.25950745 0.11772771 0.17805274 0.202736   0.04752066]
 [0.09833962 0.16481824 0.14184082 0.06099552 0.17508325 0.3589225 ]
 [0.20977901 0.16823663 0.10440225 0.09709366 0.13565136 0.28483704]]
label: [1 5 2 5 1 2 0 3 2 5]
pred: [1 3 4 3 4 5 2 1 5 5]
top 1-6 acc: [20.0, 30.0, 50.0, 80.0, 100.0, 100.0]
