In [154]:
import torch
import json

In [238]:
index_label = {
    0: 'outside',
    1: 'nonsense',
    2: 'ileocecal',
    3: 'bbps0',
    4: 'bbps1',
    5: 'bbps2',
    6: 'bbps3',
}

In [239]:
gt_path = '/mnt/data4/cwy/Datasets/VideoHYL/gt_out.txt'
pred_path = '/mnt/data4/cwy/Datasets/VideoHYL/pred_out.txt'

In [240]:
def read_from_imagenet(txt_path: str):
    name_all = []
    label_all = []
    with open(txt_path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        # 0位置是图片名
        line = line.strip()
        name = line.split(' ')[0]
        name_all.append(name)
        try:
            labels = [int(x)-1 for x in line.split(' ')[1:]]
        except ValueError:
            print(line)
            raise
        label_list = [0, 0, 0, 0, 0, 0, 0]
        for label in labels:
            label_list[label] = 1
        # 把后面4个bbps标签移动到最前面
        label_list = label_list[3:] + label_list[:3]
        label_all.append(label_list)

    return name_all, torch.tensor(label_all)

In [241]:
gt_names, gt_labels = read_from_imagenet(gt_path)
pred_names, pred_labels = read_from_imagenet(pred_path)
# gt_labels与pred_labels的每一行是一个长度为7的标签向量，分别对应label_dict的每个类别

In [242]:
pred_ori=pred_labels.clone().detach()

In [243]:
outlier_json_path = '/mnt/data4/cwy/CQCExp/Experiment/R105_detect_fps_vitp14s336c7_40/outliers/outlier_result.json'

In [244]:
outliers = json.load(open(outlier_json_path))['HYL_Clip']
outliers

{'0': {'0': [[178, 178],
   [430, 430],
   [460, 460],
   [507, 507],
   [1788, 1788],
   [1806, 1806],
   [4010, 4010],
   [4045, 4046],
   [4144, 4145],
   [4147, 4148],
   [4543, 4543]],
  '1': [[450, 450],
   [458, 459],
   [506, 506],
   [508, 509],
   [654, 655],
   [680, 680],
   [1757, 1757],
   [1787, 1787],
   [1789, 1790],
   [1805, 1805],
   [1807, 1807],
   [1817, 1817],
   [1832, 1832],
   [1837, 1837],
   [1903, 1903],
   [1954, 1954],
   [2004, 2004],
   [3320, 3320],
   [3948, 3949],
   [3989, 3989],
   [4009, 4009],
   [4011, 4011],
   [4022, 4022],
   [4027, 4027],
   [4044, 4044],
   [4047, 4047],
   [4114, 4114],
   [4136, 4136],
   [4142, 4143],
   [4146, 4146],
   [4149, 4149],
   [4315, 4315],
   [4430, 4430],
   [4532, 4532],
   [4542, 4542],
   [4544, 4545],
   [4577, 4577],
   [4645, 4645],
   [4679, 4679],
   [4727, 4727],
   [4970, 4970],
   [4978, 4978],
   [4984, 4984],
   [5206, 5206]]},
 '1': {'0': [[178, 178],
   [430, 430],
   [460, 460],
   [507, 507

In [245]:
outside_outliers_101 = outliers['0']['0']
outside_outliers_010 = outliers['0']['1']
nonsense_outliers_101 = outliers['1']['0']
nonsense_outliers_010 = outliers['1']['1']

In [246]:
def merge_outside_outliers(idx_101,idx_010):
    idx_ab=[]
    idx_ab.extend(idx_101)
    idx_ab.extend(idx_010)
    idx_ab.sort(key=lambda x:x[0])
    idx_merged = [idx_ab.pop(0)]
    for it in idx_ab:
        if it[0]==idx_merged[-1][1]+1:
            idx_merged[-1][1]=it[1]
        else:
            idx_merged.append(it)
    return idx_merged 

In [247]:
outside_outliers = merge_outside_outliers(outside_outliers_101,outside_outliers_010)
outside_outliers

[[178, 178],
 [430, 430],
 [450, 450],
 [458, 460],
 [506, 509],
 [654, 655],
 [680, 680],
 [1757, 1757],
 [1787, 1790],
 [1805, 1807],
 [1817, 1817],
 [1832, 1832],
 [1837, 1837],
 [1903, 1903],
 [1954, 1954],
 [2004, 2004],
 [3320, 3320],
 [3948, 3949],
 [3989, 3989],
 [4009, 4011],
 [4022, 4022],
 [4027, 4027],
 [4044, 4047],
 [4114, 4114],
 [4136, 4136],
 [4142, 4149],
 [4315, 4315],
 [4430, 4430],
 [4532, 4532],
 [4542, 4545],
 [4577, 4577],
 [4645, 4645],
 [4679, 4679],
 [4727, 4727],
 [4970, 4970],
 [4978, 4978],
 [4984, 4984],
 [5206, 5206]]

In [248]:
nonsense_outliers = nonsense_outliers_101
nonsense_outliers

[[178, 178],
 [430, 430],
 [460, 460],
 [507, 507],
 [594, 594],
 [597, 597],
 [611, 611],
 [633, 633],
 [676, 676],
 [678, 679],
 [684, 684],
 [686, 687],
 [694, 695],
 [713, 713],
 [715, 716],
 [730, 731],
 [741, 741],
 [784, 784],
 [864, 864],
 [975, 975],
 [996, 997],
 [1006, 1006],
 [1066, 1066],
 [1142, 1142],
 [1158, 1159],
 [1224, 1224],
 [1266, 1267],
 [1269, 1269],
 [1347, 1347],
 [1350, 1350],
 [1382, 1382],
 [1384, 1384],
 [1386, 1386],
 [1389, 1389],
 [1523, 1524],
 [1539, 1539],
 [1542, 1542],
 [1584, 1584],
 [1590, 1591],
 [1627, 1627],
 [1784, 1785],
 [1788, 1788],
 [1806, 1806],
 [1816, 1816],
 [1904, 1904],
 [1908, 1909],
 [1919, 1919],
 [1941, 1941],
 [1950, 1951],
 [2050, 2051],
 [2053, 2054],
 [2190, 2190],
 [2199, 2199],
 [2355, 2355],
 [2559, 2559],
 [3314, 3315],
 [3318, 3319],
 [3408, 3408],
 [3589, 3589],
 [3629, 3629],
 [3709, 3709],
 [3837, 3837],
 [4010, 4010],
 [4045, 4046],
 [4144, 4145],
 [4147, 4148],
 [4431, 4432],
 [4533, 4534],
 [4543, 4543],
 [4576,

In [249]:
def fix_outliers_idx(pred_origin, outlier_list, val_func):
    pred_fix = pred_origin.clone().detach()
    for item in outlier_list:
        a, b = item
        for now_i in range(a,b+1):
            pred_fix[now_i] = torch.tensor(val_func(pred_origin, item, now_i))
    return pred_fix

In [250]:
def val_reverse_outside(pred,idx_item,now_i):
    a, b = idx_item
    if pred[a][0].item()==0:
        if pred[now_i][0].item()==1 or pred[now_i][1].item()==1:
            print(f'Fix {idx_item}: {now_i} not change')
            return pred[now_i].clone().detach()
        else:
            print(f'Fix {idx_item}: {now_i}=[1, 0, 0, 0, 0, 0, 0]')
            return [1, 0, 0, 0, 0, 0, 0]
    elif pred[a][0].item()==1:
        if pred[now_i][0].item()==0:
            print(f'Fix {idx_item}: {now_i} not change')
            return pred[now_i].clone().detach()
        else:
            print(f'Fix {idx_item}: {now_i}=[0, 1, 0, 0, 0, 0, 0]')
            return [0, 1, 0, 0, 0, 0, 0]
    else:
        raise ValueError

In [251]:
def val_reverse_nonsense_101(pred,idx_item,now_i):
    if pred[now_i][1].item()==0:
        print(f'Fix {idx_item}: {now_i}=[0, 1, 0, 0, 0, 0, 0]')
        return [0, 1, 0, 0, 0, 0, 0]
    elif pred[now_i][1].item()==1:
        print(f'Fix {idx_item}: {now_i} not change')
        return pred[now_i].clone().detach()
    else:
        raise ValueError

In [252]:
pred_fix = pred_labels.clone().detach()

In [253]:
pred_fix = fix_outliers_idx(pred_fix, nonsense_outliers, val_reverse_nonsense_101)

Fix [178, 178]: 178=[0, 1, 0, 0, 0, 0, 0]
Fix [430, 430]: 430=[0, 1, 0, 0, 0, 0, 0]
Fix [460, 460]: 460=[0, 1, 0, 0, 0, 0, 0]
Fix [507, 507]: 507=[0, 1, 0, 0, 0, 0, 0]
Fix [594, 594]: 594=[0, 1, 0, 0, 0, 0, 0]
Fix [597, 597]: 597=[0, 1, 0, 0, 0, 0, 0]
Fix [611, 611]: 611=[0, 1, 0, 0, 0, 0, 0]
Fix [633, 633]: 633=[0, 1, 0, 0, 0, 0, 0]
Fix [676, 676]: 676=[0, 1, 0, 0, 0, 0, 0]
Fix [678, 679]: 678=[0, 1, 0, 0, 0, 0, 0]
Fix [678, 679]: 679=[0, 1, 0, 0, 0, 0, 0]
Fix [684, 684]: 684=[0, 1, 0, 0, 0, 0, 0]
Fix [686, 687]: 686=[0, 1, 0, 0, 0, 0, 0]
Fix [686, 687]: 687=[0, 1, 0, 0, 0, 0, 0]
Fix [694, 695]: 694=[0, 1, 0, 0, 0, 0, 0]
Fix [694, 695]: 695=[0, 1, 0, 0, 0, 0, 0]
Fix [713, 713]: 713=[0, 1, 0, 0, 0, 0, 0]
Fix [715, 716]: 715=[0, 1, 0, 0, 0, 0, 0]
Fix [715, 716]: 716=[0, 1, 0, 0, 0, 0, 0]
Fix [730, 731]: 730=[0, 1, 0, 0, 0, 0, 0]
Fix [730, 731]: 731=[0, 1, 0, 0, 0, 0, 0]
Fix [741, 741]: 741=[0, 1, 0, 0, 0, 0, 0]
Fix [784, 784]: 784=[0, 1, 0, 0, 0, 0, 0]
Fix [864, 864]: 864=[0, 1, 0, 0, 0

In [254]:
pred_fix = fix_outliers_idx(pred_fix, outside_outliers, val_reverse_outside)

Fix [178, 178]: 178 not change
Fix [430, 430]: 430 not change
Fix [450, 450]: 450=[0, 1, 0, 0, 0, 0, 0]
Fix [458, 460]: 458=[0, 1, 0, 0, 0, 0, 0]
Fix [458, 460]: 459=[0, 1, 0, 0, 0, 0, 0]
Fix [458, 460]: 460 not change
Fix [506, 509]: 506=[0, 1, 0, 0, 0, 0, 0]
Fix [506, 509]: 507 not change
Fix [506, 509]: 508=[0, 1, 0, 0, 0, 0, 0]
Fix [506, 509]: 509=[0, 1, 0, 0, 0, 0, 0]
Fix [654, 655]: 654=[0, 1, 0, 0, 0, 0, 0]
Fix [654, 655]: 655=[0, 1, 0, 0, 0, 0, 0]
Fix [680, 680]: 680=[0, 1, 0, 0, 0, 0, 0]
Fix [1757, 1757]: 1757=[0, 1, 0, 0, 0, 0, 0]
Fix [1787, 1790]: 1787=[0, 1, 0, 0, 0, 0, 0]
Fix [1787, 1790]: 1788 not change
Fix [1787, 1790]: 1789=[0, 1, 0, 0, 0, 0, 0]
Fix [1787, 1790]: 1790=[0, 1, 0, 0, 0, 0, 0]
Fix [1805, 1807]: 1805=[0, 1, 0, 0, 0, 0, 0]
Fix [1805, 1807]: 1806 not change
Fix [1805, 1807]: 1807=[0, 1, 0, 0, 0, 0, 0]
Fix [1817, 1817]: 1817=[0, 1, 0, 0, 0, 0, 0]
Fix [1832, 1832]: 1832=[0, 1, 0, 0, 0, 0, 0]
Fix [1837, 1837]: 1837=[0, 1, 0, 0, 0, 0, 0]
Fix [1903, 1903]: 1903=[0

  pred_fix[now_i] = torch.tensor(val_func(pred_origin, item, now_i))


In [255]:
pred_labels = pred_fix.clone().detach()

In [256]:
confuse_matrix = dict()
for i in range(3):
    confuse_matrix[f'label_{index_label[i]}_TP'] = 0.
    confuse_matrix[f'label_{index_label[i]}_FP'] = 0.
    confuse_matrix[f'label_{index_label[i]}_FN'] = 0.
    confuse_matrix[f'label_{index_label[i]}_TN'] = 0.
for i in range(0, 4):  # i: predict
    for j in range(0, 4):  # j: gt
        confuse_matrix[f'label_cleansing_pred_{index_label[i + 3]}_gt_{index_label[j + 3]}'] = 0.

In [257]:
# 计算test_acc
mean_acc = float(torch.eq(pred_labels, gt_labels).float().mean().cpu())
mean_acc

0.9793775677680969

In [258]:
in_out_labels = pred_labels[:, 0]
# 体内外标签: BoolTensor[B]
# outside时为True
label_in_out_pred = torch.gt(in_out_labels, 0)
# 体内外gt: BoolTensor[B]
label_in_out_gt = torch.gt(gt_labels[:, 0], 0)
confuse_matrix[f'label_{index_label[0]}_TP'] += float((label_in_out_pred & label_in_out_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[0]}_FP'] += float((label_in_out_pred & ~label_in_out_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[0]}_FN'] += float((~label_in_out_pred & label_in_out_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[0]}_TN'] += float((~label_in_out_pred & ~label_in_out_gt).float().sum().cpu())


In [259]:
nonsense_logit = pred_labels[:, 1]
# 坏帧标签: BoolTensor[B]
# nonsense时为True
label_nonsense_pred = torch.gt(nonsense_logit, 0)
# 帧质量gt: BoolTensor[B]
# pred或gt是outside时不计入总数
label_nonsense_gt = torch.gt(gt_labels[:, 1], 0)
flag = ~label_in_out_pred & ~label_in_out_gt
confuse_matrix[f'label_{index_label[1]}_TP'] += float(
    (flag & label_nonsense_pred & label_nonsense_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[1]}_FP'] += float(
    (flag & label_nonsense_pred & ~label_nonsense_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[1]}_FN'] += float(
    (flag & ~label_nonsense_pred & label_nonsense_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[1]}_TN'] += float(
    (flag & ~label_nonsense_pred & ~label_nonsense_gt).float().sum().cpu())


In [260]:
ileo_logit = pred_labels[:, 2]
# 回盲部标签: BoolTensor[B]
label_ileo_pred = torch.gt(ileo_logit, 0)
# 回盲部gt: BoolTensor[B]
label_ileo_gt = torch.gt(gt_labels[:, 2], 0)
flag = ~label_in_out_pred & ~label_in_out_gt & ~label_nonsense_pred & ~label_nonsense_gt
confuse_matrix[f'label_{index_label[2]}_TP'] += float((flag & label_ileo_pred & label_ileo_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[2]}_FP'] += float((flag & label_ileo_pred & ~label_ileo_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[2]}_FN'] += float((flag & ~label_ileo_pred & label_ileo_gt).float().sum().cpu())
confuse_matrix[f'label_{index_label[2]}_TN'] += float((flag & ~label_ileo_pred & ~label_ileo_gt).float().sum().cpu())


In [261]:
# 清洁度logit: FloatTensor[B, 4]
cls_logit = pred_labels[:, 3:]
# 清洁度label: IntTensor[B] (取预测值最大的，但会被outside标签抑制)
label_cls_pred = torch.argmax(cls_logit, dim=-1)
# 清洁度gt: IntTensor[B]
label_cls_gt = torch.argmax(gt_labels[:, 3:], dim=-1)
flag = ~label_in_out_pred & ~label_in_out_gt & ~label_nonsense_pred & ~label_nonsense_gt
for i in range(0, 4):  # i: predict
    for j in range(0, 4):  # j: gt
        confuse_matrix[f'label_cleansing_pred_{index_label[i + 3]}_gt_{index_label[j + 3]}'] += float((flag & torch.eq(label_cls_pred, i) & torch.eq(label_cls_gt, j)).float().sum().cpu())  # flag用于清洁度标签抑制


In [262]:
confuse_matrix

{'label_outside_TP': 471.0,
 'label_outside_FP': 5.0,
 'label_outside_FN': 45.0,
 'label_outside_TN': 4813.0,
 'label_nonsense_TP': 1372.0,
 'label_nonsense_FP': 36.0,
 'label_nonsense_FN': 69.0,
 'label_nonsense_TN': 3336.0,
 'label_ileocecal_TP': 119.0,
 'label_ileocecal_FP': 3.0,
 'label_ileocecal_FN': 5.0,
 'label_ileocecal_TN': 3209.0,
 'label_cleansing_pred_bbps0_gt_bbps0': 31.0,
 'label_cleansing_pred_bbps0_gt_bbps1': 1.0,
 'label_cleansing_pred_bbps0_gt_bbps2': 61.0,
 'label_cleansing_pred_bbps0_gt_bbps3': 3.0,
 'label_cleansing_pred_bbps1_gt_bbps0': 0.0,
 'label_cleansing_pred_bbps1_gt_bbps1': 169.0,
 'label_cleansing_pred_bbps1_gt_bbps2': 9.0,
 'label_cleansing_pred_bbps1_gt_bbps3': 2.0,
 'label_cleansing_pred_bbps2_gt_bbps0': 0.0,
 'label_cleansing_pred_bbps2_gt_bbps1': 15.0,
 'label_cleansing_pred_bbps2_gt_bbps2': 1264.0,
 'label_cleansing_pred_bbps2_gt_bbps3': 88.0,
 'label_cleansing_pred_bbps3_gt_bbps0': 0.0,
 'label_cleansing_pred_bbps3_gt_bbps1': 0.0,
 'label_cleansing_

In [263]:
metrics = dict()

# 体内外
TP: float = confuse_matrix[f'label_{index_label[0]}_TP']
FP: float = confuse_matrix[f'label_{index_label[0]}_FP']
FN: float = confuse_matrix[f'label_{index_label[0]}_FN']
TN: float = confuse_matrix[f'label_{index_label[0]}_TN']
metrics[f'label_{index_label[0]}_acc'] = (TP + TN) / (TP + FP + FN + TN) if TP + FP + FN + TN > 0. else 0.

# 帧质量
TP: float = confuse_matrix[f'label_{index_label[1]}_TP']
FP: float = confuse_matrix[f'label_{index_label[1]}_FP']
FN: float = confuse_matrix[f'label_{index_label[1]}_FN']
TN: float = confuse_matrix[f'label_{index_label[1]}_TN']
metrics[f'label_{index_label[1]}_acc'] = (TP + TN) / (TP + FP + FN + TN) if TP + FP + FN + TN > 0. else 0.

# 回盲部
TP: float = confuse_matrix[f'label_{index_label[2]}_TP']
FP: float = confuse_matrix[f'label_{index_label[2]}_FP']
FN: float = confuse_matrix[f'label_{index_label[2]}_FN']
TN: float = confuse_matrix[f'label_{index_label[2]}_TN']

metrics[f'label_{index_label[2]}_acc'] = (TP + TN) / (TP + FP + FN + TN) if TP + FP + FN + TN > 0. else 0.

# 四分清洁度准确率
total: float = 0.
correct: float = 0.
for i in range(3, 7):  # i: predict
    for j in range(3, 7):  # j: gt
        tmp = confuse_matrix[f'label_cleansing_pred_{index_label[i]}_gt_{index_label[j]}']
        if i == j:
            correct += tmp
        total += tmp
metrics[f'label_cleansing_acc'] = correct / total if total > 0. else 0.

In [264]:
metrics

{'label_outside_acc': 0.990626171728534,
 'label_nonsense_acc': 0.9781840847704134,
 'label_ileocecal_acc': 0.9976019184652278,
 'label_cleansing_acc': 0.9322541966426858}

In [237]:
for i in range(pred_labels.size()[0]):
    if torch.not_equal(pred_labels[i],gt_labels[i]).any():
        print(f"{i}:{pred_labels[i].numpy()} != {gt_labels[i].numpy()}")

172:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
173:[0 0 0 1 0 0 0] != [1 0 0 0 0 0 0]
174:[0 0 0 1 0 0 0] != [1 0 0 0 0 0 0]
178:[0 1 0 0 0 0 0] != [1 0 0 0 0 0 0]
430:[0 1 0 0 0 0 0] != [1 0 0 0 0 0 0]
446:[0 0 0 0 0 1 0] != [1 0 0 0 0 0 0]
447:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
448:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
449:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
451:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
452:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
453:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
454:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
455:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
456:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
457:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
460:[0 1 0 0 0 0 0] != [1 0 0 0 0 0 0]
466:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
467:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
468:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
469:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
470:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
471:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
472:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
473:[0 0 0 0 0 0 1] != [1 0 0 0 0 0 0]
474:[0 0 0 0 0 0 1] != [1